Spaces:
Sleeping
Sleeping
File size: 2,758 Bytes
6edbce0 fed20c9 9b1fba6 6edbce0 9b1fba6 e05063c 9b1fba6 6edbce0 9b1fba6 e05063c 9b1fba6 e05063c 30f2776 6edbce0 e05063c 6edbce0 e05063c 6edbce0 9b1fba6 e05063c 9b1fba6 e05063c 9b1fba6 e05063c 9b1fba6 e05063c 9b1fba6 e05063c 9b1fba6 e05063c 6edbce0 9b1fba6 e05063c 6edbce0 9b1fba6 e05063c 9b1fba6 e05063c 30f2776 6edbce0 e05063c 6edbce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import spaces
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagSequenceForGeneration
# β Config β
DATASET_NAME = "username/mealplan-chunks"
INDEX_PATH = "mealplan.index"
MODEL_NAME = "facebook/rag-sequence-nq"
# β Load chunks & FAISS index β
ds = load_dataset(DATASET_NAME, split="train")
texts = ds["text"]
sources = ds["source"]
pages = ds["page"]
# β Embeddings embedder & FAISS β
embedder = SentenceTransformer("all-MiniLM-L6-v2")
chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
index = faiss.read_index(INDEX_PATH)
# β RAG generator β
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
model = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
goal: str,
diet: list[str],
meals: int,
avoid: str,
weeks: str,
):
# Parse preferences
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
prefs = (
f"Goal={goal}; Diet={','.join(diet)}; "
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
)
# 1) Query embedding & FAISS search
q_emb = embedder.encode([message], convert_to_numpy=True)
D, I = index.search(q_emb, 5) # top-5
ctx_chunks = [
f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
]
context = "\n".join(ctx_chunks)
# 2) Build prompt
prompt = (
"SYSTEM: Only answer using the provided CONTEXT. "
"If itβs not there, say \"I'm sorry, I don't know.\"\n"
f"PREFS: {prefs}\n"
f"CONTEXT:\n{context}\n"
f"Q: {message}\n"
)
# 3) Generate
inputs = tokenizer([prompt], return_tensors="pt")
outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# 4) Update history
history = history or []
history.append((message, answer))
return history
# β Build Gradio chat interface β
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[goal, diet, meals, avoid, weeks],
)
if __name__ == "__main__":
demo.launch()
|