Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import base64 | |
from PIL import Image | |
import io | |
import requests | |
# Import vectorstore and embeddings from langchain community package | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
# Text splitter to break large documents into manageable chunks | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# HF Inference client for running chat completions | |
from huggingface_hub import InferenceClient | |
# Unstructured for advanced PDF processing with image/table extraction | |
from unstructured.partition.pdf import partition_pdf | |
from unstructured.partition.utils.constants import PartitionStrategy | |
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
index = None # FAISS index storing document embeddings | |
retriever = None # Retriever to fetch relevant chunks | |
current_pdf_name = None # Name of the currently loaded PDF | |
pdf_text = None # Full text of the uploaded PDF | |
extracted_images = [] # List to store extracted images and their descriptions | |
# Create directories for storing extracted figures | |
FIGURES_DIR = "extracted_figures/" | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
# ββ HF Inference clients for different models βββββββββββββββββββββββββββββββββ | |
# Text generation model | |
text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3") | |
# Vision-Language Models (choose one based on your needs and HF availability) | |
# Option 1: BLIP-2 for general image understanding | |
vision_client = InferenceClient(model="Salesforce/blip2-opt-2.7b") | |
# Option 2: Alternative vision models you can use: | |
# vision_client = InferenceClient(model="microsoft/git-base-coco") | |
# vision_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning") | |
# vision_client = InferenceClient(model="Salesforce/blip-image-captioning-large") | |
# For more advanced multimodal tasks, you can use: | |
# multimodal_client = InferenceClient(model="microsoft/DialoGPT-medium") # For conversational AI | |
# multimodal_client = InferenceClient(model="facebook/opt-iml-max-30b") # For instruction following | |
# ββ Multimodal Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Primary: CLIP embeddings for excellent text-image alignment | |
try: | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/clip-ViT-B-32", | |
model_kwargs={'device': 'cpu'}, # Ensure CPU usage for HF Spaces | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
print("β Using CLIP embeddings for multimodal support") | |
except Exception as e: | |
print(f"β οΈ CLIP failed, falling back to BGE: {e}") | |
# Fallback to BGE embeddings | |
embeddings = HuggingFaceEmbeddings( | |
model_name="BAAI/bge-base-en-v1.5", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
def create_multimodal_embeddings(text_chunks, image_descriptions): | |
""" | |
Create embeddings that combine text and visual information | |
""" | |
try: | |
all_chunks = [] | |
# Process text chunks | |
for chunk in text_chunks: | |
# Add context markers for better embedding | |
enhanced_chunk = f"Document text: {chunk}" | |
all_chunks.append(enhanced_chunk) | |
# Process image descriptions with special formatting | |
for img_desc in image_descriptions: | |
# Mark visual content for better embedding alignment | |
enhanced_desc = f"Visual content: {img_desc}" | |
all_chunks.append(enhanced_desc) | |
return all_chunks | |
except Exception as e: | |
print(f"Error creating multimodal embeddings: {e}") | |
return text_chunks + image_descriptions | |
""" | |
Enhanced image description using multiple vision models | |
""" | |
try: | |
# Load and process image | |
with open(image_path, "rb") as f: | |
image_bytes = f.read() | |
# Method 1: Use BLIP-2 for detailed image captioning | |
try: | |
description = vision_client.image_to_text(image_bytes) | |
base_description = description if isinstance(description, str) else description.get('generated_text', '') | |
except Exception as e: | |
print(f"BLIP-2 failed: {e}") | |
base_description = "Image could not be processed with vision model" | |
# Method 2: Enhance with text-based analysis using the text model | |
enhancement_prompt = f""" | |
Analyze this image description and provide a detailed analysis focusing on: | |
1. Any text, numbers, or data visible | |
2. Charts, graphs, or tables | |
3. Key visual elements and their significance | |
4. Context and meaning | |
Description: {base_description} | |
Provide a comprehensive analysis: | |
""" | |
try: | |
response = text_client.chat_completion( | |
messages=[{"role": "user", "content": enhancement_prompt}], | |
max_tokens=300, | |
temperature=0.3 | |
) | |
enhanced_description = response["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
print(f"Text enhancement failed: {e}") | |
enhanced_description = base_description | |
return f"Visual Element Analysis:\n{enhanced_description}" | |
except Exception as e: | |
print(f"Error processing image {image_path}: {str(e)}") | |
return f"Visual element detected: {os.path.basename(image_path)} (processing failed)" | |
def process_pdf_multimodal_advanced(pdf_file): | |
""" | |
Advanced multimodal PDF processing with enhanced vision capabilities | |
""" | |
global current_pdf_name, index, retriever, pdf_text, extracted_images | |
if pdf_file is None: | |
return None, "β Please upload a PDF file.", gr.update(interactive=False) | |
current_pdf_name = os.path.basename(pdf_file.name) | |
extracted_images = [] | |
# Clear existing figures directory | |
for file in os.listdir(FIGURES_DIR): | |
try: | |
os.remove(os.path.join(FIGURES_DIR, file)) | |
except: | |
pass | |
try: | |
# Process PDF with unstructured | |
elements = partition_pdf( | |
pdf_file.name, | |
strategy=PartitionStrategy.HI_RES, | |
extract_image_block_types=["Image", "Table"], | |
extract_image_block_output_dir=FIGURES_DIR, | |
extract_image_block_to_payload=False, | |
# Additional parameters for better extraction | |
infer_table_structure=True, | |
chunking_strategy="by_title", | |
max_characters=1000, | |
combine_text_under_n_chars=100 | |
) | |
# Process elements | |
text_elements = [] | |
visual_descriptions = [] | |
for element in elements: | |
if element.category in ["Image", "Table"]: | |
# Handle image/table elements | |
continue | |
elif element.category == "Title": | |
text_elements.append(f"TITLE: {element.text}") | |
elif element.category == "Header": | |
text_elements.append(f"HEADER: {element.text}") | |
else: | |
if hasattr(element, 'text') and element.text.strip(): | |
text_elements.append(element.text) | |
pdf_text = "\n\n".join(text_elements) | |
# Process extracted visual elements | |
if os.path.exists(FIGURES_DIR): | |
for filename in sorted(os.listdir(FIGURES_DIR)): | |
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
image_path = os.path.join(FIGURES_DIR, filename) | |
# Get enhanced description | |
description = extract_image_description_advanced(image_path) | |
visual_descriptions.append(description) | |
extracted_images.append({ | |
'path': image_path, | |
'description': description, | |
'filename': filename, | |
'type': 'table' if 'table' in filename.lower() else 'image' | |
}) | |
# Combine all content | |
all_content = text_elements + visual_descriptions | |
# Combine text and visual content with enhanced embedding strategy | |
text_chunks = text_splitter.split_text(pdf_text) if pdf_text else [] | |
# Create multimodal embeddings | |
all_chunks = create_multimodal_embeddings(text_chunks, visual_descriptions) | |
# Create FAISS index with optimized settings for multimodal content | |
if all_chunks: | |
index = FAISS.from_texts(all_chunks, embeddings) | |
retriever = index.as_retriever( | |
search_type="mmr", # Maximum marginal relevance for diverse results | |
search_kwargs={ | |
"k": 5, # Get more results for multimodal content | |
"fetch_k": 10, # Broader initial search | |
"lambda_mult": 0.6 # Balance between relevance and diversity | |
} | |
) | |
else: | |
raise Exception("No content extracted from PDF") | |
status = f"β Advanced processing complete for '{current_pdf_name}'\nπ {len(text_elements)} text sections\nπΌοΈ {len(extracted_images)} visual elements\nπ¦ {len(all_chunks)} total searchable chunks" | |
return current_pdf_name, status, gr.update(interactive=True) | |
except Exception as e: | |
error_msg = f"β Processing error: {str(e)}" | |
return current_pdf_name, error_msg, gr.update(interactive=False) | |
def ask_question_multimodal_advanced(pdf_name, question): | |
""" | |
Advanced multimodal question answering with smart routing | |
""" | |
global retriever, extracted_images | |
if index is None or retriever is None: | |
return "β Please upload and process a PDF first." | |
if not question.strip(): | |
return "β Please enter a question." | |
try: | |
# Retrieve relevant chunks | |
docs = retriever.get_relevant_documents(question) | |
context = "\n\n".join([doc.page_content for doc in docs]) | |
# Enhanced visual query detection | |
visual_keywords = [ | |
'image', 'figure', 'chart', 'graph', 'table', 'diagram', 'picture', | |
'visual', 'show', 'display', 'plot', 'data', 'visualization', | |
'illustration', 'screenshot', 'photo', 'drawing' | |
] | |
is_visual_query = any(keyword in question.lower() for keyword in visual_keywords) | |
# Smart context enhancement | |
if is_visual_query and extracted_images: | |
# Prioritize visual content for visual queries | |
visual_context = "\n\n".join([img['description'] for img in extracted_images]) | |
enhanced_context = f"{visual_context}\n\nAdditional Context:\n{context}" | |
else: | |
enhanced_context = context | |
# Advanced prompting based on query type | |
if is_visual_query: | |
system_prompt = """You are an expert document analyst specializing in multimodal content analysis. | |
You excel at interpreting charts, graphs, tables, images, and visual data alongside textual information. | |
When answering questions about visual elements, be specific about what you observe and provide detailed insights.""" | |
else: | |
system_prompt = """You are an expert document analyst. Provide accurate, comprehensive answers based on the document content. | |
Use the context provided to give detailed and helpful responses.""" | |
prompt = f"""{system_prompt} | |
Context: {enhanced_context} | |
Question: {question} | |
Provide a detailed, accurate answer based on the context above. If the question relates to visual elements, describe what you can understand from the visual descriptions provided.""" | |
response = text_client.chat_completion( | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=400, | |
temperature=0.4 | |
) | |
answer = response["choices"][0]["message"]["content"].strip() | |
return answer | |
except Exception as e: | |
return f"β Error generating answer: {str(e)}" | |
def analyze_document_structure(): | |
""" | |
New feature: Analyze the overall structure of the document | |
""" | |
global pdf_text, extracted_images | |
if not pdf_text and not extracted_images: | |
return "β Please upload and process a PDF first." | |
try: | |
structure_prompt = f""" | |
Analyze the structure and organization of this document. Provide insights about: | |
1. Document type and purpose | |
2. Main sections and topics | |
3. Visual elements present ({len(extracted_images)} images/tables/charts) | |
4. Key information hierarchy | |
5. Overall document quality and completeness | |
Text content sample: {pdf_text[:1000]} | |
Visual elements: {len(extracted_images)} items detected | |
Provide a structural analysis: | |
""" | |
response = text_client.chat_completion( | |
messages=[{"role": "user", "content": structure_prompt}], | |
max_tokens=300, | |
temperature=0.3 | |
) | |
return response["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
return f"β Error analyzing structure: {str(e)}" | |
# [Previous functions remain the same: generate_summary_multimodal, extract_keywords_multimodal, show_extracted_images, clear_interface_multimodal] | |
def generate_summary_multimodal(): | |
"""Enhanced summary generation considering both text and visual content""" | |
global pdf_text, extracted_images | |
if not pdf_text and not extracted_images: | |
return "β Please upload and process a PDF first." | |
try: | |
content_parts = [] | |
if pdf_text: | |
content_parts.append(f"Text Content:\n{pdf_text[:2000]}") | |
if extracted_images: | |
visual_summary = "\n".join([img['description'][:200] for img in extracted_images[:3]]) | |
content_parts.append(f"Visual Content:\n{visual_summary}") | |
combined_content = "\n\n".join(content_parts) | |
prompt = f"""Provide a comprehensive summary of this document that includes both textual and visual elements. | |
Focus on key findings, main topics, and insights from charts, tables, or images. | |
Content: {combined_content} | |
Summary:""" | |
response = text_client.chat_completion( | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=250, | |
temperature=0.5 | |
) | |
return response["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
return f"β Error generating summary: {str(e)}" | |
def extract_keywords_multimodal(): | |
"""Enhanced keyword extraction from both text and visual content""" | |
global pdf_text, extracted_images | |
if not pdf_text and not extracted_images: | |
return "β Please upload and process a PDF first." | |
try: | |
content_parts = [] | |
if pdf_text: | |
content_parts.append(f"Text: {pdf_text[:1500]}") | |
if extracted_images: | |
visual_content = "\n".join([img['description'][:150] for img in extracted_images]) | |
content_parts.append(f"Visual Content: {visual_content}") | |
combined_content = "\n\n".join(content_parts) | |
prompt = f"""Extract key terms, concepts, and topics from this document content. | |
Include technical terms, important concepts, and themes from both text and visual elements. | |
Content: {combined_content} | |
Key terms and concepts:""" | |
response = text_client.chat_completion( | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=120, | |
temperature=0.5 | |
) | |
return response["choices"][0]["message"]["content"].strip() | |
except Exception as e: | |
return f"β Error extracting keywords: {str(e)}" | |
def show_extracted_images(): | |
"""Display information about extracted images""" | |
global extracted_images | |
if not extracted_images: | |
return "No visual elements extracted from the current document." | |
info = f"π Extracted {len(extracted_images)} visual elements:\n\n" | |
for i, img in enumerate(extracted_images, 1): | |
element_type = "π Table" if img['type'] == 'table' else "πΌοΈ Image" | |
info += f"{i}. {element_type}: {img['filename']}\n" | |
info += f" Description: {img['description'][:150]}...\n\n" | |
if i >= 5: # Limit display to first 5 | |
remaining = len(extracted_images) - 5 | |
if remaining > 0: | |
info += f"... and {remaining} more visual elements." | |
break | |
return info | |
def clear_interface_multimodal(): | |
"""Enhanced clear function for multimodal system""" | |
global index, retriever, current_pdf_name, pdf_text, extracted_images | |
index = retriever = None | |
current_pdf_name = pdf_text = None | |
extracted_images = [] | |
if os.path.exists(FIGURES_DIR): | |
for file in os.listdir(FIGURES_DIR): | |
try: | |
os.remove(os.path.join(FIGURES_DIR, file)) | |
except: | |
pass | |
return None, "", gr.update(interactive=False), "", "", "", "", "" | |
# Enhanced Gradio UI | |
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue") | |
with gr.Blocks(theme=theme, css=""" | |
.container { border-radius: 10px; padding: 15px; } | |
.pdf-active { border-left: 3px solid #6366f1; padding-left: 10px; background-color: rgba(99,102,241,0.1); } | |
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; } | |
.main-title { | |
text-align: center; | |
font-size: 56px; | |
font-weight: bold; | |
margin-bottom: 20px; | |
background: linear-gradient(45deg, #6366f1, #8b5cf6, #ec4899); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.feature-badge { | |
background: linear-gradient(45deg, #10b981, #3b82f6); | |
color: white; | |
padding: 4px 12px; | |
border-radius: 15px; | |
font-size: 11px; | |
margin: 2px; | |
display: inline-block; | |
} | |
""") as demo: | |
gr.Markdown("<div class='main-title'>π€ DocQueryAI Pro</div>") | |
gr.Markdown(""" | |
<div style='text-align: center; margin-bottom: 25px;'> | |
<span class='feature-badge'>π Advanced RAG</span> | |
<span class='feature-badge'>πΌοΈ Vision AI</span> | |
<span class='feature-badge'>π Table Analysis</span> | |
<span class='feature-badge'>π Chart Understanding</span> | |
<span class='feature-badge'>π§ Smart Retrieval</span> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## π Document Processing") | |
pdf_display = gr.Textbox(label="Active Document", interactive=False, elem_classes="pdf-active") | |
pdf_file = gr.File(file_types=[".pdf"], type="filepath", label="Upload PDF Document") | |
upload_button = gr.Button("π Process with AI Vision", variant="primary", size="lg") | |
status_box = gr.Textbox(label="Processing Status", interactive=False, lines=3) | |
with gr.Column(): | |
gr.Markdown("## π¬ Intelligent Q&A") | |
gr.Markdown("*Ask about any content: text, images, charts, tables, or data visualizations*") | |
question_input = gr.Textbox( | |
lines=3, | |
placeholder="Examples:\nβ’ What does the chart show?\nβ’ Summarize the table data\nβ’ Explain the main findings", | |
label="Your Question" | |
) | |
ask_button = gr.Button("π Get AI Answer", variant="primary", size="lg") | |
answer_output = gr.Textbox(label="AI Response", lines=8, interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
summary_button = gr.Button("π Generate Summary", variant="secondary") | |
summary_output = gr.Textbox(label="Document Summary", lines=5, interactive=False) | |
with gr.Column(): | |
keywords_button = gr.Button("π·οΈ Extract Keywords", variant="secondary") | |
keywords_output = gr.Textbox(label="Key Concepts", lines=5, interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
structure_button = gr.Button("ποΈ Analyze Structure", variant="secondary") | |
structure_output = gr.Textbox(label="Document Structure Analysis", lines=5, interactive=False) | |
with gr.Column(): | |
images_button = gr.Button("πΌοΈ Show Visual Elements", variant="secondary") | |
images_output = gr.Textbox(label="Extracted Visual Elements", lines=5, interactive=False) | |
with gr.Row(): | |
clear_button = gr.Button("ποΈ Clear All", variant="secondary", size="sm") | |
gr.Markdown(""" | |
<div class='footer'> | |
π <strong>Powered by Advanced AI</strong><br> | |
π§ HuggingFace Transformers β’ LangChain β’ FAISS β’ Unstructured<br> | |
π― Multimodal RAG: Text + Vision + Tables + Charts | |
</div> | |
""") | |
# Event bindings | |
upload_button.click(process_pdf_multimodal_advanced, [pdf_file], [pdf_display, status_box, question_input]) | |
ask_button.click(ask_question_multimodal_advanced, [pdf_display, question_input], answer_output) | |
summary_button.click(generate_summary_multimodal, [], summary_output) | |
keywords_button.click(extract_keywords_multimodal, [], keywords_output) | |
structure_button.click(analyze_document_structure, [], structure_output) | |
images_button.click(show_extracted_images, [], images_output) | |
clear_button.click(clear_interface_multimodal, [], [pdf_file, pdf_display, question_input, answer_output, summary_output, keywords_output, structure_output, images_output]) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True) |