import os import re import faiss import gradio as gr import numpy as np import pdfplumber import docx from typing import List, Optional from sentence_transformers import SentenceTransformer from transformers import pipeline # Utility: Clean text helper def clean_text(text: str) -> str: text = re.sub(r'\s+', ' ', text) # collapse whitespace text = text.strip() return text # Text chunking (smaller chunks for better semantic search) def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]: words = text.split() chunks = [] start = 0 while start < len(words): end = min(start + chunk_size, len(words)) chunk = ' '.join(words[start:end]) chunks.append(clean_text(chunk)) start += chunk_size - overlap return chunks # Document loader for txt, pdf, docx def load_document(file_path: str) -> str: ext = os.path.splitext(file_path)[1].lower() text = "" if ext == ".txt": with open(file_path, 'r', encoding='utf-8') as f: text = f.read() elif ext == ".pdf": with pdfplumber.open(file_path) as pdf: pages = [page.extract_text() for page in pdf.pages if page.extract_text()] text = "\n".join(pages) elif ext == ".docx": doc = docx.Document(file_path) paragraphs = [para.text for para in doc.paragraphs if para.text.strip()] text = "\n".join(paragraphs) else: raise ValueError(f"Unsupported file type: {ext}") return clean_text(text) class SmartDocumentRAG: def __init__(self): print("Loading embedder and models...") self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast self.documents = [] self.embeddings = None self.index = None self.is_indexed = False # Load QA pipelines self.model_type = "distilbert-qa" # change to "flan-t5" for generative if self.model_type == "distilbert-qa": self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") elif self.model_type == "flan-t5": self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") else: self.qa_pipeline = None self.document_summary = "" def process_documents(self, files: List[gr.File]) -> str: if not files: return "⚠️ No files uploaded." print(f"Processing {len(files)} files...") all_text = "" for file in files: try: # gr.File is a dict-like, get 'name' key for path path = file.name if hasattr(file, 'name') else file text = load_document(path) all_text += text + "\n" except Exception as e: print(f"Error loading {file}: {e}") all_text = clean_text(all_text) chunks = chunk_text(all_text) if not chunks: return "⚠️ No text extracted from documents." self.documents = chunks print(f"Created {len(chunks)} text chunks.") # Embed and build FAISS index self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True) dimension = self.embeddings.shape[1] self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors faiss.normalize_L2(self.embeddings) self.index.add(self.embeddings) self.is_indexed = True # Generate summary (simple: first 3 chunks joined) summary_text = " ".join(self.documents[:3]) self.document_summary = summary_text if summary_text else "Summary not available." return f"✅ Processed {len(files)} files and created index with {len(chunks)} chunks." def find_relevant_content(self, query: str, k: int = 5) -> str: if not self.is_indexed: return "" query_emb = self.embedder.encode([query], convert_to_numpy=True) faiss.normalize_L2(query_emb) k = min(k, len(self.documents)) distances, indices = self.index.search(query_emb, k) relevant_chunks = [] for dist, idx in zip(distances[0], indices[0]): if dist > 0.1 and idx < len(self.documents): relevant_chunks.append(self.documents[idx]) context = " ".join(relevant_chunks) print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1") return context def answer_question(self, query: str) -> str: if not query.strip(): return "❓ Please ask a valid question." if not self.is_indexed: return "📁 Please upload and process documents first." query_lower = query.lower() if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']): return f"📄 Document Summary:\n\n{self.document_summary}" context = self.find_relevant_content(query, k=5) print(f"Context for query: {context[:500]}...") if not context: return "🔍 Sorry, no relevant information found. Try rephrasing your question." try: if self.model_type == "distilbert-qa": result = self.qa_pipeline(question=query, context=context) print(f"QA pipeline result: {result}") answer = result.get('answer', '').strip() score = result.get('score', 0.0) if not answer or score < 0.05: return "🤔 I couldn't find a confident answer based on the documents." snippet = context[:300].strip() if len(context) > 300: snippet += "..." return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}" elif self.model_type == "flan-t5": prompt = ( f"Answer the question based on the context below.\n\n" f"Context:\n{context}\n\n" f"Question: {query}\nAnswer:" ) result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1) print(f"Generative pipeline result: {result}") answer = result[0]['generated_text'].replace(prompt, '').strip() if not answer: return "🤔 I couldn't find a confident answer based on the documents." return f"**Answer:** {answer}" else: return "⚠️ Unsupported model type." except Exception as e: print(f"Exception in answer_question: {e}") return f"❌ Error: {str(e)}" # Create Gradio UI def create_interface(): rag = SmartDocumentRAG() with gr.Blocks(title="🧠 Enhanced Document Q&A") as demo: gr.Markdown( """ # 🧠 Enhanced Document Q&A System **Features:** - Semantic search with FAISS + SentenceTransformer - Supports PDF, DOCX, TXT uploads - Uses DistilBERT or Flan-T5 for Q&A - Shows answer with context snippet """ ) with gr.Tab("Upload & Process"): file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple") process_btn = gr.Button("Process Documents") process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4) process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status]) with gr.Tab("Q&A"): question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...") ask_btn = gr.Button("Get Answer") answer_output = gr.Textbox(label="Answer", lines=8, interactive=False) ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output]) with gr.Tab("Summary"): summary_btn = gr.Button("Get Document Summary") summary_output = gr.Textbox(label="Summary", lines=6, interactive=False) summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output]) return demo if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)