eat2fit / app.py
DurgaDeepak's picture
Update app.py
9b1fba6 verified
raw
history blame
2.76 kB
import gradio as gr
import spaces
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagSequenceForGeneration
# β€” Config β€”
DATASET_NAME = "username/mealplan-chunks"
INDEX_PATH = "mealplan.index"
MODEL_NAME = "facebook/rag-sequence-nq"
# β€” Load chunks & FAISS index β€”
ds = load_dataset(DATASET_NAME, split="train")
texts = ds["text"]
sources = ds["source"]
pages = ds["page"]
# β€” Embeddings embedder & FAISS β€”
embedder = SentenceTransformer("all-MiniLM-L6-v2")
chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
index = faiss.read_index(INDEX_PATH)
# β€” RAG generator β€”
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
model = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
goal: str,
diet: list[str],
meals: int,
avoid: str,
weeks: str,
):
# Parse preferences
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
prefs = (
f"Goal={goal}; Diet={','.join(diet)}; "
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
)
# 1) Query embedding & FAISS search
q_emb = embedder.encode([message], convert_to_numpy=True)
D, I = index.search(q_emb, 5) # top-5
ctx_chunks = [
f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
]
context = "\n".join(ctx_chunks)
# 2) Build prompt
prompt = (
"SYSTEM: Only answer using the provided CONTEXT. "
"If it’s not there, say \"I'm sorry, I don't know.\"\n"
f"PREFS: {prefs}\n"
f"CONTEXT:\n{context}\n"
f"Q: {message}\n"
)
# 3) Generate
inputs = tokenizer([prompt], return_tensors="pt")
outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# 4) Update history
history = history or []
history.append((message, answer))
return history
# β€” Build Gradio chat interface β€”
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[goal, diet, meals, avoid, weeks],
)
if __name__ == "__main__":
demo.launch()