import os import shutil import PyPDF2 import gradio as gr from PIL import Image from typing import List # Unstructured for rich PDF parsing from unstructured.partition.pdf import partition_pdf from unstructured.partition.utils.constants import PartitionStrategy # Vision-language captioning (BLIP) from transformers import BlipProcessor, BlipForConditionalGeneration # Hugging Face Inference client from huggingface_hub import InferenceClient # LangChain vectorstore and embeddings from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings # ── Globals ─────────────────────────────────────────────────────────────────── retriever = None # FAISS retriever for multimodal content current_pdf_name = None # Name of the currently loaded PDF combined_texts: List[str] = [] # Combined text + image captions corpus # ── Setup: directories ───────────────────────────────────────────────────────── FIGURES_DIR = "figures" if os.path.exists(FIGURES_DIR): shutil.rmtree(FIGURES_DIR) os.makedirs(FIGURES_DIR, exist_ok=True) # ── Clients & Models ─────────────────────────────────────────────────────────── hf = InferenceClient() # uses HUGGINGFACEHUB_API_TOKEN env var # BLIP captioner blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") def generate_caption(image_path: str) -> str: """Generate caption for image via BLIP.""" image = Image.open(image_path).convert("RGB") inputs = blip_processor(image, return_tensors="pt") out = blip_model.generate(**inputs) return blip_processor.decode(out[0], skip_special_tokens=True) def embed_texts(texts: List[str]) -> List[List[float]]: """Call HF inference embeddings endpoint.""" resp = hf.embeddings(model="google/Gemma-Embeddings-v1.0", inputs=texts) return resp["embeddings"] def process_pdf(pdf_file) -> str: """ Parse PDF, extract text and images, caption images, embed all chunks remotely, build FAISS index. """ global retriever, current_pdf_name, combined_texts if pdf_file is None: return "❌ Please upload a PDF file." pdf_path = pdf_file.name current_pdf_name = os.path.basename(pdf_path) # Attempt rich parsing try: from pdf2image.exceptions import PDFInfoNotInstalledError elements = partition_pdf( filename=pdf_path, strategy=PartitionStrategy.HI_RES, extract_image_block_types=["Image","Table"], extract_image_block_output_dir=FIGURES_DIR, ) text_elements = [el.text for el in elements if el.category not in ["Image","Table"] and el.text] image_files = [os.path.join(FIGURES_DIR, f) for f in os.listdir(FIGURES_DIR) if f.lower().endswith((".png",".jpg",".jpeg"))] except Exception: # Fallback to text-only from pypdf import PdfReader reader = PdfReader(pdf_path) text_elements = [page.extract_text() or "" for page in reader.pages] image_files = [] captions = [generate_caption(img) for img in image_files] combined_texts = text_elements + captions vectors = embed_texts(combined_texts) index = FAISS.from_embeddings(texts=combined_texts, embeddings=vectors) retriever = index.as_retriever(search_kwargs={"k":2}) return f"✅ Indexed '{current_pdf_name}' — {len(text_elements)} text blocks + {len(captions)} image captions" def ask_question(question: str) -> str: """Retrieve from FAISS and call chat completion.""" global retriever if retriever is None: return "❌ Please process a PDF first." if not question.strip(): return "❌ Please enter a question." docs = retriever.get_relevant_documents(question) context = "\n\n".join(doc.page_content for doc in docs) prompt = ( "Use the following excerpts to answer the question:\n\n" f"{context}\n\nQuestion: {question}\nAnswer:" ) response = hf.chat_completion( model="google/gemma-3-27b-it", messages=[{"role":"user","content":prompt}], max_tokens=128, temperature=0.5, ) return response["choices"][0]["message"]["content"].strip() def clear_interface(): """Reset all state and clear extracted images.""" global retriever, current_pdf_name, combined_texts retriever = None current_pdf_name = None combined_texts = [] shutil.rmtree(FIGURES_DIR, ignore_errors=True) os.makedirs(FIGURES_DIR, exist_ok=True) return "" # Gradio UI with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")) as demo: gr.Markdown("# DocQueryAI (Remote‐RAG)") with gr.Row(): with gr.Column(): pdf_file = gr.File(file_types=[".pdf"], type="filepath") process_btn = gr.Button("Process PDF") status_box = gr.Textbox(interactive=False) with gr.Column(): question_input = gr.Textbox(lines=3) ask_btn = gr.Button("Ask") answer_output = gr.Textbox(interactive=False) clear_btn = gr.Button("Clear All") process_btn.click(fn=process_pdf, inputs=[pdf_file], outputs=[status_box]) ask_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output]) clear_btn.click(fn=clear_interface, outputs=[status_box, answer_output]) if __name__ == "__main__": demo.launch()