graph-rec / exp /prepare_recsys.py
erermeev-d
Updated experiments code
b8f4763
import argparse
import sqlite3
import io
import os
import faiss
import pandas as pd
import numpy as np
def convert_numpy_array_to_text(array):
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)
return sqlite3.Binary(stream.read())
def prepare_items_db(items_path, embeddings_path, db_path):
items = pd.read_csv(items_path)
embeddings = np.load(embeddings_path)
items["embedding"] = np.split(embeddings, embeddings.shape[0])
sqlite3.register_adapter(np.ndarray, convert_numpy_array_to_text)
with sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
def build_index(embeddings_path, save_path, n_neighbors):
embeddings = np.load(embeddings_path)
index = faiss.IndexHNSWFlat(embeddings.shape[-1], n_neighbors)
index.add(embeddings)
faiss.write_index(index, save_path)
def prepare_recsys(
items_path,
embeddings_path,
save_directory,
n_neighbors=32,
):
prepare_items_db(items_path, embeddings_path, os.path.join(save_directory, "items.db"))
build_index(embeddings_path, os.path.join(save_directory, "index.faiss"), n_neighbors)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare items database and HNSW index from a CSV file and embeddings.")
parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
parser.add_argument("--save_directory", required=True, type=str, help="Path to the save directory.")
parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
args = parser.parse_args()
prepare_recsys(**vars(args))