RAG / app.py
pradeepsengarr's picture
Update app.py
b28dcf5 verified
import gradio as gr
import requests
import os
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Constants
CHUNK_SIZE = 300
MODEL_NAME = "all-MiniLM-L6-v2"
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
# Load sentence embedding model
model = SentenceTransformer(MODEL_NAME)
# Global state
doc_chunks, doc_embeddings = [], []
# --- Text Extraction from PDF ---
def extract_pdf_text(file_obj):
"""Extracts and joins text from all pages of a PDF."""
reader = PdfReader(file_obj)
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
# --- Chunk Text ---
def split_text(text, size=CHUNK_SIZE):
"""Splits text into fixed-size word chunks."""
words = text.split()
return [" ".join(words[i:i + size]) for i in range(0, len(words), size)]
# --- File Upload Handling ---
def handle_file_upload(file):
"""Processes the uploaded PDF and caches its embeddings."""
global doc_chunks, doc_embeddings
if not file:
return "⚠️ Please upload a file.", gr.update(visible=False)
try:
text = extract_pdf_text(file)
doc_chunks = split_text(text)
doc_embeddings = model.encode(doc_chunks)
return f"βœ… Processed {len(doc_chunks)} chunks.", gr.update(visible=True, value=f"{len(doc_chunks)} chunks ready.")
except Exception as e:
return f"❌ Failed to process file: {e}", gr.update(visible=False)
# --- Semantic Retrieval ---
def get_top_chunks(query, k=3):
"""Finds top-k relevant chunks using cosine similarity."""
query_emb = model.encode([query])
sims = cosine_similarity(query_emb, doc_embeddings)[0]
indices = np.argsort(sims)[::-1][:k]
return "\n\n".join([doc_chunks[i] for i in indices])
# --- Call LLM via Together API ---
def call_together_ai(context, question):
"""Calls Mixtral LLM from Together API."""
url = "https://api.together.xyz/v1/chat/completions"
headers = {
"Authorization": f"Bearer {TOGETHER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant answering from the given context."},
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
],
"temperature": 0.7,
"max_tokens": 512
}
res = requests.post(url, headers=headers, json=payload)
return res.json()["choices"][0]["message"]["content"]
# --- Serper Web Search ---
def fetch_web_snippets(query):
"""Performs a web search via Serper API."""
url = "https://google.serper.dev/search"
headers = {"X-API-KEY": SERPER_API_KEY}
res = requests.post(url, json={"q": query}, headers=headers).json()
return "\n".join([
f"πŸ”Ή [{r['title']}]({r['link']})\n{r['snippet']}" for r in res.get("organic", [])[:3]
])
# --- Main Chat Logic ---
def respond_to_query(question, source, history):
"""Handles query processing and LLM interaction."""
if not question.strip():
return history, ""
history.append([question, None])
try:
if source == "🌐 Web Search":
context = fetch_web_snippets(question)
source_note = "🌐 Web Search"
elif source == "πŸ“„ Uploaded File":
if not doc_chunks:
answer = "⚠️ Please upload a PDF document first."
history[-1][1] = answer
return history, ""
context = get_top_chunks(question)
source_note = "πŸ“„ Uploaded Document"
else:
history[-1][1] = "❌ Invalid knowledge source selected."
return history, ""
answer = call_together_ai(context, question)
history[-1][1] = f"**{source_note}**\n\n{answer}"
return history, ""
except Exception as e:
history[-1][1] = f"❌ Error: {e}"
return history, ""
# --- Clear Chat ---
def clear_chat(): return []
# --- UI Design ---
css = """
.gradio-container { max-width: 1100px !important; margin: auto; }
h1, h2, h3 { text-align: center; }
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="πŸ” AI RAG Assistant") as demo:
gr.HTML("<h1>πŸ€– AI Chat with RAG Capabilities</h1><h3>Ask questions from PDFs or real-time web search</h3>")
with gr.Row():
with gr.Column(scale=1):
source = gr.Radio(["🌐 Web Search", "πŸ“„ Uploaded File"], label="Knowledge Source", value="🌐 Web Search")
file = gr.File(label="Upload PDF", file_types=[".pdf"])
status = gr.Textbox(label="Status", interactive=False)
doc_info = gr.Textbox(label="Chunks Info", visible=False, interactive=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat", height=500)
query = gr.Textbox(placeholder="Type your question here...", lines=2)
with gr.Row():
send = gr.Button("Send")
clear = gr.Button("Clear")
with gr.Accordion("ℹ️ Info", open=False):
gr.Markdown("- Web Search fetches latest online results\n- PDF mode retrieves answers from your document")
gr.HTML("<div style='text-align:center; font-size:0.9em; color:gray;'>ipradeepsengarr</div>")
# Bind events
file.change(handle_file_upload, inputs=file, outputs=[status, doc_info])
query.submit(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query])
send.click(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query])
clear.click(clear_chat, outputs=[chatbot])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)