Spaces:
Sleeping
Sleeping
import os | |
import io | |
import base64 | |
import gc | |
from huggingface_hub.utils import HfHubHTTPError | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
import io, base64 | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import spaces | |
import numpy as np | |
import pandas as pd | |
import pymupdf | |
from PIL import Image | |
from pypdf import PdfReader | |
from dotenv import load_dotenv | |
import shutil | |
from chromadb.config import Settings | |
from welcome_text import WELCOME_INTRO | |
from doctr.io import DocumentFile | |
from doctr.models import ocr_predictor | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
from langchain_core.prompts import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEndpoint | |
from utils import extract_pdfs, extract_images, clean_text, image_to_bytes | |
from utils import * | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Load .env | |
load_dotenv() | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
processor = None | |
vision_model = None | |
# OCR + multimodal image description setup | |
ocr_model = ocr_predictor( | |
"db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True | |
) | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cuda") | |
# Add at the top of your module, alongside your other globals | |
CURRENT_VDB = None | |
def get_image_description(image: Image.Image) -> str: | |
""" | |
Lazy-loads the Llava processor + model inside the GPU worker, | |
runs captioning, and returns a one-sentence description. | |
""" | |
global processor, vision_model | |
# On first call, instantiate + move to CUDA | |
if processor is None or vision_model is None: | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cuda") | |
torch.cuda.empty_cache() | |
gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = processor(prompt, image, return_tensors="pt").to("cuda") | |
output = vision_model.generate(**inputs, max_new_tokens=100) | |
return processor.decode(output[0], skip_special_tokens=True) | |
# Vector DB setup | |
# at top of file, alongside your other imports | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
import chromadb | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from utils import image_to_bytes # your helper | |
# 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim) | |
SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]): | |
""" | |
Build an in-memory ChromaDB instance with two collections: | |
• text_db (chunks of the PDF text) | |
• image_db (image descriptions + raw image bytes) | |
Returns the Chroma client for later querying. | |
""" | |
# ——— 1) Init & wipe old ———————————————— | |
client = chromadb.EphemeralClient() | |
for col in ("text_db", "image_db"): | |
if col in [c.name for c in client.list_collections()]: | |
client.delete_collection(col) | |
# ——— 2) Create fresh collections ————————— | |
text_col = client.get_or_create_collection( | |
name="text_db", | |
embedding_function=SHARED_EMB_FN, | |
data_loader=ImageLoader(), # loader only matters for images, benign here | |
) | |
img_col = client.get_or_create_collection( | |
name="image_db", | |
embedding_function=SHARED_EMB_FN, | |
metadata={"hnsw:space": "cosine"}, | |
data_loader=ImageLoader(), | |
) | |
# ——— 3) Add images if any ——————————————— | |
if images: | |
descs = [] | |
metas = [] | |
for idx, img in enumerate(images): | |
# build one-line caption (or fallback) | |
try: | |
caption = get_image_description(img) | |
except Exception: | |
caption = "⚠️ could not describe image" | |
descs.append(f"{img_names[idx]}: {caption}") | |
metas.append({"image": image_to_bytes(img)}) | |
img_col.add( | |
ids=[str(i) for i in range(len(images))], | |
documents=descs, | |
metadatas=metas, | |
) | |
# ——— 4) Chunk & add text ——————————————— | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = splitter.create_documents([text]) | |
text_col.add( | |
ids=[str(i) for i in range(len(docs))], | |
documents=[d.page_content for d in docs], | |
) | |
return client | |
# Text extraction | |
def result_to_text(result, as_text=False): | |
pages = [] | |
for pg in result.pages: | |
txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words) | |
pages.append(clean_text(txt)) | |
return "\n\n".join(pages) if as_text else pages | |
OCR_CHOICES = { | |
"db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"), | |
"db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"), | |
} | |
def extract_data_from_pdfs( | |
docs: list[str], | |
session: dict, | |
include_images: str, # "Include Images" or "Exclude Images" | |
do_ocr: str, # "Get Text With OCR" or "Get Available Text Only" | |
ocr_choice: str, # key into OCR_CHOICES | |
vlm_choice: str, # HF repo ID for LlavaNext | |
progress=gr.Progress() | |
): | |
""" | |
1) (Optional) OCR setup | |
2) Vision+Lang model setup & monkey-patch get_image_description | |
3) Extract text & images | |
4) Build and stash vector DB in CURRENT_VDB | |
""" | |
if not docs: | |
raise gr.Error("No documents to process") | |
# 1) OCR pipeline if requested | |
if do_ocr == "Get Text With OCR": | |
db_m, crnn_m = OCR_CHOICES[ocr_choice] | |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True) | |
else: | |
local_ocr = None | |
# 2) Vision–language model | |
proc = LlavaNextProcessor.from_pretrained(vlm_choice) | |
vis = ( | |
LlavaNextForConditionalGeneration | |
.from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
.to("cuda") | |
) | |
# Monkey-patch our pipeline for image captions | |
def describe(img: Image.Image) -> str: | |
torch.cuda.empty_cache() | |
gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = proc(prompt, img, return_tensors="pt").to("cuda") | |
output = vis.generate(**inputs, max_new_tokens=100) | |
return proc.decode(output[0], skip_special_tokens=True) | |
global get_image_description, CURRENT_VDB | |
get_image_description = describe | |
# 3) Extract text + images | |
progress(0.2, "Extracting text and images…") | |
all_text = "" | |
images, names = [], [] | |
for path in docs: | |
if local_ocr: | |
pdf = DocumentFile.from_pdf(path) | |
res = local_ocr(pdf) | |
all_text += result_to_text(res, as_text=True) + "\n\n" | |
else: | |
txt = PdfReader(path).pages[0].extract_text() or "" | |
all_text += txt + "\n\n" | |
if include_images == "Include Images": | |
imgs = extract_images([path]) | |
images.extend(imgs) | |
names.extend([os.path.basename(path)] * len(imgs)) | |
# 4) Build + store the vector DB | |
progress(0.6, "Indexing in vector DB…") | |
CURRENT_VDB = get_vectordb(all_text, images, names) | |
session["processed"] = True | |
sample_imgs = images[:4] if include_images == "Include Images" else [] | |
# ─── return *exactly four* picklable outputs ─── | |
return ( | |
session, # gr.State: so UI knows we're ready | |
all_text[:2000] + "...", # preview text | |
sample_imgs, # preview images | |
"<h3>Done!</h3>" # Done message | |
) | |
# Chat function | |
def conversation( | |
session: dict, | |
question: str, | |
num_ctx: int, | |
img_ctx: int, | |
history: list, | |
temp: float, | |
max_tok: int, | |
model_id: str | |
): | |
""" | |
Uses the global CURRENT_VDB (set by extract_data_from_pdfs) to answer. | |
""" | |
global CURRENT_VDB | |
if not session.get("processed") or CURRENT_VDB is None: | |
raise gr.Error("Please extract data first") | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, | |
temperature=temp, | |
max_new_tokens=max_tok, | |
huggingfacehub_api_token=HF_TOKEN | |
) | |
# 1) Text retrieval | |
text_col = CURRENT_VDB.get_collection("text_db") | |
docs = text_col.query( | |
query_texts=[question], | |
n_results=int(num_ctx), | |
include=["documents"] | |
)["documents"][0] | |
# 2) Image retrieval | |
img_col = CURRENT_VDB.get_collection("image_db") | |
img_q = img_col.query( | |
query_texts=[question], | |
n_results=int(img_ctx), | |
include=["metadatas", "documents"] | |
) | |
img_descs = img_q["documents"][0] or ["No images found"] | |
images = [] | |
for meta in img_q["metadatas"][0]: | |
b64 = meta.get("image", "") | |
try: | |
images.append(Image.open(io.BytesIO(base64.b64decode(b64)))) | |
except: | |
pass | |
img_desc = "\n".join(img_descs) | |
# 3) Build prompt & call LLM | |
prompt = PromptTemplate( | |
template=""" | |
Context: | |
{text} | |
Included Images: | |
{img_desc} | |
Question: | |
{q} | |
Answer: | |
""", | |
input_variables=["text", "img_desc", "q"], | |
) | |
user_input = prompt.format( | |
text="\n\n".join(docs), | |
img_desc=img_desc, | |
q=question | |
) | |
try: | |
answer = llm.invoke(user_input) | |
except HfHubHTTPError as e: | |
if e.response.status_code == 404: | |
answer = f"❌ Model `{model_id}` not hosted on HF Inference API." | |
else: | |
answer = f"⚠️ HF API error: {e}" | |
except Exception as e: | |
answer = f"⚠️ Unexpected error: {e}" | |
new_history = history + [ | |
{"role": "user", "content": question}, | |
{"role": "assistant", "content": answer} | |
] | |
return new_history, docs, images | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Gradio UI | |
CSS = """ | |
footer {visibility:hidden;} | |
""" | |
MODEL_OPTIONS = [ | |
"HuggingFaceH4/zephyr-7b-beta", | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"openchat/openchat-3.5-0106", | |
"google/gemma-7b-it", | |
"deepseek-ai/deepseek-llm-7b-chat", | |
"microsoft/Phi-3-mini-4k-instruct", | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Qwen/Qwen1.5-7B-Chat", | |
"tiiuae/falcon-7b-instruct", # Falcon 7B Instruct | |
"bigscience/bloomz-7b1", # BLOOMZ 7B | |
"facebook/opt-2.7b", | |
] | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
# State to track that extraction completed (and carry any metadata) | |
session_state = gr.State({}) | |
# ─── Welcome Screen ───────────────────────────────────────────── | |
with gr.Column(visible=True) as welcome_col: | |
gr.Markdown( | |
f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>", | |
elem_id="welcome_md" | |
) | |
start_btn = gr.Button("🚀 Start") | |
# ─── Main App (hidden until Start is clicked) ─────────────────── | |
with gr.Column(visible=False) as app_col: | |
gr.Markdown("## 📚 Multimodal Chat-PDF Playground") | |
# We need to capture the extract‐event so we can chain the “show chat tab” later | |
extract_event = None | |
with gr.Tabs() as tabs: | |
# ── Tab 1: Upload & Extract ─────────────────────────────── | |
with gr.TabItem("1. Upload & Extract"): | |
docs = gr.File( | |
file_count="multiple", | |
file_types=[".pdf"], | |
label="Upload PDFs" | |
) | |
include_dd = gr.Radio( | |
["Include Images", "Exclude Images"], | |
value="Exclude Images", | |
label="Images" | |
) | |
ocr_radio = gr.Radio( | |
["Get Text With OCR", "Get Available Text Only"], | |
value="Get Available Text Only", | |
label="OCR" | |
) | |
ocr_dd = gr.Dropdown( | |
choices=list(OCR_CHOICES.keys()), | |
value=list(OCR_CHOICES.keys())[0], | |
label="OCR Model" | |
) | |
vlm_dd = gr.Dropdown( | |
choices=[ | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
"llava-hf/llava-v1.5-mistral-7b" | |
], | |
value="llava-hf/llava-v1.6-mistral-7b-hf", | |
label="Vision-Language Model" | |
) | |
extract_btn = gr.Button("Extract") | |
preview_text = gr.Textbox( | |
lines=10, | |
label="Sample Text", | |
interactive=False | |
) | |
preview_img = gr.Gallery( | |
label="Sample Images", | |
rows=2, | |
value=[] | |
) | |
preview_html = gr.HTML() | |
# Kick off extraction and capture the event | |
extract_event = extract_btn.click( | |
fn=extract_data_from_pdfs, | |
inputs=[ | |
docs, | |
session_state, | |
include_dd, | |
ocr_radio, | |
ocr_dd, | |
vlm_dd | |
], | |
outputs=[ | |
session_state, # sets session["processed"]=True | |
preview_text, # shows first bits of text | |
preview_img, # shows first images | |
preview_html # shows “<h3>Done!</h3>” | |
] | |
) | |
# ── Tab 2: Chat (initially hidden) ────────────────────────── | |
with gr.TabItem("2. Chat", visible=False) as chat_tab: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chat = gr.Chatbot(type="messages", label="Chat") | |
msg = gr.Textbox( | |
placeholder="Ask about your PDF...", | |
label="Your question" | |
) | |
send = gr.Button("Send") | |
with gr.Column(scale=1): | |
model_dd = gr.Dropdown( | |
MODEL_OPTIONS, | |
value=MODEL_OPTIONS[0], | |
label="Choose Chat Model" | |
) | |
num_ctx = gr.Slider(1, 20, value=3, label="Text Contexts") | |
img_ctx = gr.Slider(1, 10, value=2, label="Image Contexts") | |
temp = gr.Slider(0.1, 1.0, step=0.1, value=0.4, label="Temperature") | |
max_tok = gr.Slider(10, 1000, step=10, value=200, label="Max Tokens") | |
send.click( | |
fn=conversation, | |
inputs=[ | |
session_state, | |
msg, | |
num_ctx, | |
img_ctx, | |
chat, | |
temp, | |
max_tok, | |
model_dd | |
], | |
outputs=[ | |
chat, | |
gr.Dataframe(), # shows retrieved text chunks | |
gr.Gallery(label="Relevant Images", rows=2, value=[]) | |
] | |
) | |
# After both tabs are defined, chain the “unhide chat tab” event | |
extract_event.then( | |
fn=lambda: gr.update(visible=True), | |
inputs=[], | |
outputs=[chat_tab] | |
) | |
gr.HTML("<center>Made with ❤️ by Zamal</center>") | |
# ─── Wire the Start button ─────────────────────────────────────── | |
start_btn.click( | |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
inputs=[], | |
outputs=[welcome_col, app_col] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |