Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import faiss | |
import numpy as np | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from transformers import RagTokenizer, RagSequenceForGeneration | |
# β Config β | |
DATASET_NAME = "username/mealplan-chunks" | |
INDEX_PATH = "mealplan.index" | |
MODEL_NAME = "facebook/rag-sequence-nq" | |
# β Load chunks & FAISS index β | |
ds = load_dataset(DATASET_NAME, split="train") | |
texts = ds["text"] | |
sources = ds["source"] | |
pages = ds["page"] | |
# β Embeddings embedder & FAISS β | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
chunk_embeddings = embedder.encode(texts, convert_to_numpy=True) | |
index = faiss.read_index(INDEX_PATH) | |
# β RAG generator β | |
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME) | |
model = RagSequenceForGeneration.from_pretrained(MODEL_NAME) | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
goal: str, | |
diet: list[str], | |
meals: int, | |
avoid: str, | |
weeks: str, | |
): | |
# Parse preferences | |
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()] | |
prefs = ( | |
f"Goal={goal}; Diet={','.join(diet)}; " | |
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}" | |
) | |
# 1) Query embedding & FAISS search | |
q_emb = embedder.encode([message], convert_to_numpy=True) | |
D, I = index.search(q_emb, 5) # top-5 | |
ctx_chunks = [ | |
f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0] | |
] | |
context = "\n".join(ctx_chunks) | |
# 2) Build prompt | |
prompt = ( | |
"SYSTEM: Only answer using the provided CONTEXT. " | |
"If itβs not there, say \"I'm sorry, I don't know.\"\n" | |
f"PREFS: {prefs}\n" | |
f"CONTEXT:\n{context}\n" | |
f"Q: {message}\n" | |
) | |
# 3) Generate | |
inputs = tokenizer([prompt], return_tensors="pt") | |
outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200) | |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
# 4) Update history | |
history = history or [] | |
history.append((message, answer)) | |
return history | |
# β Build Gradio chat interface β | |
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal") | |
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style") | |
meals = gr.Slider(1,6,value=3,step=1,label="Meals per day") | |
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)") | |
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length") | |
demo = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[goal, diet, meals, avoid, weeks], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |