File size: 1,053 Bytes
b8f4763 d4852d9 b8f4763 d4852d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import functools
import io
import sqlite3
import numpy as np
class ItemDatabase:
def __init__(self, db_path):
sqlite3.register_converter("embedding", self._text_to_numpy_array)
self._db_path = db_path
@staticmethod
def _text_to_numpy_array(text):
out = io.BytesIO(text)
out.seek(0)
return np.load(out)
def _connect(self):
return sqlite3.connect(
self._db_path, detect_types=sqlite3.PARSE_DECLTYPES)
def search_items(self, query, n_items=10):
with self._connect() as conn:
c = conn.cursor()
c.execute(f"select item_id from items where title like '%{query}%'")
rows = c.fetchall()[:n_items]
return [row[0] for row in rows]
@functools.lru_cache(maxsize=2**14)
def get_item(self, item_id):
with self._connect() as conn:
c = conn.cursor()
c.row_factory = sqlite3.Row
c.execute(f"select * from items where item_id like '{item_id}'")
return c.fetchone() |