Custom_Rag_Bot / app.py
pradeepsengarr's picture
Update app.py
3d46a83 verified
import os
import re
import faiss
import gradio as gr
import numpy as np
import pdfplumber
import docx
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# Utility: Clean text helper
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text) # collapse whitespace
text = text.strip()
return text
# Text chunking (smaller chunks for better semantic search)
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
words = text.split()
chunks = []
start = 0
while start < len(words):
end = min(start + chunk_size, len(words))
chunk = ' '.join(words[start:end])
chunks.append(clean_text(chunk))
start += chunk_size - overlap
return chunks
# Document loader for txt, pdf, docx
def load_document(file_path: str) -> str:
ext = os.path.splitext(file_path)[1].lower()
text = ""
if ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
elif ext == ".pdf":
with pdfplumber.open(file_path) as pdf:
pages = [page.extract_text() for page in pdf.pages if page.extract_text()]
text = "\n".join(pages)
elif ext == ".docx":
doc = docx.Document(file_path)
paragraphs = [para.text for para in doc.paragraphs if para.text.strip()]
text = "\n".join(paragraphs)
else:
raise ValueError(f"Unsupported file type: {ext}")
return clean_text(text)
class SmartDocumentRAG:
def __init__(self):
print("Loading embedder and models...")
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast
self.documents = []
self.embeddings = None
self.index = None
self.is_indexed = False
# Load QA pipelines
self.model_type = "distilbert-qa" # change to "flan-t5" for generative
if self.model_type == "distilbert-qa":
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
elif self.model_type == "flan-t5":
self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
else:
self.qa_pipeline = None
self.document_summary = ""
def process_documents(self, files: List[gr.File]) -> str:
if not files:
return "⚠️ No files uploaded."
print(f"Processing {len(files)} files...")
all_text = ""
for file in files:
try:
# gr.File is a dict-like, get 'name' key for path
path = file.name if hasattr(file, 'name') else file
text = load_document(path)
all_text += text + "\n"
except Exception as e:
print(f"Error loading {file}: {e}")
all_text = clean_text(all_text)
chunks = chunk_text(all_text)
if not chunks:
return "⚠️ No text extracted from documents."
self.documents = chunks
print(f"Created {len(chunks)} text chunks.")
# Embed and build FAISS index
self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True)
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors
faiss.normalize_L2(self.embeddings)
self.index.add(self.embeddings)
self.is_indexed = True
# Generate summary (simple: first 3 chunks joined)
summary_text = " ".join(self.documents[:3])
self.document_summary = summary_text if summary_text else "Summary not available."
return f"βœ… Processed {len(files)} files and created index with {len(chunks)} chunks."
def find_relevant_content(self, query: str, k: int = 5) -> str:
if not self.is_indexed:
return ""
query_emb = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_emb)
k = min(k, len(self.documents))
distances, indices = self.index.search(query_emb, k)
relevant_chunks = []
for dist, idx in zip(distances[0], indices[0]):
if dist > 0.1 and idx < len(self.documents):
relevant_chunks.append(self.documents[idx])
context = " ".join(relevant_chunks)
print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1")
return context
def answer_question(self, query: str) -> str:
if not query.strip():
return "❓ Please ask a valid question."
if not self.is_indexed:
return "πŸ“ Please upload and process documents first."
query_lower = query.lower()
if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']):
return f"πŸ“„ Document Summary:\n\n{self.document_summary}"
context = self.find_relevant_content(query, k=5)
print(f"Context for query: {context[:500]}...")
if not context:
return "πŸ” Sorry, no relevant information found. Try rephrasing your question."
try:
if self.model_type == "distilbert-qa":
result = self.qa_pipeline(question=query, context=context)
print(f"QA pipeline result: {result}")
answer = result.get('answer', '').strip()
score = result.get('score', 0.0)
if not answer or score < 0.05:
return "πŸ€” I couldn't find a confident answer based on the documents."
snippet = context[:300].strip()
if len(context) > 300:
snippet += "..."
return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}"
elif self.model_type == "flan-t5":
prompt = (
f"Answer the question based on the context below.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\nAnswer:"
)
result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1)
print(f"Generative pipeline result: {result}")
answer = result[0]['generated_text'].replace(prompt, '').strip()
if not answer:
return "πŸ€” I couldn't find a confident answer based on the documents."
return f"**Answer:** {answer}"
else:
return "⚠️ Unsupported model type."
except Exception as e:
print(f"Exception in answer_question: {e}")
return f"❌ Error: {str(e)}"
# Create Gradio UI
def create_interface():
rag = SmartDocumentRAG()
with gr.Blocks(title="🧠 Enhanced Document Q&A") as demo:
gr.Markdown(
"""
# 🧠 Enhanced Document Q&A System
**Features:**
- Semantic search with FAISS + SentenceTransformer
- Supports PDF, DOCX, TXT uploads
- Uses DistilBERT or Flan-T5 for Q&A
- Shows answer with context snippet
"""
)
with gr.Tab("Upload & Process"):
file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple")
process_btn = gr.Button("Process Documents")
process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4)
process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status])
with gr.Tab("Q&A"):
question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...")
ask_btn = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output])
with gr.Tab("Summary"):
summary_btn = gr.Button("Get Document Summary")
summary_output = gr.Textbox(label="Summary", lines=6, interactive=False)
summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output])
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)