slogan / prepare_assets.py
yaire
initial
c3c36d1
"""
Run this LOCALLY to build assets/ from your real dataset.
1) Put your CSV/Parquet with at least 'tagline' (and optional 'description') columns.
2) Adjust INPUT_PATH below.
3) python prepare_assets.py
Then commit assets/ into your Space repo (or upload to a Dataset repo).
"""
import os, json, numpy as np, pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
from logic.cleaning import clean_dataframe
# ---- CHANGE THIS ----
INPUT_PATH = "/mnt/data/hf-slogan-space/data/raw_slogans.csv" # e.g., export from your notebook
ASSETS_DIR = "assets"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
NORMALIZE = True # set False if you prefer L2
def main():
os.makedirs(ASSETS_DIR, exist_ok=True)
# Load
if INPUT_PATH.endswith(".csv"):
df = pd.read_csv(INPUT_PATH)
elif INPUT_PATH.endswith(".parquet"):
df = pd.read_parquet(INPUT_PATH)
else:
raise ValueError("Use CSV or Parquet for INPUT_PATH")
# Clean using your real rules
df_clean = clean_dataframe(df)
df_clean.to_parquet(os.path.join(ASSETS_DIR, "slogans_clean.parquet"), index=False)
# Choose text field
if "description" in df_clean.columns:
texts = df_clean["description"].fillna(df_clean["tagline"]).astype(str).tolist()
text_col, fallback_col = "description", "tagline"
else:
texts = df_clean["tagline"].astype(str).tolist()
text_col, fallback_col = "tagline", "tagline"
# Encode
encoder = SentenceTransformer(MODEL_NAME)
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)
# Save embeddings numpy (optional; not required at runtime)
np.save(os.path.join(ASSETS_DIR, "embeddings.npy"), emb)
# Build FAISS index
dim = emb.shape[1]
if NORMALIZE:
index = faiss.IndexFlatIP(dim) # cosine if normalized
else:
index = faiss.IndexFlatL2(dim)
index.add(emb)
faiss.write_index(index, os.path.join(ASSETS_DIR, "faiss.index"))
meta = {
"model_name": MODEL_NAME,
"dim": int(dim),
"normalized": NORMALIZE,
"metric": "ip" if NORMALIZE else "l2",
"row_count": int(len(df_clean)),
"text_col": text_col,
"fallback_col": fallback_col,
}
with open(os.path.join(ASSETS_DIR, "meta.json"), "w") as f:
json.dump(meta, f, indent=2)
print("✅ Assets built in", ASSETS_DIR)
print(meta)
if __name__ == "__main__":
main()