pradeepsengarr commited on
Commit
fd77b07
Β·
verified Β·
1 Parent(s): dcc21e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -638
app.py CHANGED
@@ -1,704 +1,264 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
4
- from sentence_transformers import SentenceTransformer
5
  import faiss
6
  import numpy as np
7
- import PyPDF2
8
- import docx
9
- import io
10
- import os
11
- import re
12
- from typing import List, Optional, Dict, Tuple
13
- import json
14
- from collections import Counter
15
- import warnings
16
- warnings.filterwarnings("ignore")
17
-
18
- class SmartDocumentRAG:
19
- def __init__(self):
20
- print("πŸš€ Initializing Enhanced Smart RAG System...")
21
-
22
- # Initialize better embedding model
23
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # Faster and good quality
24
- print("βœ… Embedding model loaded")
25
-
26
- # Initialize optimized LLM with better quantization
27
- self.setup_llm()
28
-
29
- # Document storage
30
- self.documents = []
31
- self.document_metadata = []
32
- self.index = None
33
- self.is_indexed = False
34
- self.raw_text = ""
35
- self.document_type = "general"
36
- self.document_summary = ""
37
- self.sentence_embeddings = []
38
- self.sentences = []
39
-
40
- def setup_llm(self):
41
- """Setup optimized model with better quantization"""
42
- try:
43
- # Check CUDA availability
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
- print(f"πŸ”§ Using device: {device}")
46
-
47
- if device == "cuda":
48
- self.setup_gpu_model()
49
- else:
50
- self.setup_cpu_model()
51
-
52
- except Exception as e:
53
- print(f"❌ Error loading models: {e}")
54
- self.setup_fallback_model()
55
-
56
- def setup_gpu_model(self):
57
- """Setup GPU model with proper quantization"""
58
- try:
59
- # Use Phi-2 - excellent for Q&A and reasoning
60
- model_name = "microsoft/DialoGPT-medium"
61
-
62
- # Better quantization config
63
- quantization_config = BitsAndBytesConfig(
64
- load_in_4bit=True,
65
- bnb_4bit_compute_dtype=torch.float16,
66
- bnb_4bit_use_double_quant=True,
67
- bnb_4bit_quant_type="nf4",
68
- bnb_4bit_quant_storage=torch.uint8
69
- )
70
-
71
- try:
72
- # Try Flan-T5 first - excellent for Q&A
73
- model_name = "google/flan-t5-base"
74
- print(f"πŸ€– Loading {model_name}...")
75
-
76
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
77
- self.model = AutoModelForCausalLM.from_pretrained(
78
- model_name,
79
- quantization_config=quantization_config,
80
- device_map="auto",
81
- torch_dtype=torch.float16,
82
- trust_remote_code=True
83
- )
84
-
85
- # Create pipeline for easier use
86
- self.qa_pipeline = pipeline(
87
- "text2text-generation",
88
- model=self.model,
89
- tokenizer=self.tokenizer,
90
- max_length=512,
91
- do_sample=True,
92
- temperature=0.3,
93
- top_p=0.9
94
- )
95
-
96
- print("βœ… Flan-T5 model loaded successfully")
97
- self.model_type = "flan-t5"
98
-
99
- except Exception as e:
100
- print(f"Flan-T5 failed, trying Phi-2: {e}")
101
- # Try Phi-2 as backup
102
- model_name = "microsoft/phi-2"
103
- print(f"πŸ€– Loading {model_name}...")
104
-
105
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
106
- self.model = AutoModelForCausalLM.from_pretrained(
107
- model_name,
108
- quantization_config=quantization_config,
109
- device_map="auto",
110
- torch_dtype=torch.float16,
111
- trust_remote_code=True
112
- )
113
-
114
- if self.tokenizer.pad_token is None:
115
- self.tokenizer.pad_token = self.tokenizer.eos_token
116
-
117
- print("βœ… Phi-2 model loaded successfully")
118
- self.model_type = "phi-2"
119
-
120
- except Exception as e:
121
- print(f"❌ GPU models failed: {e}")
122
- self.setup_cpu_model()
123
-
124
- def setup_cpu_model(self):
125
- """Setup CPU-optimized model"""
126
- try:
127
- # Use DistilBERT for Q&A - much better than DialoGPT for this task
128
- model_name = "distilbert-base-cased-distilled-squad"
129
- print(f"πŸ€– Loading CPU model: {model_name}")
130
-
131
- self.qa_pipeline = pipeline(
132
- "question-answering",
133
- model=model_name,
134
- tokenizer=model_name
135
- )
136
- self.model_type = "distilbert-qa"
137
- print("βœ… DistilBERT Q&A model loaded successfully")
138
-
139
- except Exception as e:
140
- print(f"❌ CPU model failed: {e}")
141
- self.setup_fallback_model()
142
 
143
- def setup_fallback_model(self):
144
- """Fallback to basic model"""
145
- try:
146
- print("πŸ€– Loading fallback model...")
147
- self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
148
- self.model_type = "fallback"
149
- print("βœ… Fallback model loaded")
150
- except Exception as e:
151
- print(f"❌ All models failed: {e}")
152
- self.qa_pipeline = None
153
- self.model_type = "none"
154
 
155
- def detect_document_type(self, text: str) -> str:
156
- """Enhanced document type detection"""
157
- text_lower = text.lower()
158
-
159
- resume_patterns = [
160
- 'experience', 'skills', 'education', 'linkedin', 'email', 'phone',
161
- 'work experience', 'employment', 'resume', 'cv', 'curriculum vitae',
162
- 'internship', 'projects', 'achievements', 'career', 'profile', 'objective'
163
- ]
164
-
165
- research_patterns = [
166
- 'abstract', 'introduction', 'methodology', 'conclusion', 'references',
167
- 'literature review', 'hypothesis', 'study', 'research', 'findings',
168
- 'data analysis', 'results', 'discussion', 'bibliography', 'journal'
169
- ]
170
-
171
- business_patterns = [
172
- 'company', 'revenue', 'market', 'strategy', 'business', 'financial',
173
- 'quarter', 'profit', 'sales', 'growth', 'investment', 'stakeholder',
174
- 'operations', 'management', 'corporate', 'enterprise', 'budget'
175
- ]
176
-
177
- technical_patterns = [
178
- 'implementation', 'algorithm', 'system', 'technical', 'specification',
179
- 'architecture', 'development', 'software', 'programming', 'api',
180
- 'database', 'framework', 'deployment', 'infrastructure', 'code'
181
- ]
182
-
183
- def count_matches(patterns, text):
184
- score = 0
185
- for pattern in patterns:
186
- count = text.count(pattern)
187
- score += count * (2 if len(pattern.split()) > 1 else 1) # Weight phrases higher
188
- return score
189
-
190
- scores = {
191
- 'resume': count_matches(resume_patterns, text_lower),
192
- 'research': count_matches(research_patterns, text_lower),
193
- 'business': count_matches(business_patterns, text_lower),
194
- 'technical': count_matches(technical_patterns, text_lower)
195
- }
196
-
197
- max_score = max(scores.values())
198
- if max_score > 5: # Higher threshold
199
- return max(scores, key=scores.get)
200
- return 'general'
201
 
202
- def create_document_summary(self, text: str) -> str:
203
- """Enhanced document summary creation"""
204
- try:
205
- clean_text = re.sub(r'\s+', ' ', text).strip()
206
- sentences = re.split(r'[.!?]+', clean_text)
207
- sentences = [s.strip() for s in sentences if len(s.strip()) > 30]
208
-
209
- if not sentences:
210
- return "Document contains basic information."
211
-
212
- # Use first few sentences and key information
213
- if self.document_type == 'resume':
214
- return self.extract_resume_summary(sentences, clean_text)
215
- elif self.document_type == 'research':
216
- return self.extract_research_summary(sentences)
217
- elif self.document_type == 'business':
218
- return self.extract_business_summary(sentences)
219
- else:
220
- return self.extract_general_summary(sentences)
221
-
222
- except Exception as e:
223
- print(f"Summary creation error: {e}")
224
- return "Document summary not available."
225
 
226
- def extract_resume_summary(self, sentences: List[str], full_text: str) -> str:
227
- """Extract resume-specific summary with better name detection"""
228
- summary_parts = []
229
-
230
- # Extract name using multiple patterns
231
- name = self.extract_name(full_text)
232
- if name:
233
- summary_parts.append(f"Resume of {name}")
234
-
235
- # Extract role/title
236
- role_patterns = [
237
- r'(?:software|senior|junior|lead|principal)?\s*(?:engineer|developer|analyst|manager|designer|architect|consultant)',
238
- r'(?:full stack|frontend|backend|data|ml|ai)\s*(?:engineer|developer)',
239
- r'(?:product|project|technical)\s*manager'
240
- ]
241
-
242
- for sentence in sentences[:5]:
243
- for pattern in role_patterns:
244
- matches = re.findall(pattern, sentence.lower())
245
- if matches:
246
- summary_parts.append(f"working as {matches[0].title()}")
247
- break
248
-
249
- # Extract experience
250
- exp_match = re.search(r'(\d+)[\+\-\s]*(?:years?|yrs?)\s*(?:of\s*)?(?:experience|exp)', full_text.lower())
251
- if exp_match:
252
- summary_parts.append(f"with {exp_match.group(1)}+ years of experience")
253
-
254
- return '. '.join(summary_parts) + '.' if summary_parts else "Professional resume with career details."
255
 
256
- def extract_name(self, text: str) -> str:
257
- """Extract name from document using multiple strategies"""
258
- # Strategy 1: Look for name patterns at the beginning
259
- lines = text.split('\n')[:10] # First 10 lines
260
-
261
- for line in lines:
262
- line = line.strip()
263
- if len(line) < 50 and len(line) > 3: # Likely a header line
264
- # Check if it looks like a name
265
- name_match = re.match(r'^([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)(?:\s|$)', line)
266
- if name_match:
267
- return name_match.group(1)
268
-
269
- # Strategy 2: Look for "Name:" pattern
270
- name_patterns = [
271
- r'(?:name|full name):\s*([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)',
272
- r'^([A-Z][a-z]+\s+[A-Z][a-z]+)(?:\s*\n|\s*email|\s*phone|\s*linkedin)',
273
- ]
274
-
275
- for pattern in name_patterns:
276
- match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
277
- if match:
278
- return match.group(1)
279
-
280
  return ""
281
 
282
- def extract_research_summary(self, sentences: List[str]) -> str:
283
- """Extract research paper summary"""
284
- # Look for abstract or introduction
285
- for sentence in sentences[:5]:
286
- if any(word in sentence.lower() for word in ['abstract', 'study', 'research', 'paper']):
287
- return sentence[:200] + ('...' if len(sentence) > 200 else '')
288
-
289
- return "Research document with academic content."
290
-
291
- def extract_business_summary(self, sentences: List[str]) -> str:
292
- """Extract business document summary"""
293
- for sentence in sentences[:3]:
294
- if any(word in sentence.lower() for word in ['company', 'business', 'organization']):
295
- return sentence[:200] + ('...' if len(sentence) > 200 else '')
296
-
297
- return "Business document with organizational information."
298
 
299
- def extract_general_summary(self, sentences: List[str]) -> str:
300
- """Extract general document summary"""
301
- return sentences[0][:200] + ('...' if len(sentences[0]) > 200 else '') if sentences else "General document."
302
 
303
- def extract_text_from_file(self, file_path: str) -> str:
304
- """Enhanced text extraction"""
305
- try:
306
- file_extension = os.path.splitext(file_path)[1].lower()
307
-
308
- if file_extension == '.pdf':
309
- return self.extract_from_pdf(file_path)
310
- elif file_extension == '.docx':
311
- return self.extract_from_docx(file_path)
312
- elif file_extension == '.txt':
313
- return self.extract_from_txt(file_path)
314
- else:
315
- return f"Unsupported file format: {file_extension}"
316
-
317
- except Exception as e:
318
- return f"Error reading file: {str(e)}"
319
-
320
- def extract_from_pdf(self, file_path: str) -> str:
321
- """Enhanced PDF extraction"""
322
- text = ""
323
- try:
324
- with open(file_path, 'rb') as file:
325
- pdf_reader = PyPDF2.PdfReader(file)
326
- for page in pdf_reader.pages:
327
- page_text = page.extract_text()
328
- if page_text.strip():
329
- # Better text cleaning
330
- page_text = re.sub(r'\s+', ' ', page_text)
331
- page_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', page_text) # Fix merged words
332
- text += f"{page_text}\n"
333
- except Exception as e:
334
- text = f"Error reading PDF: {str(e)}"
335
- return text.strip()
336
 
337
- def extract_from_docx(self, file_path: str) -> str:
338
- """Enhanced DOCX extraction"""
339
- try:
340
- doc = docx.Document(file_path)
 
 
 
 
 
 
 
 
 
 
341
  text = ""
342
- for paragraph in doc.paragraphs:
343
- if paragraph.text.strip():
344
- text += paragraph.text.strip() + "\n"
345
- return text.strip()
346
- except Exception as e:
347
- return f"Error reading DOCX: {str(e)}"
348
-
349
- def extract_from_txt(self, file_path: str) -> str:
350
- """Enhanced TXT extraction"""
351
- encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']
 
 
352
 
353
- for encoding in encodings:
354
- try:
355
- with open(file_path, 'r', encoding=encoding) as file:
356
- return file.read().strip()
357
- except UnicodeDecodeError:
358
- continue
359
- except Exception as e:
360
- return f"Error reading TXT: {str(e)}"
361
 
362
- return "Error: Could not decode file"
363
-
364
- def enhanced_chunk_text(self, text: str, max_chunk_size: int = 300, overlap: int = 50) -> list[str]:
365
- """
366
- Splits text into smaller overlapping chunks for better semantic search.
367
 
368
- Args:
369
- text (str): The full text to chunk.
370
- max_chunk_size (int): Maximum tokens/words per chunk.
371
- overlap (int): Number of words overlapping between consecutive chunks.
372
-
373
- Returns:
374
- list[str]: List of text chunks.
375
- """
376
- import re
377
-
378
- # Clean and normalize whitespace
379
- text = re.sub(r'\s+', ' ', text).strip()
380
 
381
- words = text.split()
382
- chunks = []
383
- start = 0
384
- text_len = len(words)
385
-
386
- while start < text_len:
387
- end = min(start + max_chunk_size, text_len)
388
- chunk_words = words[start:end]
389
- chunk = ' '.join(chunk_words)
390
- chunks.append(chunk)
391
- # Move start forward by chunk size minus overlap to create overlap
392
- start += max_chunk_size - overlap
393
-
394
- return chunks
395
-
396
-
397
- def process_documents(self, files) -> str:
398
- """Enhanced document processing"""
399
- if not files:
400
- return "❌ No files uploaded!"
401
 
402
- try:
403
- all_text = ""
404
- processed_files = []
405
-
406
- for file in files:
407
- if file is None:
408
- continue
409
-
410
- file_text = self.extract_text_from_file(file.name)
411
- if not file_text.startswith("Error") and not file_text.startswith("Unsupported"):
412
- all_text += f"\n{file_text}"
413
- processed_files.append(os.path.basename(file.name))
414
- else:
415
- return f"❌ {file_text}"
416
-
417
- if not all_text.strip():
418
- return "❌ No text extracted from files!"
419
-
420
- # Store and analyze
421
- self.raw_text = all_text
422
- self.document_type = self.detect_document_type(all_text)
423
- self.document_summary = self.create_document_summary(all_text)
424
-
425
- # Enhanced chunking
426
- chunk_data = self.enhanced_chunk_text(all_text)
427
-
428
- if not chunk_data:
429
- return "❌ No valid text chunks created!"
430
-
431
- self.documents = [chunk['text'] for chunk in chunk_data]
432
- self.document_metadata = chunk_data
433
-
434
- # Create embeddings
435
- print(f"πŸ“„ Creating embeddings for {len(self.documents)} chunks...")
436
- embeddings = self.embedder.encode(self.documents, show_progress_bar=False)
437
-
438
- # Build FAISS index
439
- dimension = embeddings.shape[1]
440
- self.index = faiss.IndexFlatIP(dimension)
441
-
442
- # Normalize for cosine similarity
443
- faiss.normalize_L2(embeddings)
444
- self.index.add(embeddings.astype('float32'))
445
-
446
- self.is_indexed = True
447
-
448
- return f"βœ… Successfully processed {len(processed_files)} files:\n" + \
449
- f"πŸ“„ Files: {', '.join(processed_files)}\n" + \
450
- f"πŸ“Š Document Type: {self.document_type.title()}\n" + \
451
- f"πŸ” Created {len(self.documents)} chunks\n" + \
452
- f"πŸ“ Summary: {self.document_summary}\n" + \
453
- f"πŸš€ Ready for Q&A!"
454
-
455
- except Exception as e:
456
- return f"❌ Error processing documents: {str(e)}"
457
-
458
- def find_relevant_content(self, query: str, k: int = 3) -> str:
459
- """Improved content retrieval with stricter relevance filter"""
460
- if not self.is_indexed:
461
- return ""
462
 
463
- try:
464
- # Semantic search
465
- query_embedding = self.embedder.encode([query])
466
- faiss.normalize_L2(query_embedding)
467
-
468
- scores, indices = self.index.search(query_embedding.astype('float32'), min(k, len(self.documents)))
469
-
470
- relevant_chunks = []
471
- for i, idx in enumerate(indices[0]):
472
- score = scores[0][i]
473
- if idx < len(self.documents) and score > 0.4: # βœ… stricter similarity filter
474
- relevant_chunks.append(self.documents[idx])
475
-
476
- return ' '.join(relevant_chunks)
477
-
478
- except Exception as e:
479
- print(f"Error in content retrieval: {e}")
480
- return ""
481
-
482
 
483
- def answer_question(self, query: str) -> str:
484
- """Enhanced question answering with better model usage and hallucination reduction."""
485
- if not query.strip():
486
- return "❓ Please ask a question!"
 
 
 
 
 
 
 
 
487
 
488
- if not self.is_indexed:
489
- return "πŸ“ Please upload and process documents first!"
490
 
491
- try:
492
- query_lower = query.lower()
493
-
494
- # Handle summary requests explicitly
495
- if any(word in query_lower for word in ['summary', 'summarize', 'about', 'overview']):
496
- return f"πŸ“„ **Document Summary:**\n\n{self.document_summary}"
497
-
498
- # Retrieve relevant content chunks via semantic search
499
- context = self.find_relevant_content(query, k=3)
500
-
501
- if not context:
502
- return "πŸ” No relevant information found. Try rephrasing your question."
503
-
504
- # If no QA pipeline, fall back to direct extraction
505
- if self.qa_pipeline is None:
506
- return self.extract_direct_answer(query, context)
507
-
508
- try:
509
- if self.model_type in ["distilbert-qa", "fallback"]:
510
- # Use extractive Q&A pipeline
511
- result = self.qa_pipeline(question=query, context=context)
512
- answer = result.get('answer', '').strip()
513
- confidence = result.get('score', 0)
514
-
515
- if confidence > 0.1 and answer:
516
- return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..."
517
- else:
518
- return self.extract_direct_answer(query, context)
519
-
520
- elif self.model_type == "flan-t5":
521
- # Use generative model with improved prompt to reduce hallucination
522
- prompt = (
523
- f"Answer concisely and strictly based on the following context.\n\n"
524
- f"Context:\n{context}\n\n"
525
- f"Question:\n{query}\n\n"
526
- f"If the answer is not contained in the context, reply with 'Not found in document.'\n"
527
- f"Answer:"
528
- )
529
- result = self.qa_pipeline(prompt, max_length=256, num_return_sequences=1)
530
- generated_text = result[0].get('generated_text', '')
531
- answer = generated_text.replace(prompt, '').strip()
532
-
533
- if answer.lower() in ["not found in document.", "no answer", "unknown", ""]:
534
- return "πŸ” Sorry, the answer was not found in the documents."
535
- else:
536
- return f"**Answer:** {answer}"
537
-
538
- else:
539
- # Default fallback extraction
540
- return self.extract_direct_answer(query, context)
541
-
542
- except Exception as e:
543
- print(f"Model inference error: {e}")
544
- return self.extract_direct_answer(query, context)
545
-
546
- except Exception as e:
547
- return f"❌ Error processing question: {str(e)}"
548
-
549
 
550
  def extract_direct_answer(self, query: str, context: str) -> str:
551
- """Direct answer extraction as fallback"""
552
- query_lower = query.lower()
553
-
554
- # Name extraction
555
- if any(word in query_lower for word in ['name', 'who is', 'who']):
556
  names = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', context)
557
  if names:
558
  return f"**Name:** {names[0]}"
559
 
560
- # Experience extraction
561
- if any(word in query_lower for word in ['experience', 'years']):
562
- exp_matches = re.findall(r'(\d+)[\+\-\s]*(?:years?|yrs?)', context.lower())
563
- if exp_matches:
564
- return f"**Experience:** {exp_matches[0]} years"
565
 
566
- # Skills extraction
567
- if any(word in query_lower for word in ['skill', 'technology', 'tech']):
568
- # Common tech skills
569
- tech_patterns = [
570
- r'\b(?:Python|Java|JavaScript|React|Node|SQL|AWS|Docker|Kubernetes|Git)\b',
571
- r'\b(?:HTML|CSS|Angular|Vue|Spring|Django|Flask|MongoDB|PostgreSQL)\b'
572
- ]
573
- skills = []
574
- for pattern in tech_patterns:
575
- skills.extend(re.findall(pattern, context, re.IGNORECASE))
576
-
577
  if skills:
578
- return f"**Skills mentioned:** {', '.join(set(skills))}"
 
579
 
580
- # Education extraction
581
- if any(word in query_lower for word in ['education', 'degree', 'university']):
582
- edu_matches = re.findall(r'(?:Bachelor|Master|PhD|B\.?S\.?|M\.?S\.?|B\.?A\.?|M\.?A\.?).*?(?:in|of)\s+([^.]+)', context)
583
- if edu_matches:
584
- return f"**Education:** {edu_matches[0]}"
585
 
586
- # Return first relevant sentence
587
  sentences = [s.strip() for s in context.split('.') if s.strip()]
588
  if sentences:
589
  return f"**Answer:** {sentences[0]}"
590
-
591
- return "I found relevant content but couldn't extract a specific answer."
592
-
593
- def clean_text(self, text: str) -> str:
594
- """
595
- Clean and normalize raw text by:
596
- - Removing excessive whitespace
597
- - Fixing merged words (camel case separation)
598
- - Removing unwanted characters (optional)
599
- - Lowercasing or preserving case (optional)
600
- """
601
- import re
602
 
603
- # Replace multiple whitespace/newlines/tabs with single space
604
- text = re.sub(r'\s+', ' ', text).strip()
 
 
 
605
 
606
- # Fix merged words like 'wordAnotherWord' -> 'word Another Word'
607
- text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
 
608
 
609
- # Optional: remove special characters except basic punctuation
610
- # text = re.sub(r'[^a-zA-Z0-9,.!?;:\'\"()\-\s]', '', text)
 
611
 
612
- return text
613
-
614
-
 
 
 
 
 
 
 
 
 
 
 
615
 
 
616
 
 
 
617
 
 
 
 
 
 
 
 
 
618
 
619
-
620
- # Initialize the system
621
- print("Initializing Enhanced Smart RAG System...")
622
- rag_system = SmartDocumentRAG()
623
-
624
- # Create the interface
625
- def create_interface():
626
  with gr.Blocks(title="🧠 Enhanced Document Q&A", theme=gr.themes.Soft()) as demo:
627
  gr.Markdown("""
628
  # 🧠 Enhanced Document Q&A System
629
 
630
- **Optimized with Better Models & Quantization!**
631
 
632
- **Features:**
633
- - 🎯 Flan-T5 or DistilBERT for accurate Q&A
634
- - ⚑ 4-bit quantization for GPU efficiency
635
- - πŸ“Š Direct answer extraction
636
- - πŸ” Enhanced semantic search
637
  """)
638
 
639
  with gr.Tab("πŸ“€ Upload & Process"):
640
  with gr.Row():
641
  with gr.Column():
642
- file_upload = gr.File(
643
- label="πŸ“ Upload Documents",
644
- file_count="multiple",
645
- file_types=[".pdf", ".docx", ".txt"],
646
- height=150
647
- )
648
  process_btn = gr.Button("πŸ”„ Process Documents", variant="primary", size="lg")
649
-
650
  with gr.Column():
651
- process_status = gr.Textbox(
652
- label="πŸ“‹ Processing Status",
653
- lines=10,
654
- interactive=False
655
- )
656
-
657
- process_btn.click(
658
- fn=rag_system.process_documents,
659
- inputs=[file_upload],
660
- outputs=[process_status]
661
- )
662
 
663
  with gr.Tab("❓ Q&A"):
664
  with gr.Row():
665
  with gr.Column():
666
- question_input = gr.Textbox(
667
- label="πŸ€” Ask Your Question",
668
- placeholder="What is the person's name? / How many years of experience? / What skills do they have?",
669
- lines=3
670
- )
671
-
672
  with gr.Row():
673
  ask_btn = gr.Button("🧠 Get Answer", variant="primary")
674
  summary_btn = gr.Button("πŸ“Š Get Summary", variant="secondary")
675
-
676
  with gr.Column():
677
- answer_output = gr.Textbox(
678
- label="πŸ’‘ Answer",
679
- lines=8,
680
- interactive=False
681
- )
682
-
683
- ask_btn.click(
684
- fn=rag_system.answer_question,
685
- inputs=[question_input],
686
- outputs=[answer_output]
687
- )
688
-
689
- summary_btn.click(
690
- fn=lambda: rag_system.answer_question("summary"),
691
- inputs=[],
692
- outputs=[answer_output]
693
- )
694
 
695
- return demo
696
 
697
- # Launch the app
698
  if __name__ == "__main__":
699
- demo = create_interface()
700
- demo.launch(
701
- server_name="0.0.0.0",
702
- server_port=7860,
703
- share=True
704
- )
 
1
+ import re
2
+ import os
 
 
3
  import faiss
4
  import numpy as np
5
+ import gradio as gr
6
+ from typing import List
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline
9
+ from PyPDF2 import PdfReader
10
+ import docx2txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # === Helper functions ===
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def clean_text(text: str) -> str:
15
+ """Clean and normalize text."""
16
+ text = re.sub(r'\s+', ' ', text) # normalize whitespace
17
+ text = text.strip()
18
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def chunk_text(text: str, max_chunk_size: int = 300, overlap: int = 50) -> List[str]:
21
+ """Split text into smaller overlapping chunks for better semantic search."""
22
+ sentences = re.split(r'(?<=[.?!])\s+', text)
23
+ chunks = []
24
+ chunk = ""
25
+ for sentence in sentences:
26
+ if len(chunk) + len(sentence) <= max_chunk_size:
27
+ chunk += sentence + " "
28
+ else:
29
+ chunks.append(chunk.strip())
30
+ chunk = sentence + " "
31
+ if chunk:
32
+ chunks.append(chunk.strip())
33
+ # Add overlapping between chunks to retain context
34
+ overlapped_chunks = []
35
+ for i in range(len(chunks)):
36
+ combined = chunks[i]
37
+ if i > 0:
38
+ combined = chunks[i-1][-overlap:] + " " + combined
39
+ overlapped_chunks.append(clean_text(combined))
40
+ return overlapped_chunks
 
 
41
 
42
+ def extract_text_from_pdf(file_path: str) -> str:
43
+ """Extract text from PDF file."""
44
+ text = ""
45
+ try:
46
+ reader = PdfReader(file_path)
47
+ for page in reader.pages:
48
+ text += page.extract_text() + " "
49
+ except Exception as e:
50
+ print(f"Error reading PDF {file_path}: {e}")
51
+ return clean_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def extract_text_from_docx(file_path: str) -> str:
54
+ """Extract text from DOCX file."""
55
+ try:
56
+ text = docx2txt.process(file_path)
57
+ return clean_text(text)
58
+ except Exception as e:
59
+ print(f"Error reading DOCX {file_path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return ""
61
 
62
+ def extract_text_from_txt(file_path: str) -> str:
63
+ """Extract text from TXT file."""
64
+ try:
65
+ with open(file_path, 'r', encoding='utf-8') as f:
66
+ text = f.read()
67
+ return clean_text(text)
68
+ except Exception as e:
69
+ print(f"Error reading TXT {file_path}: {e}")
70
+ return ""
 
 
 
 
 
 
 
71
 
72
+ # === Main RAG System ===
 
 
73
 
74
+ class SmartDocumentRAG:
75
+ def __init__(self):
76
+ # Model & embedding initialization
77
+ self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
78
+ self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
79
+ self.documents = []
80
+ self.chunks = []
81
+ self.index = None
82
+ self.is_indexed = False
83
+ self.document_summary = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def process_documents(self, uploaded_files) -> str:
86
+ """Load, extract, chunk, embed, and index documents."""
87
+ if not uploaded_files:
88
+ return "⚠️ No files uploaded."
89
+
90
+ self.documents.clear()
91
+ self.chunks.clear()
92
+ all_text = ""
93
+
94
+ # Extract text from each uploaded file
95
+ for file_obj in uploaded_files:
96
+ # Save file temporarily to disk to process
97
+ file_path = file_obj.name
98
+ ext = os.path.splitext(file_path)[1].lower()
99
  text = ""
100
+ if ext == ".pdf":
101
+ text = extract_text_from_pdf(file_path)
102
+ elif ext == ".docx":
103
+ text = extract_text_from_docx(file_path)
104
+ elif ext == ".txt":
105
+ text = extract_text_from_txt(file_path)
106
+ else:
107
+ continue # skip unsupported
108
+
109
+ if text:
110
+ self.documents.append(text)
111
+ all_text += text + " "
112
 
113
+ if not all_text.strip():
114
+ return "⚠️ No extractable text found in uploaded files."
 
 
 
 
 
 
115
 
116
+ # Create chunks for semantic search
117
+ self.chunks = chunk_text(all_text)
 
 
 
118
 
119
+ # Create embeddings for chunks
120
+ embeddings = self.embedder.encode(self.chunks, convert_to_numpy=True)
121
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # normalize
 
 
 
 
 
 
 
 
 
122
 
123
+ # Create FAISS index
124
+ dim = embeddings.shape[1]
125
+ self.index = faiss.IndexFlatIP(dim)
126
+ self.index.add(embeddings.astype('float32'))
127
+ self.is_indexed = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # Create simple summary
130
+ self.document_summary = self.generate_summary(all_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ return f"βœ… Processed {len(self.documents)} document(s), {len(self.chunks)} chunks indexed."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def generate_summary(self, text: str) -> str:
135
+ """Generate a simple summary using top sentences."""
136
+ sentences = re.split(r'(?<=[.?!])\s+', text)
137
+ summary = ' '.join(sentences[:5]) # first 5 sentences as naive summary
138
+ return summary
139
+
140
+ def find_relevant_content(self, query: str, top_k: int = 3) -> str:
141
+ """Perform semantic search to find relevant content chunks."""
142
+ if not self.is_indexed or not self.chunks:
143
+ return ""
144
+ query_emb = self.embedder.encode([query], convert_to_numpy=True)
145
+ query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
146
 
147
+ scores, indices = self.index.search(query_emb.astype('float32'), min(top_k, len(self.chunks)))
 
148
 
149
+ relevant_chunks = []
150
+ for i, idx in enumerate(indices[0]):
151
+ if scores[0][i] > 0.1:
152
+ relevant_chunks.append(self.chunks[idx])
153
+ return " ".join(relevant_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def extract_direct_answer(self, query: str, context: str) -> str:
156
+ """Simple regex-based fallback extraction."""
157
+ q = query.lower()
158
+ if any(word in q for word in ['name', 'who is', 'who']):
 
 
159
  names = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', context)
160
  if names:
161
  return f"**Name:** {names[0]}"
162
 
163
+ if any(word in q for word in ['experience', 'years']):
164
+ years = re.findall(r'(\d+)[\+\-\s]*(?:years?|yrs?)', context.lower())
165
+ if years:
166
+ return f"**Experience:** {years[0]} years"
 
167
 
168
+ if any(word in q for word in ['skill', 'technology', 'tech']):
169
+ skills = re.findall(r'\b(?:Python|Java|JavaScript|React|Node|SQL|AWS|Docker|Kubernetes|Git|HTML|CSS|Angular|Vue|Spring|Django|Flask|MongoDB|PostgreSQL)\b', context, re.I)
 
 
 
 
 
 
 
 
 
170
  if skills:
171
+ unique_skills = sorted(set(skills), key=skills.index)
172
+ return f"**Skills:** {', '.join(unique_skills)}"
173
 
174
+ if any(word in q for word in ['education', 'degree', 'university']):
175
+ edu = re.findall(r'(?:Bachelor|Master|PhD|B\.?S\.?|M\.?S\.?|B\.?A\.?|M\.?A\.?).*?(?:in|of)\s+([^.]+)', context, re.I)
176
+ if edu:
177
+ return f"**Education:** {edu[0]}"
 
178
 
179
+ # Fallback: first sentence from context
180
  sentences = [s.strip() for s in context.split('.') if s.strip()]
181
  if sentences:
182
  return f"**Answer:** {sentences[0]}"
183
+ return "I found relevant content but could not extract a specific answer."
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ def answer_question(self, query: str) -> str:
186
+ if not query.strip():
187
+ return "❓ Please ask a question."
188
+ if not self.is_indexed:
189
+ return "πŸ“ Please upload and process documents first."
190
 
191
+ q_lower = query.lower()
192
+ if any(word in q_lower for word in ['summary', 'summarize', 'overview', 'about']):
193
+ return f"πŸ“„ **Document Summary:**\n\n{self.document_summary}"
194
 
195
+ context = self.find_relevant_content(query, top_k=3)
196
+ if not context:
197
+ return "πŸ” No relevant information found. Try rephrasing your question."
198
 
199
+ try:
200
+ # Use model for QA
201
+ result = self.qa_pipeline(question=query, context=context)
202
+ answer = result.get('answer', '').strip()
203
+ score = result.get('score', 0)
204
+
205
+ # Confidence threshold to fallback to regex extraction
206
+ if score < 0.1 or not answer:
207
+ return self.extract_direct_answer(query, context)
208
+ return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..."
209
+
210
+ except Exception as e:
211
+ print(f"QA model error: {e}")
212
+ return self.extract_direct_answer(query, context)
213
 
214
+ # === Gradio UI ===
215
 
216
+ def main():
217
+ rag = SmartDocumentRAG()
218
 
219
+ def process_files(files):
220
+ return rag.process_documents(files)
221
+
222
+ def ask_question(question):
223
+ return rag.answer_question(question)
224
+
225
+ def get_summary():
226
+ return rag.answer_question("summary")
227
 
 
 
 
 
 
 
 
228
  with gr.Blocks(title="🧠 Enhanced Document Q&A", theme=gr.themes.Soft()) as demo:
229
  gr.Markdown("""
230
  # 🧠 Enhanced Document Q&A System
231
 
232
+ **Optimized with Better Models & Semantic Search**
233
 
234
+ - Upload PDF, DOCX, TXT files
235
+ - Semantic search + QA pipeline
236
+ - Direct answer extraction fallback
 
 
237
  """)
238
 
239
  with gr.Tab("πŸ“€ Upload & Process"):
240
  with gr.Row():
241
  with gr.Column():
242
+ file_upload = gr.File(label="πŸ“ Upload Documents", file_types=['.pdf','.docx','.txt'], file_count="multiple", height=150)
 
 
 
 
 
243
  process_btn = gr.Button("πŸ”„ Process Documents", variant="primary", size="lg")
 
244
  with gr.Column():
245
+ process_status = gr.Textbox(label="πŸ“‹ Processing Status", lines=10, interactive=False)
246
+ process_btn.click(fn=process_files, inputs=file_upload, outputs=process_status)
 
 
 
 
 
 
 
 
 
247
 
248
  with gr.Tab("❓ Q&A"):
249
  with gr.Row():
250
  with gr.Column():
251
+ question_input = gr.Textbox(label="πŸ€” Ask Your Question", lines=3,
252
+ placeholder="Name? Experience? Skills? Education?")
 
 
 
 
253
  with gr.Row():
254
  ask_btn = gr.Button("🧠 Get Answer", variant="primary")
255
  summary_btn = gr.Button("πŸ“Š Get Summary", variant="secondary")
 
256
  with gr.Column():
257
+ answer_output = gr.Textbox(label="πŸ’‘ Answer", lines=8, interactive=False)
258
+ ask_btn.click(fn=ask_question, inputs=question_input, outputs=answer_output)
259
+ summary_btn.click(fn=get_summary, inputs=None, outputs=answer_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
262
 
 
263
  if __name__ == "__main__":
264
+ main()