DurgaDeepak commited on
Commit
901777e
Β·
verified Β·
1 Parent(s): b936c3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -94
app.py CHANGED
@@ -1,106 +1,76 @@
1
- # app.py
2
  import os
3
- import glob
4
- import faiss
5
  import numpy as np
6
-
7
- import gradio as gr
8
- import spaces
9
-
10
- from unstructured.partition.pdf import partition_pdf
11
  from sentence_transformers import SentenceTransformer
12
- from transformers import RagTokenizer, RagSequenceForGeneration
13
-
14
- # ─── Configuration ─────────────────────────────────────────────
15
- PDF_FOLDER = "meal_plans"
16
- MODEL_NAME = "facebook/rag-sequence-nq"
17
- EMBED_MODEL = "all-MiniLM-L6-v2"
18
- TOP_K = 5
19
-
20
- # ─── 1) LOAD + CHUNK ALL PDFs ──────────────────────────────────
21
- rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
22
- texts, sources, pages = [], [], []
23
-
24
- for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"):
25
- book = os.path.basename(pdf_path)
26
- pages_data = partition_pdf(filename=pdf_path)
27
- for pg_num, page in enumerate(pages_data, start=1):
28
- enc = rag_tokenizer(
29
- page.text,
30
- max_length=800,
31
- truncation=True,
32
- return_overflowing_tokens=True,
33
- stride=50,
34
- return_tensors="pt"
35
- )
36
- for token_ids in enc["input_ids"]:
37
- chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True)
38
- texts.append(chunk)
39
- sources.append(book)
40
- pages.append(pg_num)
41
 
42
- # ─── 2) EMBED + BUILD FAISS INDEX ─────────────────────────────
43
- embedder = SentenceTransformer(EMBED_MODEL)
44
- embeddings = embedder.encode(texts, convert_to_numpy=True)
45
- dim = embeddings.shape[1]
46
- index = faiss.IndexFlatL2(dim)
47
- index.add(embeddings)
48
 
49
- # ─── 3) LOAD RAG GENERATOR ────────────────────────────────────
50
- tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
51
- generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- @spaces.GPU
54
- def respond(
55
- message: str,
56
- history: list[tuple[str,str]],
57
- goal: str,
58
- diet: list[str],
59
- meals: int,
60
- avoid: str,
61
- weeks: str
62
- ):
63
- # build prefs string
64
- avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
65
- prefs = (
66
- f"Goal={goal}; Diet={','.join(diet)}; "
67
- f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
68
- )
69
- # 1) RETRIEVE top-k chunks
70
- q_emb = embedder.encode([message], convert_to_numpy=True)
71
- D, I = index.search(q_emb, TOP_K)
72
- context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0])
73
 
74
- # 2) BUILD PROMPT with guardrail
75
- prompt = (
76
- "SYSTEM: Only answer using the provided CONTEXT. "
77
- "If it’s not there, say \"I'm sorry, I don't know.\" \n"
78
- f"PREFS: {prefs}\n"
79
- f"CONTEXT:\n{context}\n"
80
- f"Q: {message}\n"
81
- )
 
 
 
82
 
83
- # 3) GENERATE
84
- inputs = tokenizer([prompt], return_tensors="pt")
85
- outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200)
86
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
87
 
88
- # update chat history
89
- history = history or []
90
- history.append((message, answer))
91
- return history
 
 
 
 
 
 
 
92
 
93
- # ─── 4) BUILD UI ────────────────────────────────────────────────
94
- goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight")
95
- diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
96
- meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day")
97
- avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts…", label="Avoidances (comma-separated)")
98
- weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week")
99
 
100
- demo = gr.ChatInterface(
101
- fn=respond,
102
- additional_inputs=[goal, diet, meals, avoid, weeks]
103
- )
104
 
105
- if __name__ == "__main__":
106
- demo.launch()
 
 
1
  import os
2
+ import fitz # PyMuPDF
 
3
  import numpy as np
4
+ import faiss
 
 
 
 
5
  from sentence_transformers import SentenceTransformer
6
+ import gradio as gr
7
+ import spaces # for ZeroGPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ @spaces.GPU
10
+ def query_app(user_input, include_source, verbose):
11
+ return search_index(user_input, index, documents, include_source, verbose)
 
 
 
12
 
13
+ # PDF reader
14
+ def extract_text_from_pdf(folder_path="meal_plans"):
15
+ documents = []
16
+ for filename in os.listdir(folder_path):
17
+ if filename.lower().endswith(".pdf"):
18
+ path = os.path.join(folder_path, filename)
19
+ try:
20
+ doc = fitz.open(path)
21
+ text = ""
22
+ for page in doc:
23
+ text += page.get_text()
24
+ documents.append({"text": text, "source": filename})
25
+ except Exception as e:
26
+ print(f"Error reading {filename}: {e}")
27
+ return documents
28
 
29
+ # Index builder
30
+ def create_index(docs):
31
+ texts = [doc["text"] for doc in docs]
32
+ embeddings = model.encode(texts)
33
+ dim = embeddings[0].shape[0]
34
+ index = faiss.IndexFlatL2(dim)
35
+ index.add(np.array(embeddings).astype("float32"))
36
+ return index
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Search logic
39
+ def search_index(query, index, docs, include_source=True, verbose=False, top_k=3):
40
+ query_vec = model.encode([query])
41
+ D, I = index.search(np.array(query_vec).astype("float32"), top_k)
42
+ responses = []
43
+ for i in I[0]:
44
+ doc = docs[i]
45
+ snippet = doc["text"][:750 if verbose else 300].replace("\n", " ").strip()
46
+ label = f"**πŸ“„ {doc['source']}**\n" if include_source else ""
47
+ responses.append(f"{label}{snippet}...")
48
+ return "\n\n---\n\n".join(responses)
49
 
50
+ # Setup
51
+ model = SentenceTransformer("all-MiniLM-L6-v2")
52
+ documents = extract_text_from_pdf("meal_plans")
53
+ index = create_index(documents)
54
 
55
+ # Gradio UI
56
+ with gr.Blocks(title="Meal Plan Chat Assistant") as demo:
57
+ gr.Markdown("## 🍽️ Meal Plan Assistant\nChat with your PDF documents in `meal_plans/` folder.")
58
+ with gr.Row():
59
+ with gr.Column(scale=4):
60
+ chatbot = gr.Chatbot()
61
+ user_input = gr.Textbox(placeholder="Ask something...", show_label=False)
62
+ send_btn = gr.Button("Ask")
63
+ with gr.Column(scale=1):
64
+ include_source = gr.Checkbox(label="Include Source", value=True)
65
+ verbose = gr.Checkbox(label="Verbose Mode", value=False)
66
 
67
+ def user_query(msg, history, source, verbose_mode):
68
+ answer = query_app(msg, source, verbose_mode)
69
+ history = history + [(msg, answer)]
70
+ return history, history
 
 
71
 
72
+ send_btn.click(user_query,
73
+ inputs=[user_input, chatbot, include_source, verbose],
74
+ outputs=[chatbot, chatbot])
 
75
 
76
+ demo.launch()