import argparse import sqlite3 import io import os import faiss import pandas as pd import numpy as np def convert_numpy_array_to_text(array): stream = io.BytesIO() np.save(stream, array) stream.seek(0) return sqlite3.Binary(stream.read()) def prepare_items_db(items_path, embeddings_path, db_path): items = pd.read_csv(items_path) embeddings = np.load(embeddings_path) items["embedding"] = np.split(embeddings, embeddings.shape[0]) sqlite3.register_adapter(np.ndarray, convert_numpy_array_to_text) with sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) as conn: items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"}) def build_index(embeddings_path, save_path, n_neighbors): embeddings = np.load(embeddings_path) index = faiss.IndexHNSWFlat(embeddings.shape[-1], n_neighbors) index.add(embeddings) faiss.write_index(index, save_path) def prepare_recsys( items_path, embeddings_path, save_directory, n_neighbors=32, ): prepare_items_db(items_path, embeddings_path, os.path.join(save_directory, "items.db")) build_index(embeddings_path, os.path.join(save_directory, "index.faiss"), n_neighbors) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Prepare items database and HNSW index from a CSV file and embeddings.") parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.") parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.") parser.add_argument("--save_directory", required=True, type=str, help="Path to the save directory.") parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.") args = parser.parse_args() prepare_recsys(**vars(args))