File size: 2,758 Bytes
6edbce0
fed20c9
9b1fba6
 
 
 
 
6edbce0
9b1fba6
e05063c
9b1fba6
 
 
 
 
 
 
 
 
 
 
 
 
6edbce0
9b1fba6
e05063c
9b1fba6
e05063c
30f2776
6edbce0
e05063c
6edbce0
e05063c
 
 
 
 
6edbce0
9b1fba6
e05063c
 
9b1fba6
 
e05063c
9b1fba6
 
 
 
 
 
 
 
 
 
e05063c
 
9b1fba6
e05063c
9b1fba6
e05063c
 
9b1fba6
 
e05063c
 
 
6edbce0
9b1fba6
e05063c
 
 
6edbce0
9b1fba6
e05063c
 
9b1fba6
 
e05063c
30f2776
6edbce0
e05063c
 
6edbce0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import spaces
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagSequenceForGeneration

# β€” Config β€”
DATASET_NAME = "username/mealplan-chunks"
INDEX_PATH   = "mealplan.index"
MODEL_NAME   = "facebook/rag-sequence-nq"

# β€” Load chunks & FAISS index β€”
ds = load_dataset(DATASET_NAME, split="train")
texts   = ds["text"]
sources = ds["source"]
pages   = ds["page"]

# β€” Embeddings embedder & FAISS β€”
embedder         = SentenceTransformer("all-MiniLM-L6-v2")
chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
index            = faiss.read_index(INDEX_PATH)

# β€” RAG generator β€”
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
model     = RagSequenceForGeneration.from_pretrained(MODEL_NAME)

@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str, str]],
    goal: str,
    diet: list[str],
    meals: int,
    avoid: str,
    weeks: str,
):
    # Parse preferences
    avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
    prefs = (
        f"Goal={goal}; Diet={','.join(diet)}; "
        f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
    )

    # 1) Query embedding & FAISS search
    q_emb = embedder.encode([message], convert_to_numpy=True)
    D, I  = index.search(q_emb, 5)  # top-5
    ctx_chunks = [
        f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
    ]
    context = "\n".join(ctx_chunks)

    # 2) Build prompt
    prompt = (
        "SYSTEM: Only answer using the provided CONTEXT. "
        "If it’s not there, say \"I'm sorry, I don't know.\"\n"
        f"PREFS: {prefs}\n"
        f"CONTEXT:\n{context}\n"
        f"Q: {message}\n"
    )

    # 3) Generate
    inputs  = tokenizer([prompt], return_tensors="pt")
    outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
    answer  = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    # 4) Update history
    history = history or []
    history.append((message, answer))
    return history

# β€” Build Gradio chat interface β€”
goal  = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
diet  = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")

demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[goal, diet, meals, avoid, weeks],
)

if __name__ == "__main__":
    demo.launch()