|
import argparse |
|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from tqdm.auto import tqdm |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from utils import normalize_embeddings |
|
|
|
|
|
def prepare_sbert_embeddings( |
|
items_path, |
|
embeddings_savepath, |
|
model_name, |
|
batch_size, |
|
device |
|
): |
|
items = pd.read_csv(items_path).sort_values("item_id") |
|
sentences = items["description"].values |
|
model = SentenceTransformer(model_name).to(device) |
|
model.eval() |
|
embeddings = [] |
|
for start_index in tqdm(range(0, len(sentences), batch_size)): |
|
batch = sentences[start_index:start_index+batch_size] |
|
embeddings.extend(model.encode(batch)) |
|
embeddings = normalize_embeddings(np.array(embeddings)) |
|
np.save(embeddings_savepath, embeddings) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Prepare SBERT embeddings.") |
|
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.") |
|
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to save the embeddings.") |
|
parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2", help="Name of the SBERT model to use.") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") |
|
parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).") |
|
args = parser.parse_args() |
|
|
|
prepare_sbert_embeddings(**vars(args)) |