Spaces:
Sleeping
Sleeping
\ | |
import os, json, numpy as np, pandas as pd | |
import gradio as gr | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from logic.cleaning import clean_dataframe | |
from logic.search import SloganSearcher | |
ASSETS_DIR = "assets" | |
DATA_PATH = "data/slogan.csv" | |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
NORMALIZE = True | |
GEN_MODEL_NAME = "google/flan-t5-base" | |
NUM_GEN_CANDIDATES = 6 | |
MAX_NEW_TOKENS = 24 | |
TEMPERATURE = 0.9 | |
TOP_P = 0.95 | |
NOVELTY_SIM_THRESHOLD = 0.80 | |
META_PATH = os.path.join(ASSETS_DIR, "meta.json") | |
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet") | |
INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index") | |
EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy") | |
def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True) | |
def _build_assets(): | |
if not os.path.exists(DATA_PATH): | |
raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').") | |
os.makedirs(ASSETS_DIR, exist_ok=True) | |
_log(f"Loading dataset: {DATA_PATH}") | |
df = pd.read_csv(DATA_PATH) | |
_log(f"Rows before cleaning: {len(df)}") | |
df = clean_dataframe(df) | |
_log(f"Rows after cleaning: {len(df)}") | |
if "description" in df.columns and df["description"].notna().any(): | |
texts = df["description"].fillna(df["tagline"]).astype(str).tolist() | |
text_col, fallback_col = "description", "tagline" | |
else: | |
texts = df["tagline"].astype(str).tolist() | |
text_col, fallback_col = "tagline", "tagline" | |
_log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …") | |
encoder = SentenceTransformer(MODEL_NAME) | |
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE) | |
dim = emb.shape[1] | |
index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim) | |
index.add(emb) | |
_log("Persisting assets …") | |
df.to_parquet(PARQUET_PATH, index=False) | |
faiss.write_index(index, INDEX_PATH) | |
np.save(EMB_PATH, emb) | |
meta = { | |
"model_name": MODEL_NAME, | |
"dim": int(dim), | |
"normalized": NORMALIZE, | |
"metric": "ip" if NORMALIZE else "l2", | |
"row_count": int(len(df)), | |
"text_col": text_col, | |
"fallback_col": fallback_col, | |
} | |
with open(META_PATH, "w") as f: | |
json.dump(meta, f, indent=2) | |
_log("Assets built successfully.") | |
def _ensure_assets(): | |
need = False | |
for p in (META_PATH, PARQUET_PATH, INDEX_PATH): | |
if not os.path.exists(p): | |
_log(f"Missing asset: {p}") | |
need = True | |
if need: | |
_log("Building assets from scratch …") | |
_build_assets() | |
return | |
try: | |
pd.read_parquet(PARQUET_PATH) | |
except Exception as e: | |
_log(f"Parquet read failed ({e}); rebuilding assets.") | |
_build_assets() | |
_ensure_assets() | |
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False) | |
meta = json.load(open(META_PATH)) | |
_encoder = SentenceTransformer(meta["model_name"]) | |
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) | |
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) | |
# ---- Prompt (adjust if you want your exact wording) ---- | |
def _prompt_for(description: str) -> str: | |
return ( | |
"You are a professional slogan writer. " | |
"Write ONE original, catchy startup slogan under 8 words, Title Case, no punctuation. " | |
"Do not copy examples. Description:\n" | |
f"{description}\nSlogan:" | |
) | |
def _generate_candidates(description: str, n: int = NUM_GEN_CANDIDATES): | |
prompt = _prompt_for(description) | |
inputs = _gen_tokenizer([prompt]*n, return_tensors="pt", padding=True, truncation=True) | |
outputs = _gen_model.generate( | |
**inputs, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
top_p=TOP_P, | |
num_return_sequences=n, | |
max_new_tokens=MAX_NEW_TOKENS, | |
eos_token_id=_gen_tokenizer.eos_token_id, | |
) | |
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return [t.replace("Slogan:", "").strip().strip('"') for t in texts if t.strip()] | |
def _pick_most_novel(candidates, retrieved_texts): | |
if not candidates: | |
return None | |
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True) if retrieved_texts else None | |
best, best_novelty = None, -1e9 | |
for c in candidates: | |
c_emb = _encoder.encode([c], convert_to_numpy=True, normalize_embeddings=True) | |
if R is None or len(retrieved_texts) == 0: | |
max_sim = 0.0 | |
else: | |
sims = np.dot(R, c_emb[0]) # cosine | |
max_sim = float(np.max(sims)) | |
novelty = 1.0 - max_sim | |
if (max_sim < NOVELTY_SIM_THRESHOLD and novelty > best_novelty) or best is None and novelty > best_novelty: | |
best, best_novelty = c, novelty | |
return best | |
def run_pipeline(user_description: str): | |
if not user_description or not user_description.strip(): | |
return "Please enter a description." | |
retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10) | |
retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else [] | |
gens = _generate_candidates(user_description, NUM_GEN_CANDIDATES) | |
generated = _pick_most_novel(gens, retrieved_texts) or (gens[0] if gens else "—") | |
lines = [] | |
lines.append("### 🔎 Top 3 similar slogans") | |
if retrieved_texts: | |
for i, s in enumerate(retrieved_texts, 1): | |
lines.append(f"{i}. {s}") | |
else: | |
lines.append("_No similar slogans found._") | |
lines.append("\n### ✨ AI-generated suggestion") | |
lines.append(generated) | |
return "\n".join(lines) | |
with gr.Blocks(title="Slogan Finder") as demo: | |
gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.") | |
query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...") | |
btn = gr.Button("Get slogans", variant="primary") | |
out = gr.Markdown() | |
btn.click(run_pipeline, inputs=[query], outputs=out) | |
demo.queue(max_size=64).launch() | |