import gradio as gr import requests import os from PyPDF2 import PdfReader from sentence_transformers import SentenceTransformer import numpy as np from sklearn.metrics.pairwise import cosine_similarity # Constants CHUNK_SIZE = 300 MODEL_NAME = "all-MiniLM-L6-v2" TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") SERPER_API_KEY = os.getenv("SERPER_API_KEY") # Load sentence embedding model model = SentenceTransformer(MODEL_NAME) # Global state doc_chunks, doc_embeddings = [], [] # --- Text Extraction from PDF --- def extract_pdf_text(file_obj): """Extracts and joins text from all pages of a PDF.""" reader = PdfReader(file_obj) return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) # --- Chunk Text --- def split_text(text, size=CHUNK_SIZE): """Splits text into fixed-size word chunks.""" words = text.split() return [" ".join(words[i:i + size]) for i in range(0, len(words), size)] # --- File Upload Handling --- def handle_file_upload(file): """Processes the uploaded PDF and caches its embeddings.""" global doc_chunks, doc_embeddings if not file: return "â ī¸ Please upload a file.", gr.update(visible=False) try: text = extract_pdf_text(file) doc_chunks = split_text(text) doc_embeddings = model.encode(doc_chunks) return f"â Processed {len(doc_chunks)} chunks.", gr.update(visible=True, value=f"{len(doc_chunks)} chunks ready.") except Exception as e: return f"â Failed to process file: {e}", gr.update(visible=False) # --- Semantic Retrieval --- def get_top_chunks(query, k=3): """Finds top-k relevant chunks using cosine similarity.""" query_emb = model.encode([query]) sims = cosine_similarity(query_emb, doc_embeddings)[0] indices = np.argsort(sims)[::-1][:k] return "\n\n".join([doc_chunks[i] for i in indices]) # --- Call LLM via Together API --- def call_together_ai(context, question): """Calls Mixtral LLM from Together API.""" url = "https://api.together.xyz/v1/chat/completions" headers = { "Authorization": f"Bearer {TOGETHER_API_KEY}", "Content-Type": "application/json" } payload = { "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "messages": [ {"role": "system", "content": "You are a helpful assistant answering from the given context."}, {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"} ], "temperature": 0.7, "max_tokens": 512 } res = requests.post(url, headers=headers, json=payload) return res.json()["choices"][0]["message"]["content"] # --- Serper Web Search --- def fetch_web_snippets(query): """Performs a web search via Serper API.""" url = "https://google.serper.dev/search" headers = {"X-API-KEY": SERPER_API_KEY} res = requests.post(url, json={"q": query}, headers=headers).json() return "\n".join([ f"đš [{r['title']}]({r['link']})\n{r['snippet']}" for r in res.get("organic", [])[:3] ]) # --- Main Chat Logic --- def respond_to_query(question, source, history): """Handles query processing and LLM interaction.""" if not question.strip(): return history, "" history.append([question, None]) try: if source == "đ Web Search": context = fetch_web_snippets(question) source_note = "đ Web Search" elif source == "đ Uploaded File": if not doc_chunks: answer = "â ī¸ Please upload a PDF document first." history[-1][1] = answer return history, "" context = get_top_chunks(question) source_note = "đ Uploaded Document" else: history[-1][1] = "â Invalid knowledge source selected." return history, "" answer = call_together_ai(context, question) history[-1][1] = f"**{source_note}**\n\n{answer}" return history, "" except Exception as e: history[-1][1] = f"â Error: {e}" return history, "" # --- Clear Chat --- def clear_chat(): return [] # --- UI Design --- css = """ .gradio-container { max-width: 1100px !important; margin: auto; } h1, h2, h3 { text-align: center; } """ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="đ AI RAG Assistant") as demo: gr.HTML("