import gradio as gr import spaces import faiss import numpy as np from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers import RagTokenizer, RagSequenceForGeneration # — Config — DATASET_NAME = "username/mealplan-chunks" INDEX_PATH = "mealplan.index" MODEL_NAME = "facebook/rag-sequence-nq" # — Load chunks & FAISS index — ds = load_dataset(DATASET_NAME, split="train") texts = ds["text"] sources = ds["source"] pages = ds["page"] # — Embeddings embedder & FAISS — embedder = SentenceTransformer("all-MiniLM-L6-v2") chunk_embeddings = embedder.encode(texts, convert_to_numpy=True) index = faiss.read_index(INDEX_PATH) # — RAG generator — tokenizer = RagTokenizer.from_pretrained(MODEL_NAME) model = RagSequenceForGeneration.from_pretrained(MODEL_NAME) @spaces.GPU def respond( message: str, history: list[tuple[str, str]], goal: str, diet: list[str], meals: int, avoid: str, weeks: str, ): # Parse preferences avoid_list = [a.strip() for a in avoid.split(",") if a.strip()] prefs = ( f"Goal={goal}; Diet={','.join(diet)}; " f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}" ) # 1) Query embedding & FAISS search q_emb = embedder.encode([message], convert_to_numpy=True) D, I = index.search(q_emb, 5) # top-5 ctx_chunks = [ f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0] ] context = "\n".join(ctx_chunks) # 2) Build prompt prompt = ( "SYSTEM: Only answer using the provided CONTEXT. " "If it’s not there, say \"I'm sorry, I don't know.\"\n" f"PREFS: {prefs}\n" f"CONTEXT:\n{context}\n" f"Q: {message}\n" ) # 3) Generate inputs = tokenizer([prompt], return_tensors="pt") outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200) answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # 4) Update history history = history or [] history.append((message, answer)) return history # — Build Gradio chat interface — goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal") diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style") meals = gr.Slider(1,6,value=3,step=1,label="Meals per day") avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)") weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length") demo = gr.ChatInterface( fn=respond, additional_inputs=[goal, diet, meals, avoid, weeks], ) if __name__ == "__main__": demo.launch()