File size: 8,391 Bytes
d64804c
aee5caa
1dedfac
d64804c
aee5caa
3d46a83
 
 
 
fd77b07
 
aee5caa
3d46a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253bfed
fd77b07
3d46a83
 
 
d64804c
3d46a83
fd77b07
 
d64804c
3d46a83
 
 
 
 
 
d64804c
3d46a83
d64804c
3d46a83
d64804c
3d46a83
aee5caa
3d46a83
 
d64804c
3d46a83
 
 
 
 
 
 
 
 
d64804c
3d46a83
 
d64804c
3d46a83
 
d64804c
3d46a83
 
d64804c
3d46a83
 
 
 
 
 
 
d64804c
3d46a83
 
 
d64804c
3d46a83
d64804c
3d46a83
d64804c
fd77b07
d64804c
3d46a83
 
 
 
d64804c
3d46a83
 
 
 
 
 
 
d64804c
 
aee5caa
d64804c
aee5caa
3d46a83
 
adc1d58
 
3d46a83
 
48c1ca7
3d46a83
 
adc1d58
3d46a83
 
aee5caa
3d46a83
48c1ca7
3d46a83
48c1ca7
 
3d46a83
48c1ca7
3d46a83
 
48c1ca7
 
 
 
3d46a83
48c1ca7
 
 
 
 
 
 
3d46a83
48c1ca7
 
3d46a83
48c1ca7
3d46a83
48c1ca7
3d46a83
 
aee5caa
3d46a83
 
d64804c
3d46a83
 
 
d64804c
3d46a83
 
 
 
 
 
 
 
 
 
 
d64804c
3d46a83
 
 
 
d64804c
3d46a83
d64804c
3d46a83
 
 
 
d64804c
3d46a83
aee5caa
3d46a83
 
 
aee5caa
3d46a83
d64804c
aee5caa
 
6705397
aee5caa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import re
import faiss
import gradio as gr
import numpy as np
import pdfplumber
import docx
from typing import List, Optional

from sentence_transformers import SentenceTransformer
from transformers import pipeline

# Utility: Clean text helper
def clean_text(text: str) -> str:
    text = re.sub(r'\s+', ' ', text)  # collapse whitespace
    text = text.strip()
    return text

# Text chunking (smaller chunks for better semantic search)
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
    words = text.split()
    chunks = []
    start = 0
    while start < len(words):
        end = min(start + chunk_size, len(words))
        chunk = ' '.join(words[start:end])
        chunks.append(clean_text(chunk))
        start += chunk_size - overlap
    return chunks

# Document loader for txt, pdf, docx
def load_document(file_path: str) -> str:
    ext = os.path.splitext(file_path)[1].lower()
    text = ""
    if ext == ".txt":
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
    elif ext == ".pdf":
        with pdfplumber.open(file_path) as pdf:
            pages = [page.extract_text() for page in pdf.pages if page.extract_text()]
            text = "\n".join(pages)
    elif ext == ".docx":
        doc = docx.Document(file_path)
        paragraphs = [para.text for para in doc.paragraphs if para.text.strip()]
        text = "\n".join(paragraphs)
    else:
        raise ValueError(f"Unsupported file type: {ext}")
    return clean_text(text)

class SmartDocumentRAG:
    def __init__(self):
        print("Loading embedder and models...")
        self.embedder = SentenceTransformer('all-MiniLM-L6-v2')  # small, fast
        self.documents = []
        self.embeddings = None
        self.index = None
        self.is_indexed = False

        # Load QA pipelines
        self.model_type = "distilbert-qa"  # change to "flan-t5" for generative
        if self.model_type == "distilbert-qa":
            self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
        elif self.model_type == "flan-t5":
            self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
        else:
            self.qa_pipeline = None

        self.document_summary = ""

    def process_documents(self, files: List[gr.File]) -> str:
        if not files:
            return "⚠️ No files uploaded."
        print(f"Processing {len(files)} files...")

        all_text = ""
        for file in files:
            try:
                # gr.File is a dict-like, get 'name' key for path
                path = file.name if hasattr(file, 'name') else file
                text = load_document(path)
                all_text += text + "\n"
            except Exception as e:
                print(f"Error loading {file}: {e}")

        all_text = clean_text(all_text)
        chunks = chunk_text(all_text)

        if not chunks:
            return "⚠️ No text extracted from documents."

        self.documents = chunks
        print(f"Created {len(chunks)} text chunks.")

        # Embed and build FAISS index
        self.embeddings = self.embedder.encode(self.documents, convert_to_numpy=True)
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dimension)  # Cosine similarity with normalized vectors
        faiss.normalize_L2(self.embeddings)
        self.index.add(self.embeddings)
        self.is_indexed = True

        # Generate summary (simple: first 3 chunks joined)
        summary_text = " ".join(self.documents[:3])
        self.document_summary = summary_text if summary_text else "Summary not available."

        return f"βœ… Processed {len(files)} files and created index with {len(chunks)} chunks."

    def find_relevant_content(self, query: str, k: int = 5) -> str:
        if not self.is_indexed:
            return ""

        query_emb = self.embedder.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(query_emb)
        k = min(k, len(self.documents))
        distances, indices = self.index.search(query_emb, k)

        relevant_chunks = []
        for dist, idx in zip(distances[0], indices[0]):
            if dist > 0.1 and idx < len(self.documents):
                relevant_chunks.append(self.documents[idx])
        context = " ".join(relevant_chunks)
        print(f"Found {len(relevant_chunks)} relevant chunks with distances >0.1")
        return context

    def answer_question(self, query: str) -> str:
        if not query.strip():
            return "❓ Please ask a valid question."
        if not self.is_indexed:
            return "πŸ“ Please upload and process documents first."

        query_lower = query.lower()
        if any(word in query_lower for word in ['summary', 'summarize', 'overview', 'about']):
            return f"πŸ“„ Document Summary:\n\n{self.document_summary}"

        context = self.find_relevant_content(query, k=5)
        print(f"Context for query: {context[:500]}...")

        if not context:
            return "πŸ” Sorry, no relevant information found. Try rephrasing your question."

        try:
            if self.model_type == "distilbert-qa":
                result = self.qa_pipeline(question=query, context=context)
                print(f"QA pipeline result: {result}")
                answer = result.get('answer', '').strip()
                score = result.get('score', 0.0)

                if not answer or score < 0.05:
                    return "πŸ€” I couldn't find a confident answer based on the documents."

                snippet = context[:300].strip()
                if len(context) > 300:
                    snippet += "..."
                return f"**Answer:** {answer}\n\n*Context snippet:* {snippet}"

            elif self.model_type == "flan-t5":
                prompt = (
                    f"Answer the question based on the context below.\n\n"
                    f"Context:\n{context}\n\n"
                    f"Question: {query}\nAnswer:"
                )
                result = self.qa_pipeline(prompt, max_length=200, num_return_sequences=1)
                print(f"Generative pipeline result: {result}")
                answer = result[0]['generated_text'].replace(prompt, '').strip()
                if not answer:
                    return "πŸ€” I couldn't find a confident answer based on the documents."
                return f"**Answer:** {answer}"

            else:
                return "⚠️ Unsupported model type."

        except Exception as e:
            print(f"Exception in answer_question: {e}")
            return f"❌ Error: {str(e)}"

# Create Gradio UI
def create_interface():
    rag = SmartDocumentRAG()

    with gr.Blocks(title="🧠 Enhanced Document Q&A") as demo:
        gr.Markdown(
            """
            # 🧠 Enhanced Document Q&A System
            **Features:**
            - Semantic search with FAISS + SentenceTransformer
            - Supports PDF, DOCX, TXT uploads
            - Uses DistilBERT or Flan-T5 for Q&A
            - Shows answer with context snippet
            """
        )

        with gr.Tab("Upload & Process"):
            file_upload = gr.File(file_types=['.pdf', '.docx', '.txt'], label="Upload Documents", file_count="multiple")
            process_btn = gr.Button("Process Documents")
            process_status = gr.Textbox(label="Processing Status", interactive=False, lines=4)

            process_btn.click(fn=rag.process_documents, inputs=[file_upload], outputs=[process_status])

        with gr.Tab("Q&A"):
            question_input = gr.Textbox(label="Ask your question", lines=2, placeholder="Type your question here...")
            ask_btn = gr.Button("Get Answer")
            answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)

            ask_btn.click(fn=rag.answer_question, inputs=[question_input], outputs=[answer_output])

        with gr.Tab("Summary"):
            summary_btn = gr.Button("Get Document Summary")
            summary_output = gr.Textbox(label="Summary", lines=6, interactive=False)

            summary_btn.click(fn=lambda: rag.answer_question("summary"), inputs=[], outputs=[summary_output])

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)