slogan / app.py
yair319732's picture
Upload folder using huggingface_hub
1b43026 verified
\
import os, json, numpy as np, pandas as pd
import gradio as gr
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from logic.cleaning import clean_dataframe
from logic.search import SloganSearcher
ASSETS_DIR = "assets"
DATA_PATH = "data/slogan.csv"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
NORMALIZE = True
GEN_MODEL_NAME = "google/flan-t5-base"
NUM_GEN_CANDIDATES = 6
MAX_NEW_TOKENS = 24
TEMPERATURE = 0.9
TOP_P = 0.95
NOVELTY_SIM_THRESHOLD = 0.80
META_PATH = os.path.join(ASSETS_DIR, "meta.json")
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index")
EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy")
def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True)
def _build_assets():
if not os.path.exists(DATA_PATH):
raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').")
os.makedirs(ASSETS_DIR, exist_ok=True)
_log(f"Loading dataset: {DATA_PATH}")
df = pd.read_csv(DATA_PATH)
_log(f"Rows before cleaning: {len(df)}")
df = clean_dataframe(df)
_log(f"Rows after cleaning: {len(df)}")
if "description" in df.columns and df["description"].notna().any():
texts = df["description"].fillna(df["tagline"]).astype(str).tolist()
text_col, fallback_col = "description", "tagline"
else:
texts = df["tagline"].astype(str).tolist()
text_col, fallback_col = "tagline", "tagline"
_log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …")
encoder = SentenceTransformer(MODEL_NAME)
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)
dim = emb.shape[1]
index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim)
index.add(emb)
_log("Persisting assets …")
df.to_parquet(PARQUET_PATH, index=False)
faiss.write_index(index, INDEX_PATH)
np.save(EMB_PATH, emb)
meta = {
"model_name": MODEL_NAME,
"dim": int(dim),
"normalized": NORMALIZE,
"metric": "ip" if NORMALIZE else "l2",
"row_count": int(len(df)),
"text_col": text_col,
"fallback_col": fallback_col,
}
with open(META_PATH, "w") as f:
json.dump(meta, f, indent=2)
_log("Assets built successfully.")
def _ensure_assets():
need = False
for p in (META_PATH, PARQUET_PATH, INDEX_PATH):
if not os.path.exists(p):
_log(f"Missing asset: {p}")
need = True
if need:
_log("Building assets from scratch …")
_build_assets()
return
try:
pd.read_parquet(PARQUET_PATH)
except Exception as e:
_log(f"Parquet read failed ({e}); rebuilding assets.")
_build_assets()
_ensure_assets()
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False)
meta = json.load(open(META_PATH))
_encoder = SentenceTransformer(meta["model_name"])
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
# ---- Prompt (adjust if you want your exact wording) ----
def _prompt_for(description: str) -> str:
return (
"You are a professional slogan writer. "
"Write ONE original, catchy startup slogan under 8 words, Title Case, no punctuation. "
"Do not copy examples. Description:\n"
f"{description}\nSlogan:"
)
def _generate_candidates(description: str, n: int = NUM_GEN_CANDIDATES):
prompt = _prompt_for(description)
inputs = _gen_tokenizer([prompt]*n, return_tensors="pt", padding=True, truncation=True)
outputs = _gen_model.generate(
**inputs,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
num_return_sequences=n,
max_new_tokens=MAX_NEW_TOKENS,
eos_token_id=_gen_tokenizer.eos_token_id,
)
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
return [t.replace("Slogan:", "").strip().strip('"') for t in texts if t.strip()]
def _pick_most_novel(candidates, retrieved_texts):
if not candidates:
return None
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True) if retrieved_texts else None
best, best_novelty = None, -1e9
for c in candidates:
c_emb = _encoder.encode([c], convert_to_numpy=True, normalize_embeddings=True)
if R is None or len(retrieved_texts) == 0:
max_sim = 0.0
else:
sims = np.dot(R, c_emb[0]) # cosine
max_sim = float(np.max(sims))
novelty = 1.0 - max_sim
if (max_sim < NOVELTY_SIM_THRESHOLD and novelty > best_novelty) or best is None and novelty > best_novelty:
best, best_novelty = c, novelty
return best
def run_pipeline(user_description: str):
if not user_description or not user_description.strip():
return "Please enter a description."
retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10)
retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else []
gens = _generate_candidates(user_description, NUM_GEN_CANDIDATES)
generated = _pick_most_novel(gens, retrieved_texts) or (gens[0] if gens else "—")
lines = []
lines.append("### 🔎 Top 3 similar slogans")
if retrieved_texts:
for i, s in enumerate(retrieved_texts, 1):
lines.append(f"{i}. {s}")
else:
lines.append("_No similar slogans found._")
lines.append("\n### ✨ AI-generated suggestion")
lines.append(generated)
return "\n".join(lines)
with gr.Blocks(title="Slogan Finder") as demo:
gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
btn = gr.Button("Get slogans", variant="primary")
out = gr.Markdown()
btn.click(run_pipeline, inputs=[query], outputs=out)
demo.queue(max_size=64).launch()