Spaces:
Running
Running
import os | |
import re | |
import faiss | |
import gradio as gr | |
import numpy as np | |
import pdfplumber | |
import docx | |
from typing import List, Optional | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
# Utility: Clean text helper | |
def clean_text(text: str) -> str: | |
text = re.sub(r'\s+', ' ', text) # collapse whitespace | |
text = text.strip() | |
return text | |
# Text chunking (smaller chunks for better semantic search) | |
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]: | |
words = text.split() | |
chunks = [] | |
start = 0 | |
while start < len(words): | |
end = min(start + chunk_size, len(words)) | |
chunk = ' '.join(words[start:end]) | |
chunks.append(clean_text(chunk)) | |
start += chunk_size - overlap | |
return chunks | |
# Document loader for txt, pdf, docx | |
def load_document(file_path: str) -> str: | |
ext = os.path.splitext(file_path)[1].lower() | |
text = "" | |
if ext == ".txt": | |
with open(file_path, 'r', encoding='utf-8') as f: | |
text = f.read() | |
elif ext == ".pdf": | |
with pdfplumber.open(file_path) as pdf: | |
pages = [page.extract_text() for page in pdf.pages if page.extract_text()] | |
text = "\n".join(pages) | |
elif ext == ".docx": | |
doc = docx.Document(file_path) | |
paragraphs = [para.text for para in doc.paragraphs if para.text.strip()] | |
text = "\n".join(paragraphs) | |
else: | |
raise ValueError(f"Unsupported file type: {ext}") | |
return clean_text(text) | |
class SmartDocumentRAG: | |
def __init__(self): | |
print("Loading embedder and models...") | |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast | |
self.documents = [] | |
self.embeddings = None | |
self.index = None | |
self.is_indexed = False | |
# Load QA pipelines | |
self.model_type = "distilbert-qa" # change to "flan-t5" for generative | |
if self.model_type == "distilbert-qa": | |
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
elif self.model_type == "flan-t5": | |
self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") | |
else: | |
self.qa_pipeline = None | |
self.document_summary = "" | |
def process_documents(self, files: List[gr.File]) -> str: | |
if not files: | |
return "β οΈ No files uploaded." | |
print(f"Processing {len(files)} files...") | |
all_text = "" | |
for file in files: | |
try: | |
# gr.File is a dict-like, get 'name' key for path | |
path = file.name if hasattr(file, 'name') else file | |
text = load_document(path) | |
all_text += text + "\n" | |
except Exception as e: | |
print(f"Error loading {file}: {e}") | |
all_text = clean_text(all_text) | |
chunks = chunk_text(all_text) | |
if not chunks: | |
return "β οΈ No text extracted from documents." | |
self.documents = chunks | |
print(f"Created {len(chunks)} text chunks.") | |
# Embed and build FAISS index | |
self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True) | |
dimension = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors | |
faiss.normalize_L2(self.embeddings) | |
self.index.add(self.embeddings) | |
self.is_indexed = True | |
# Generate summary (simple: first 3 chunks joined) | |
summary_text = " ".join(self.documents[:3]) | |
self.document_summary = summary_text if summary_text else "Summary not available." | |
return f"β Processed {len(files)} files and created index with {len(chunks)} chunks." | |
def find_relevant_content(self, query: str, k: int = 5) -> str: | |
if not self.is_indexed: | |
return "" | |
query_emb = self.embedder.encode([query], convert_to_numpy=True) | |
faiss.normalize_L2(query_emb) | |
k = min(k, len(self.documents)) | |
distances, indices = self.index.search(query_emb, k) | |
relevant_chunks = [] | |
for dist, idx in zip(distances[0], indices[0]): | |
if dist > 0.1 and idx < len(self.documents): | |
relevant_chunks.append(self.documents[idx]) | |
context = " ".join(relevant_chunks) | |
print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1") | |
return context | |
def answer_question(self, query: str) -> str: | |
if not query.strip(): | |
return "β Please ask a valid question." | |
if not self.is_indexed: | |
return "π Please upload and process documents first." | |
query_lower = query.lower() | |
if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']): | |
return f"π Document Summary:\n\n{self.document_summary}" | |
context = self.find_relevant_content(query, k=5) | |
print(f"Context for query: {context[:500]}...") | |
if not context: | |
return "π Sorry, no relevant information found. Try rephrasing your question." | |
try: | |
if self.model_type == "distilbert-qa": | |
result = self.qa_pipeline(question=query, context=context) | |
print(f"QA pipeline result: {result}") | |
answer = result.get('answer', '').strip() | |
score = result.get('score', 0.0) | |
if not answer or score < 0.05: | |
return "π€ I couldn't find a confident answer based on the documents." | |
snippet = context[:300].strip() | |
if len(context) > 300: | |
snippet += "..." | |
return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}" | |
elif self.model_type == "flan-t5": | |
prompt = ( | |
f"Answer the question based on the context below.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question: {query}\nAnswer:" | |
) | |
result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1) | |
print(f"Generative pipeline result: {result}") | |
answer = result[0]['generated_text'].replace(prompt, '').strip() | |
if not answer: | |
return "π€ I couldn't find a confident answer based on the documents." | |
return f"**Answer:** {answer}" | |
else: | |
return "β οΈ Unsupported model type." | |
except Exception as e: | |
print(f"Exception in answer_question: {e}") | |
return f"β Error: {str(e)}" | |
# Create Gradio UI | |
def create_interface(): | |
rag = SmartDocumentRAG() | |
with gr.Blocks(title="π§ Enhanced Document Q&A") as demo: | |
gr.Markdown( | |
""" | |
# π§ Enhanced Document Q&A System | |
**Features:** | |
- Semantic search with FAISS + SentenceTransformer | |
- Supports PDF, DOCX, TXT uploads | |
- Uses DistilBERT or Flan-T5 for Q&A | |
- Shows answer with context snippet | |
""" | |
) | |
with gr.Tab("Upload & Process"): | |
file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple") | |
process_btn = gr.Button("Process Documents") | |
process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4) | |
process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status]) | |
with gr.Tab("Q&A"): | |
question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...") | |
ask_btn = gr.Button("Get Answer") | |
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False) | |
ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output]) | |
with gr.Tab("Summary"): | |
summary_btn = gr.Button("Get Document Summary") | |
summary_output = gr.Textbox(label="Summary", lines=6, interactive=False) | |
summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output]) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |