File size: 1,909 Bytes
d4852d9
 
 
b8f4763
d4852d9
b8f4763
d4852d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8f4763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4852d9
b8f4763
 
d4852d9
 
b8f4763
 
d4852d9
 
b8f4763
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import sqlite3
import io
import os

import faiss
import pandas as pd
import numpy as np


def convert_numpy_array_to_text(array):
    stream = io.BytesIO()
    np.save(stream, array)
    stream.seek(0)
    return sqlite3.Binary(stream.read())


def prepare_items_db(items_path, embeddings_path, db_path):
    items = pd.read_csv(items_path)
    embeddings = np.load(embeddings_path)
    items["embedding"] = np.split(embeddings, embeddings.shape[0])

    sqlite3.register_adapter(np.ndarray, convert_numpy_array_to_text)
    with sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
        items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})


def build_index(embeddings_path, save_path, n_neighbors):
    embeddings = np.load(embeddings_path)
    index = faiss.IndexHNSWFlat(embeddings.shape[-1], n_neighbors)
    index.add(embeddings)
    faiss.write_index(index, save_path)


def prepare_recsys(
    items_path,
    embeddings_path,
    save_directory,
    n_neighbors=32,
):
    prepare_items_db(items_path, embeddings_path, os.path.join(save_directory, "items.db"))
    build_index(embeddings_path, os.path.join(save_directory, "index.faiss"), n_neighbors)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Prepare items database and HNSW index from a CSV file and embeddings.")

    parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
    parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
    parser.add_argument("--save_directory", required=True, type=str, help="Path to the save directory.")
    parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")

    args = parser.parse_args()
    prepare_recsys(**vars(args))