Spaces:
Running
Running
File size: 8,391 Bytes
d64804c aee5caa 1dedfac d64804c aee5caa 3d46a83 fd77b07 aee5caa 3d46a83 253bfed fd77b07 3d46a83 d64804c 3d46a83 fd77b07 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 aee5caa 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c fd77b07 d64804c 3d46a83 d64804c 3d46a83 d64804c aee5caa d64804c aee5caa 3d46a83 adc1d58 3d46a83 48c1ca7 3d46a83 adc1d58 3d46a83 aee5caa 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 48c1ca7 3d46a83 aee5caa 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 d64804c 3d46a83 aee5caa 3d46a83 aee5caa 3d46a83 d64804c aee5caa 6705397 aee5caa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import re
import faiss
import gradio as gr
import numpy as np
import pdfplumber
import docx
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# Utility: Clean text helper
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text) # collapse whitespace
text = text.strip()
return text
# Text chunking (smaller chunks for better semantic search)
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
words = text.split()
chunks = []
start = 0
while start < len(words):
end = min(start + chunk_size, len(words))
chunk = ' '.join(words[start:end])
chunks.append(clean_text(chunk))
start += chunk_size - overlap
return chunks
# Document loader for txt, pdf, docx
def load_document(file_path: str) -> str:
ext = os.path.splitext(file_path)[1].lower()
text = ""
if ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
elif ext == ".pdf":
with pdfplumber.open(file_path) as pdf:
pages = [page.extract_text() for page in pdf.pages if page.extract_text()]
text = "\n".join(pages)
elif ext == ".docx":
doc = docx.Document(file_path)
paragraphs = [para.text for para in doc.paragraphs if para.text.strip()]
text = "\n".join(paragraphs)
else:
raise ValueError(f"Unsupported file type: {ext}")
return clean_text(text)
class SmartDocumentRAG:
def __init__(self):
print("Loading embedder and models...")
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # small, fast
self.documents = []
self.embeddings = None
self.index = None
self.is_indexed = False
# Load QA pipelines
self.model_type = "distilbert-qa" # change to "flan-t5" for generative
if self.model_type == "distilbert-qa":
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
elif self.model_type == "flan-t5":
self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
else:
self.qa_pipeline = None
self.document_summary = ""
def process_documents(self, files: List[gr.File]) -> str:
if not files:
return "β οΈ No files uploaded."
print(f"Processing {len(files)} files...")
all_text = ""
for file in files:
try:
# gr.File is a dict-like, get 'name' key for path
path = file.name if hasattr(file, 'name') else file
text = load_document(path)
all_text += text + "\n"
except Exception as e:
print(f"Error loading {file}: {e}")
all_text = clean_text(all_text)
chunks = chunk_text(all_text)
if not chunks:
return "β οΈ No text extracted from documents."
self.documents = chunks
print(f"Created {len(chunks)} text chunks.")
# Embed and build FAISS index
self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True)
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Cosine similarity with normalized vectors
faiss.normalize_L2(self.embeddings)
self.index.add(self.embeddings)
self.is_indexed = True
# Generate summary (simple: first 3 chunks joined)
summary_text = " ".join(self.documents[:3])
self.document_summary = summary_text if summary_text else "Summary not available."
return f"β
Processed {len(files)} files and created index with {len(chunks)} chunks."
def find_relevant_content(self, query: str, k: int = 5) -> str:
if not self.is_indexed:
return ""
query_emb = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_emb)
k = min(k, len(self.documents))
distances, indices = self.index.search(query_emb, k)
relevant_chunks = []
for dist, idx in zip(distances[0], indices[0]):
if dist > 0.1 and idx < len(self.documents):
relevant_chunks.append(self.documents[idx])
context = " ".join(relevant_chunks)
print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1")
return context
def answer_question(self, query: str) -> str:
if not query.strip():
return "β Please ask a valid question."
if not self.is_indexed:
return "π Please upload and process documents first."
query_lower = query.lower()
if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']):
return f"π Document Summary:\n\n{self.document_summary}"
context = self.find_relevant_content(query, k=5)
print(f"Context for query: {context[:500]}...")
if not context:
return "π Sorry, no relevant information found. Try rephrasing your question."
try:
if self.model_type == "distilbert-qa":
result = self.qa_pipeline(question=query, context=context)
print(f"QA pipeline result: {result}")
answer = result.get('answer', '').strip()
score = result.get('score', 0.0)
if not answer or score < 0.05:
return "π€ I couldn't find a confident answer based on the documents."
snippet = context[:300].strip()
if len(context) > 300:
snippet += "..."
return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}"
elif self.model_type == "flan-t5":
prompt = (
f"Answer the question based on the context below.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\nAnswer:"
)
result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1)
print(f"Generative pipeline result: {result}")
answer = result[0]['generated_text'].replace(prompt, '').strip()
if not answer:
return "π€ I couldn't find a confident answer based on the documents."
return f"**Answer:** {answer}"
else:
return "β οΈ Unsupported model type."
except Exception as e:
print(f"Exception in answer_question: {e}")
return f"β Error: {str(e)}"
# Create Gradio UI
def create_interface():
rag = SmartDocumentRAG()
with gr.Blocks(title="π§ Enhanced Document Q&A") as demo:
gr.Markdown(
"""
# π§ Enhanced Document Q&A System
**Features:**
- Semantic search with FAISS + SentenceTransformer
- Supports PDF, DOCX, TXT uploads
- Uses DistilBERT or Flan-T5 for Q&A
- Shows answer with context snippet
"""
)
with gr.Tab("Upload & Process"):
file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple")
process_btn = gr.Button("Process Documents")
process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4)
process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status])
with gr.Tab("Q&A"):
question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...")
ask_btn = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output])
with gr.Tab("Summary"):
summary_btn = gr.Button("Get Document Summary")
summary_output = gr.Textbox(label="Summary", lines=6, interactive=False)
summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output])
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|