pradeepsengarr commited on
Commit
1dedfac
Β·
verified Β·
1 Parent(s): 87bdf56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +359 -71
app.py CHANGED
@@ -1,83 +1,371 @@
1
- import os
2
- import time
3
- import torch
4
  import gradio as gr
5
- from huggingface_hub import login
6
- from transformers import AutoTokenizer
7
- from auto_gptq import AutoGPTQForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
- from langchain_community.vectorstores import FAISS
10
-
11
- # Load HF token and login
12
- hf_token = os.environ.get("HUGGINGFACE_TOKEN")
13
- if not hf_token:
14
- raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
15
-
16
- login(token=hf_token)
17
-
18
- # Load tokenizer and quantized model
19
- model_id = "TheBloke/mistral-7B-GPTQ"
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
-
22
- print("Loading tokenizer...")
23
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
24
-
25
- print("Loading quantized model...")
26
- start = time.time()
27
- model = AutoGPTQForCausalLM.from_quantized(
28
- model_id,
29
- use_safetensors=True,
30
- device=device,
31
- use_triton=True,
32
- quantize_config=None,
33
- )
34
- print(f"Model loaded in {time.time() - start:.2f} seconds on {device}")
35
-
36
- # Load embedding model for FAISS vector store
37
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
38
-
39
- # Sample documents to build vector index (can replace with your own)
40
- texts = [
41
- "Hello world",
42
- "Mistral 7B is a powerful language model",
43
- "Langchain and FAISS make vector search easy",
44
- "This is a test document for vector search",
45
- ]
46
- embeddings = embedder.encode(texts)
47
 
48
- faiss_index = FAISS.from_embeddings(embeddings, texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Generate text from prompt
51
- def generate_text(prompt, max_length=128):
52
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
53
- with torch.no_grad():
54
- outputs = model.generate(**inputs, max_length=max_length)
55
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- return decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Search docs with vector similarity
59
- def search_docs(query):
60
- query_emb = embedder.encode([query])
61
- results = faiss_index.similarity_search_by_vector(query_emb[0], k=3)
62
- return "\n\n".join(results)
63
 
64
- # Gradio UI
65
- with gr.Blocks() as demo:
66
- gr.Markdown("# Mistral GPTQ + FAISS Vector Search Demo")
67
 
68
- with gr.Tab("Text Generation"):
69
- prompt_input = gr.Textbox(label="Enter prompt", lines=3)
70
- generate_btn = gr.Button("Generate")
71
- output_text = gr.Textbox(label="Output", lines=6)
72
 
73
- generate_btn.click(fn=generate_text, inputs=prompt_input, outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- with gr.Tab("Vector Search"):
76
- query_input = gr.Textbox(label="Enter search query", lines=2)
77
- search_btn = gr.Button("Search")
78
- search_output = gr.Textbox(label="Search Results", lines=6)
79
 
80
- search_btn.click(fn=search_docs, inputs=query_input, outputs=search_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
  if __name__ == "__main__":
83
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
4
  from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ import PyPDF2
8
+ import docx
9
+ import io
10
+ import os
11
+ from typing import List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ class DocumentRAG:
14
+ def __init__(self):
15
+ print("πŸš€ Initializing RAG System...")
16
+
17
+ # Initialize embedding model (lightweight)
18
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
19
+ print("βœ… Embedding model loaded")
20
+
21
+ # Initialize quantized LLM
22
+ self.setup_llm()
23
+
24
+ # Document storage
25
+ self.documents = []
26
+ self.index = None
27
+ self.is_indexed = False
28
+
29
+ def setup_llm(self):
30
+ """Setup quantized Mistral model"""
31
+ try:
32
+ quantization_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_compute_dtype=torch.float16,
35
+ bnb_4bit_use_double_quant=True,
36
+ bnb_4bit_quant_type="nf4"
37
+ )
38
+
39
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
40
+
41
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ if self.tokenizer.pad_token is None:
43
+ self.tokenizer.pad_token = self.tokenizer.eos_token
44
+
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ quantization_config=quantization_config,
48
+ device_map="auto",
49
+ torch_dtype=torch.float16,
50
+ trust_remote_code=True
51
+ )
52
+ print("βœ… Quantized Mistral model loaded")
53
+
54
+ except Exception as e:
55
+ print(f"❌ Error loading model: {e}")
56
+ # Fallback to a smaller model if Mistral fails
57
+ self.setup_fallback_model()
58
+
59
+ def setup_fallback_model(self):
60
+ """Fallback to smaller model if Mistral fails"""
61
+ try:
62
+ model_name = "microsoft/DialoGPT-small"
63
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
65
+ print("βœ… Fallback model loaded")
66
+ except Exception as e:
67
+ print(f"❌ Fallback model failed: {e}")
68
+ self.model = None
69
+ self.tokenizer = None
70
 
71
+ def extract_text_from_file(self, file_path: str) -> str:
72
+ """Extract text from various file formats"""
73
+ try:
74
+ file_extension = os.path.splitext(file_path)[1].lower()
75
+
76
+ if file_extension == '.pdf':
77
+ return self.extract_from_pdf(file_path)
78
+ elif file_extension == '.docx':
79
+ return self.extract_from_docx(file_path)
80
+ elif file_extension == '.txt':
81
+ return self.extract_from_txt(file_path)
82
+ else:
83
+ return f"Unsupported file format: {file_extension}"
84
+
85
+ except Exception as e:
86
+ return f"Error reading file: {str(e)}"
87
+
88
+ def extract_from_pdf(self, file_path: str) -> str:
89
+ """Extract text from PDF"""
90
+ text = ""
91
+ try:
92
+ with open(file_path, 'rb') as file:
93
+ pdf_reader = PyPDF2.PdfReader(file)
94
+ for page in pdf_reader.pages:
95
+ text += page.extract_text() + "\n"
96
+ except Exception as e:
97
+ text = f"Error reading PDF: {str(e)}"
98
+ return text
99
+
100
+ def extract_from_docx(self, file_path: str) -> str:
101
+ """Extract text from DOCX"""
102
+ try:
103
+ doc = docx.Document(file_path)
104
+ text = ""
105
+ for paragraph in doc.paragraphs:
106
+ text += paragraph.text + "\n"
107
+ return text
108
+ except Exception as e:
109
+ return f"Error reading DOCX: {str(e)}"
110
+
111
+ def extract_from_txt(self, file_path: str) -> str:
112
+ """Extract text from TXT"""
113
+ try:
114
+ with open(file_path, 'r', encoding='utf-8') as file:
115
+ return file.read()
116
+ except Exception as e:
117
+ try:
118
+ with open(file_path, 'r', encoding='latin-1') as file:
119
+ return file.read()
120
+ except Exception as e2:
121
+ return f"Error reading TXT: {str(e2)}"
122
+
123
+ def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
124
+ """Split text into overlapping chunks"""
125
+ if not text.strip():
126
+ return []
127
+
128
+ words = text.split()
129
+ chunks = []
130
+
131
+ for i in range(0, len(words), chunk_size - overlap):
132
+ chunk = ' '.join(words[i:i + chunk_size])
133
+ if chunk.strip():
134
+ chunks.append(chunk.strip())
135
+
136
+ if i + chunk_size >= len(words):
137
+ break
138
+
139
+ return chunks
140
+
141
+ def process_documents(self, files) -> str:
142
+ """Process uploaded files and create embeddings"""
143
+ if not files:
144
+ return "❌ No files uploaded!"
145
+
146
+ try:
147
+ all_text = ""
148
+ processed_files = []
149
+
150
+ # Extract text from all files
151
+ for file in files:
152
+ if file is None:
153
+ continue
154
+
155
+ file_text = self.extract_text_from_file(file.name)
156
+ if not file_text.startswith("Error") and not file_text.startswith("Unsupported"):
157
+ all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}"
158
+ processed_files.append(os.path.basename(file.name))
159
+ else:
160
+ return f"❌ {file_text}"
161
+
162
+ if not all_text.strip():
163
+ return "❌ No text extracted from files!"
164
+
165
+ # Chunk the text
166
+ self.documents = self.chunk_text(all_text)
167
+
168
+ if not self.documents:
169
+ return "❌ No valid text chunks created!"
170
+
171
+ # Create embeddings
172
+ print(f"πŸ“„ Creating embeddings for {len(self.documents)} chunks...")
173
+ embeddings = self.embedder.encode(self.documents, show_progress_bar=True)
174
+
175
+ # Build FAISS index
176
+ dimension = embeddings.shape[1]
177
+ self.index = faiss.IndexFlatIP(dimension)
178
+
179
+ # Normalize embeddings for cosine similarity
180
+ faiss.normalize_L2(embeddings)
181
+ self.index.add(embeddings.astype('float32'))
182
+
183
+ self.is_indexed = True
184
+
185
+ return f"βœ… Successfully processed {len(processed_files)} files:\n" + \
186
+ f"πŸ“„ Files: {', '.join(processed_files)}\n" + \
187
+ f"πŸ“Š Created {len(self.documents)} text chunks\n" + \
188
+ f"πŸ” Ready for Q&A!"
189
+
190
+ except Exception as e:
191
+ return f"❌ Error processing documents: {str(e)}"
192
+
193
+ def retrieve_context(self, query: str, k: int = 3) -> str:
194
+ """Retrieve relevant context for the query"""
195
+ if not self.is_indexed:
196
+ return ""
197
+
198
+ try:
199
+ # Get query embedding
200
+ query_embedding = self.embedder.encode([query])
201
+ faiss.normalize_L2(query_embedding)
202
+
203
+ # Search for similar chunks
204
+ scores, indices = self.index.search(query_embedding.astype('float32'), k)
205
+
206
+ # Get relevant documents
207
+ relevant_docs = []
208
+ for i, idx in enumerate(indices[0]):
209
+ if idx < len(self.documents) and scores[0][i] > 0.1: # Similarity threshold
210
+ relevant_docs.append(self.documents[idx])
211
+
212
+ return "\n\n".join(relevant_docs)
213
+
214
+ except Exception as e:
215
+ print(f"Error in retrieval: {e}")
216
+ return ""
217
+
218
+ def generate_answer(self, query: str, context: str) -> str:
219
+ """Generate answer using the LLM"""
220
+ if self.model is None or self.tokenizer is None:
221
+ return "❌ Model not available. Please try again."
222
+
223
+ try:
224
+ # Create prompt
225
+ prompt = f"""<s>[INST] Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question."
226
 
227
+ Context:
228
+ {context[:2000]} # Limit context length
 
 
 
229
 
230
+ Question: {query}
 
 
231
 
232
+ Answer: [/INST]"""
 
 
 
233
 
234
+ # Tokenize
235
+ inputs = self.tokenizer(
236
+ prompt,
237
+ return_tensors="pt",
238
+ max_length=1024,
239
+ truncation=True,
240
+ padding=True
241
+ )
242
+
243
+ # Generate
244
+ with torch.no_grad():
245
+ outputs = self.model.generate(
246
+ **inputs,
247
+ max_new_tokens=256,
248
+ temperature=0.7,
249
+ do_sample=True,
250
+ top_p=0.9,
251
+ pad_token_id=self.tokenizer.eos_token_id,
252
+ eos_token_id=self.tokenizer.eos_token_id
253
+ )
254
+
255
+ # Decode response
256
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
257
+
258
+ # Extract answer (remove the prompt part)
259
+ if "[/INST]" in full_response:
260
+ answer = full_response.split("[/INST]")[-1].strip()
261
+ else:
262
+ answer = full_response[len(prompt):].strip()
263
+
264
+ return answer if answer else "I couldn't generate a proper response."
265
+
266
+ except Exception as e:
267
+ return f"❌ Error generating answer: {str(e)}"
268
+
269
+ def answer_question(self, query: str) -> str:
270
+ """Main function to answer questions"""
271
+ if not query.strip():
272
+ return "❓ Please ask a question!"
273
+
274
+ if not self.is_indexed:
275
+ return "πŸ“ Please upload and process documents first!"
276
+
277
+ try:
278
+ # Retrieve relevant context
279
+ context = self.retrieve_context(query)
280
+
281
+ if not context:
282
+ return "πŸ” No relevant information found in the uploaded documents."
283
+
284
+ # Generate answer
285
+ answer = self.generate_answer(query, context)
286
+
287
+ return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source Context:** {context[:500]}..."
288
+
289
+ except Exception as e:
290
+ return f"❌ Error answering question: {str(e)}"
291
 
292
+ # Initialize the RAG system
293
+ print("Initializing Document RAG System...")
294
+ rag_system = DocumentRAG()
 
295
 
296
+ # Gradio Interface
297
+ def create_interface():
298
+ with gr.Blocks(title="πŸ“š Document Q&A with RAG", theme=gr.themes.Soft()) as demo:
299
+ gr.Markdown("""
300
+ # πŸ“š Document Q&A System
301
+
302
+ Upload your documents and ask questions about them!
303
+
304
+ **Supported formats:** PDF, DOCX, TXT
305
+ """)
306
+
307
+ with gr.Tab("πŸ“€ Upload Documents"):
308
+ with gr.Row():
309
+ with gr.Column():
310
+ file_upload = gr.File(
311
+ label="Upload Documents",
312
+ file_count="multiple",
313
+ file_types=[".pdf", ".docx", ".txt"]
314
+ )
315
+ process_btn = gr.Button("πŸ”„ Process Documents", variant="primary")
316
+
317
+ with gr.Column():
318
+ process_status = gr.Textbox(
319
+ label="Processing Status",
320
+ lines=8,
321
+ interactive=False
322
+ )
323
+
324
+ process_btn.click(
325
+ fn=rag_system.process_documents,
326
+ inputs=[file_upload],
327
+ outputs=[process_status]
328
+ )
329
+
330
+ with gr.Tab("❓ Ask Questions"):
331
+ with gr.Row():
332
+ with gr.Column():
333
+ question_input = gr.Textbox(
334
+ label="Your Question",
335
+ placeholder="What would you like to know about your documents?",
336
+ lines=3
337
+ )
338
+ ask_btn = gr.Button("πŸ” Get Answer", variant="primary")
339
+
340
+ with gr.Column():
341
+ answer_output = gr.Textbox(
342
+ label="Answer",
343
+ lines=10,
344
+ interactive=False
345
+ )
346
+
347
+ ask_btn.click(
348
+ fn=rag_system.answer_question,
349
+ inputs=[question_input],
350
+ outputs=[answer_output]
351
+ )
352
+
353
+ # Example questions
354
+ gr.Markdown("""
355
+ ### πŸ’‘ Example Questions:
356
+ - What is the main topic of the document?
357
+ - Can you summarize the key points?
358
+ - What are the conclusions mentioned?
359
+ - Are there any specific numbers or statistics?
360
+ """)
361
+
362
+ return demo
363
 
364
+ # Launch the app
365
  if __name__ == "__main__":
366
+ demo = create_interface()
367
+ demo.launch(
368
+ server_name="0.0.0.0",
369
+ server_port=7860,
370
+ share=True
371
+ )