pradeepsengarr commited on
Commit
f1e12d6
Β·
verified Β·
1 Parent(s): c80ff9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -38
app.py CHANGED
@@ -5,7 +5,7 @@ import faiss
5
  import numpy as np
6
  from io import BytesIO
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from huggingface_hub import login
11
 
@@ -18,42 +18,35 @@ login(token=hf_token)
18
  # Load embedding model
19
  embed_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
20
 
21
- # Load Mistral LLM (CPU compatible)
22
- model_id = "mistralai/Mistral-7B-Instruct-v0.1"
23
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- device_map={"": "cpu"}, # Force CPU
27
- torch_dtype="auto", # Safe for CPU
28
- token=hf_token
29
- )
30
- llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
-
32
- # Global state
33
  index = None
34
  doc_texts = []
35
 
36
- # Extract text from uploaded file
37
- def extract_text(file_obj):
38
  text = ""
39
- file_path = file_obj.name
40
- if file_path.endswith(".pdf"):
41
- with open(file_path, "rb") as f:
42
- pdf_stream = BytesIO(f.read())
43
  doc = fitz.open(stream=pdf_stream, filetype="pdf")
44
  for page in doc:
45
  text += page.get_text()
46
- elif file_path.endswith(".txt"):
47
- with open(file_path, "r", encoding="utf-8") as f:
48
- text = f.read()
49
  else:
50
  return "❌ Unsupported file type."
51
  return text
52
 
53
- # Process file and build FAISS index
54
- def process_file(file_obj):
55
  global index, doc_texts
56
- text = extract_text(file_obj)
57
  if text.startswith("❌"):
58
  return text
59
 
@@ -65,9 +58,9 @@ def process_file(file_obj):
65
  index = faiss.IndexFlatL2(dim)
66
  index.add(embeddings)
67
 
68
- return "βœ… File processed successfully. You can now ask questions!"
69
 
70
- # Generate answer from FAISS context + LLM
71
  def generate_answer(question):
72
  global index, doc_texts
73
  if index is None or not doc_texts:
@@ -77,30 +70,30 @@ def generate_answer(question):
77
  _, I = index.search(question_emb, k=3)
78
  context = "\n".join([doc_texts[i] for i in I[0]])
79
 
80
- prompt = f"""<s>[INST] You are a helpful assistant. Use the context below to answer the question.
81
 
82
  Context:
83
  {context}
84
 
85
  Question: {question}
86
- Answer: [/INST]</s>"""
87
 
88
- response = llm(prompt, max_new_tokens=300, do_sample=True, temperature=0.7)
89
- return response[0]["generated_text"].split("Answer:")[-1].strip()
90
 
91
  # Gradio UI
92
- with gr.Blocks(title="RAG Chatbot with Mistral-7B (CPU-Friendly)") as demo:
93
- gr.Markdown("## πŸ€– Upload a PDF/TXT file and ask questions using Mistral-7B")
94
 
95
  with gr.Row():
96
- file_input = gr.File(label="πŸ“ Upload PDF or TXT", file_types=[".pdf", ".txt"])
97
- upload_status = gr.Textbox(label="πŸ“₯ Upload Status", interactive=False)
98
 
99
  with gr.Row():
100
- question_input = gr.Textbox(label="❓ Ask a Question")
101
- answer_output = gr.Textbox(label="πŸ’¬ Answer", interactive=False)
102
 
103
  file_input.change(fn=process_file, inputs=file_input, outputs=upload_status)
104
- question_input.submit(fn=generate_answer, inputs=question_input, outputs=answer_output)
105
 
106
  demo.launch()
 
5
  import numpy as np
6
  from io import BytesIO
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from huggingface_hub import login
11
 
 
18
  # Load embedding model
19
  embed_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
20
 
21
+ # βœ… Load FLAN-T5 base (CPU-friendly)
22
+ model_id = "google/flan-t5-base"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
25
+ llm = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
26
+
27
+ # Globals
 
 
 
 
 
28
  index = None
29
  doc_texts = []
30
 
31
+ # Extract text from PDF or TXT
32
+ def extract_text(file):
33
  text = ""
34
+ file_bytes = file.read()
35
+ if file.name.endswith(".pdf"):
36
+ pdf_stream = BytesIO(file_bytes)
 
37
  doc = fitz.open(stream=pdf_stream, filetype="pdf")
38
  for page in doc:
39
  text += page.get_text()
40
+ elif file.name.endswith(".txt"):
41
+ text = file_bytes.decode("utf-8")
 
42
  else:
43
  return "❌ Unsupported file type."
44
  return text
45
 
46
+ # Process the file, build FAISS index
47
+ def process_file(file):
48
  global index, doc_texts
49
+ text = extract_text(file)
50
  if text.startswith("❌"):
51
  return text
52
 
 
58
  index = faiss.IndexFlatL2(dim)
59
  index.add(embeddings)
60
 
61
+ return "βœ… File processed! You can now ask questions."
62
 
63
+ # Generate answer using context + LLM
64
  def generate_answer(question):
65
  global index, doc_texts
66
  if index is None or not doc_texts:
 
70
  _, I = index.search(question_emb, k=3)
71
  context = "\n".join([doc_texts[i] for i in I[0]])
72
 
73
+ prompt = f"""Use the following context to answer the question.
74
 
75
  Context:
76
  {context}
77
 
78
  Question: {question}
79
+ """
80
 
81
+ response = llm(prompt, max_new_tokens=300)
82
+ return response[0]["generated_text"].strip()
83
 
84
  # Gradio UI
85
+ with gr.Blocks(title="RAG Chatbot (Fast & CPU Compatible)") as demo:
86
+ gr.Markdown("## πŸ“š Upload PDF/TXT and Ask Questions using FLAN-T5")
87
 
88
  with gr.Row():
89
+ file_input = gr.File(label="πŸ“ Upload File (.pdf or .txt)", file_types=[".pdf", ".txt"])
90
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
91
 
92
  with gr.Row():
93
+ question_box = gr.Textbox(label="❓ Ask a Question", placeholder="What would you like to know?")
94
+ answer_box = gr.Textbox(label="πŸ’¬ Answer", interactive=False)
95
 
96
  file_input.change(fn=process_file, inputs=file_input, outputs=upload_status)
97
+ question_box.submit(fn=generate_answer, inputs=question_box, outputs=answer_box)
98
 
99
  demo.launch()