Custom_Rag_Bot / app.py
pradeepsengarr's picture
Update app.py
3d46a83 verified
raw
history blame
8.39 kB
import os
import re
import faiss
import gradio as gr
import numpy as np
import pdfplumber
import docx
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# Utility: Clean text helper
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text) # collapse whitespace
text = text.strip()
return text
# Text chunking (smaller chunks for better semantic search)
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
words = text.split()
chunks = []
start = 0
while start < len(words):
end = min(start + chunk_size, len(words))
chunk = ' '.join(words[start:end])
chunks.append(clean_text(chunk))
start += chunk_size - overlap
return chunks
# Document loader for txt, pdf, docx
def load_document(file_path: str) -> str:
ext = os.path.splitext(file_path)[1].lower()
text = ""
if ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
elif ext == ".pdf":
with pdfplumber.open(file_path) as pdf:
pages = [page.extract_text() for page in pdf.pages if page.extract_text()]
text = "\n".join(pages)
elif ext == ".docx":
doc = docx.Document(file_path)
paragraphs = [para.text for para in doc.paragraphs if para.text.strip()]
text = "\n".join(paragraphs)
else:
raise ValueError(f"Unsupported file type: {ext}")
return clean_text(text)
class SmartDocumentRAG:
def __init__(self):
print("Loading embedder and models...")
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast
self.documents = []
self.embeddings = None
self.index = None
self.is_indexed = False
# Load QA pipelines
self.model_type = "distilbert-qa" # change to "flan-t5" for generative
if self.model_type == "distilbert-qa":
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
elif self.model_type == "flan-t5":
self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
else:
self.qa_pipeline = None
self.document_summary = ""
def process_documents(self, files: List[gr.File]) -> str:
if not files:
return "⚠️ No files uploaded."
print(f"Processing {len(files)} files...")
all_text = ""
for file in files:
try:
# gr.File is a dict-like, get 'name' key for path
path = file.name if hasattr(file, 'name') else file
text = load_document(path)
all_text += text + "\n"
except Exception as e:
print(f"Error loading {file}: {e}")
all_text = clean_text(all_text)
chunks = chunk_text(all_text)
if not chunks:
return "⚠️ No text extracted from documents."
self.documents = chunks
print(f"Created {len(chunks)} text chunks.")
# Embed and build FAISS index
self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True)
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors
faiss.normalize_L2(self.embeddings)
self.index.add(self.embeddings)
self.is_indexed = True
# Generate summary (simple: first 3 chunks joined)
summary_text = " ".join(self.documents[:3])
self.document_summary = summary_text if summary_text else "Summary not available."
return f"βœ… Processed {len(files)} files and created index with {len(chunks)} chunks."
def find_relevant_content(self, query: str, k: int = 5) -> str:
if not self.is_indexed:
return ""
query_emb = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_emb)
k = min(k, len(self.documents))
distances, indices = self.index.search(query_emb, k)
relevant_chunks = []
for dist, idx in zip(distances[0], indices[0]):
if dist > 0.1 and idx < len(self.documents):
relevant_chunks.append(self.documents[idx])
context = " ".join(relevant_chunks)
print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1")
return context
def answer_question(self, query: str) -> str:
if not query.strip():
return "❓ Please ask a valid question."
if not self.is_indexed:
return "πŸ“ Please upload and process documents first."
query_lower = query.lower()
if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']):
return f"πŸ“„ Document Summary:\n\n{self.document_summary}"
context = self.find_relevant_content(query, k=5)
print(f"Context for query: {context[:500]}...")
if not context:
return "πŸ” Sorry, no relevant information found. Try rephrasing your question."
try:
if self.model_type == "distilbert-qa":
result = self.qa_pipeline(question=query, context=context)
print(f"QA pipeline result: {result}")
answer = result.get('answer', '').strip()
score = result.get('score', 0.0)
if not answer or score < 0.05:
return "πŸ€” I couldn't find a confident answer based on the documents."
snippet = context[:300].strip()
if len(context) > 300:
snippet += "..."
return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}"
elif self.model_type == "flan-t5":
prompt = (
f"Answer the question based on the context below.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\nAnswer:"
)
result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1)
print(f"Generative pipeline result: {result}")
answer = result[0]['generated_text'].replace(prompt, '').strip()
if not answer:
return "πŸ€” I couldn't find a confident answer based on the documents."
return f"**Answer:** {answer}"
else:
return "⚠️ Unsupported model type."
except Exception as e:
print(f"Exception in answer_question: {e}")
return f"❌ Error: {str(e)}"
# Create Gradio UI
def create_interface():
rag = SmartDocumentRAG()
with gr.Blocks(title="🧠 Enhanced Document Q&A") as demo:
gr.Markdown(
"""
# 🧠 Enhanced Document Q&A System
**Features:**
- Semantic search with FAISS + SentenceTransformer
- Supports PDF, DOCX, TXT uploads
- Uses DistilBERT or Flan-T5 for Q&A
- Shows answer with context snippet
"""
)
with gr.Tab("Upload & Process"):
file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple")
process_btn = gr.Button("Process Documents")
process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4)
process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status])
with gr.Tab("Q&A"):
question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...")
ask_btn = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output])
with gr.Tab("Summary"):
summary_btn = gr.Button("Get Document Summary")
summary_output = gr.Textbox(label="Summary", lines=6, interactive=False)
summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output])
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)