import functools import io import sqlite3 import numpy as np class ItemDatabase: def __init__(self, db_path): sqlite3.register_converter("embedding", self._text_to_numpy_array) self._db_path = db_path @staticmethod def _text_to_numpy_array(text): out = io.BytesIO(text) out.seek(0) return np.load(out) def _connect(self): return sqlite3.connect( self._db_path, detect_types=sqlite3.PARSE_DECLTYPES) def search_items(self, query, n_items=10): with self._connect() as conn: c = conn.cursor() c.execute(f"select item_id from items where title like '%{query}%'") rows = c.fetchall()[:n_items] return [row[0] for row in rows] @functools.lru_cache(maxsize=2**14) def get_item(self, item_id): with self._connect() as conn: c = conn.cursor() c.row_factory = sqlite3.Row c.execute(f"select * from items where item_id like '{item_id}'") return c.fetchone()