File size: 1,053 Bytes
b8f4763
d4852d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f4763
d4852d9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import functools
import io
import sqlite3

import numpy as np


class ItemDatabase:
    def __init__(self, db_path):
        sqlite3.register_converter("embedding", self._text_to_numpy_array)
        self._db_path = db_path

    @staticmethod
    def _text_to_numpy_array(text):
        out = io.BytesIO(text)
        out.seek(0)
        return np.load(out)

    def _connect(self):
        return sqlite3.connect(
            self._db_path, detect_types=sqlite3.PARSE_DECLTYPES)

    def search_items(self, query, n_items=10):
        with self._connect() as conn:
            c = conn.cursor()
            c.execute(f"select item_id from items where title like '%{query}%'")
            rows = c.fetchall()[:n_items]
            return [row[0] for row in rows]

    @functools.lru_cache(maxsize=2**14)
    def get_item(self, item_id):
        with self._connect() as conn:
            c = conn.cursor()
            c.row_factory = sqlite3.Row
            c.execute(f"select * from items where item_id like '{item_id}'")
            return c.fetchone()