Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import os | |
from PyPDF2 import PdfReader | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Constants | |
CHUNK_SIZE = 300 | |
MODEL_NAME = "all-MiniLM-L6-v2" | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
# Load sentence embedding model | |
model = SentenceTransformer(MODEL_NAME) | |
# Global state | |
doc_chunks, doc_embeddings = [], [] | |
# --- Text Extraction from PDF --- | |
def extract_pdf_text(file_obj): | |
"""Extracts and joins text from all pages of a PDF.""" | |
reader = PdfReader(file_obj) | |
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
# --- Chunk Text --- | |
def split_text(text, size=CHUNK_SIZE): | |
"""Splits text into fixed-size word chunks.""" | |
words = text.split() | |
return [" ".join(words[i:i + size]) for i in range(0, len(words), size)] | |
# --- File Upload Handling --- | |
def handle_file_upload(file): | |
"""Processes the uploaded PDF and caches its embeddings.""" | |
global doc_chunks, doc_embeddings | |
if not file: | |
return "β οΈ Please upload a file.", gr.update(visible=False) | |
try: | |
text = extract_pdf_text(file) | |
doc_chunks = split_text(text) | |
doc_embeddings = model.encode(doc_chunks) | |
return f"β Processed {len(doc_chunks)} chunks.", gr.update(visible=True, value=f"{len(doc_chunks)} chunks ready.") | |
except Exception as e: | |
return f"β Failed to process file: {e}", gr.update(visible=False) | |
# --- Semantic Retrieval --- | |
def get_top_chunks(query, k=3): | |
"""Finds top-k relevant chunks using cosine similarity.""" | |
query_emb = model.encode([query]) | |
sims = cosine_similarity(query_emb, doc_embeddings)[0] | |
indices = np.argsort(sims)[::-1][:k] | |
return "\n\n".join([doc_chunks[i] for i in indices]) | |
# --- Call LLM via Together API --- | |
def call_together_ai(context, question): | |
"""Calls Mixtral LLM from Together API.""" | |
url = "https://api.together.xyz/v1/chat/completions" | |
headers = { | |
"Authorization": f"Bearer {TOGETHER_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant answering from the given context."}, | |
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"} | |
], | |
"temperature": 0.7, | |
"max_tokens": 512 | |
} | |
res = requests.post(url, headers=headers, json=payload) | |
return res.json()["choices"][0]["message"]["content"] | |
# --- Serper Web Search --- | |
def fetch_web_snippets(query): | |
"""Performs a web search via Serper API.""" | |
url = "https://google.serper.dev/search" | |
headers = {"X-API-KEY": SERPER_API_KEY} | |
res = requests.post(url, json={"q": query}, headers=headers).json() | |
return "\n".join([ | |
f"πΉ [{r['title']}]({r['link']})\n{r['snippet']}" for r in res.get("organic", [])[:3] | |
]) | |
# --- Main Chat Logic --- | |
def respond_to_query(question, source, history): | |
"""Handles query processing and LLM interaction.""" | |
if not question.strip(): | |
return history, "" | |
history.append([question, None]) | |
try: | |
if source == "π Web Search": | |
context = fetch_web_snippets(question) | |
source_note = "π Web Search" | |
elif source == "π Uploaded File": | |
if not doc_chunks: | |
answer = "β οΈ Please upload a PDF document first." | |
history[-1][1] = answer | |
return history, "" | |
context = get_top_chunks(question) | |
source_note = "π Uploaded Document" | |
else: | |
history[-1][1] = "β Invalid knowledge source selected." | |
return history, "" | |
answer = call_together_ai(context, question) | |
history[-1][1] = f"**{source_note}**\n\n{answer}" | |
return history, "" | |
except Exception as e: | |
history[-1][1] = f"β Error: {e}" | |
return history, "" | |
# --- Clear Chat --- | |
def clear_chat(): return [] | |
# --- UI Design --- | |
css = """ | |
.gradio-container { max-width: 1100px !important; margin: auto; } | |
h1, h2, h3 { text-align: center; } | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="π AI RAG Assistant") as demo: | |
gr.HTML("<h1>π€ AI Chat with RAG Capabilities</h1><h3>Ask questions from PDFs or real-time web search</h3>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
source = gr.Radio(["π Web Search", "π Uploaded File"], label="Knowledge Source", value="π Web Search") | |
file = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
status = gr.Textbox(label="Status", interactive=False) | |
doc_info = gr.Textbox(label="Chunks Info", visible=False, interactive=False) | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(label="Chat", height=500) | |
query = gr.Textbox(placeholder="Type your question here...", lines=2) | |
with gr.Row(): | |
send = gr.Button("Send") | |
clear = gr.Button("Clear") | |
with gr.Accordion("βΉοΈ Info", open=False): | |
gr.Markdown("- Web Search fetches latest online results\n- PDF mode retrieves answers from your document") | |
gr.HTML("<div style='text-align:center; font-size:0.9em; color:gray;'>ipradeepsengarr</div>") | |
# Bind events | |
file.change(handle_file_upload, inputs=file, outputs=[status, doc_info]) | |
query.submit(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query]) | |
send.click(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query]) | |
clear.click(clear_chat, outputs=[chatbot]) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |