|
import asyncio |
|
import io |
|
import logging |
|
import os |
|
import threading |
|
import uuid |
|
|
|
import streamlit as st |
|
import torch |
|
from byaldi import RAGMultiModalModel |
|
from pdf2image import convert_from_bytes |
|
from PIL import Image |
|
from transformers import (AutoModelForVision2Seq, AutoProcessor, |
|
BitsAndBytesConfig) |
|
from transformers.image_utils import load_image |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state["session_id"] = str(uuid.uuid4()) |
|
|
|
|
|
|
|
async def load_model_embedding_async(): |
|
st.session_state["loading_model_embedding"] = True |
|
await asyncio.sleep(0.1) |
|
model_embedding = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2") |
|
st.session_state["model_embedding"] = model_embedding |
|
st.session_state["loading_model_embedding"] = False |
|
|
|
|
|
|
|
def load_model_embedding(): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(load_model_embedding_async()) |
|
|
|
|
|
if "model_embedding" not in st.session_state: |
|
with st.status("Loading embedding model... β³"): |
|
threading.Thread(target=load_model_embedding, daemon=True).start() |
|
|
|
|
|
|
|
|
|
async def load_model_vlm_async(): |
|
st.session_state["loading_model_vlm"] = True |
|
await asyncio.sleep(0.1) |
|
|
|
checkpoint = "HuggingFaceTB/SmolVLM-Instruct" |
|
processor_vlm = AutoProcessor.from_pretrained(checkpoint) |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
model_vlm = AutoModelForVision2Seq.from_pretrained( |
|
checkpoint, |
|
|
|
quantization_config=quantization_config, |
|
) |
|
|
|
st.session_state["model_vlm"] = model_vlm |
|
st.session_state["processor_vlm"] = processor_vlm |
|
st.session_state["loading_model_vlm"] = False |
|
|
|
|
|
|
|
def load_model_vlm(): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(load_model_vlm_async()) |
|
|
|
|
|
|
|
if "model_vlm" not in st.session_state: |
|
with st.status("Loading VLM model... β³"): |
|
threading.Thread(target=load_model_vlm, daemon=True).start() |
|
|
|
|
|
|
|
|
|
def save_images_to_local(dataset, output_folder="data/"): |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
for image_id, image in enumerate(dataset): |
|
|
|
|
|
|
|
output_path = os.path.join(output_folder, f"image_{image_id}.png") |
|
|
|
image.save(output_path, format="PNG") |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)" |
|
|
|
st.title("π Image Q&A with VLM") |
|
|
|
|
|
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf")) |
|
query = st.text_input( |
|
"Ask something about the image", |
|
placeholder="Can you describe me the image ?", |
|
disabled=not uploaded_pdf, |
|
) |
|
|
|
if st.session_state.get("loading_model_embedding", True): |
|
st.warning("Loading Embedding model....") |
|
else: |
|
st.success("Embedding Model loaded successfully! π") |
|
|
|
if st.session_state.get("loading_model_vlm", True): |
|
st.warning("Loading VLM model....") |
|
else: |
|
st.success("VLM Model loaded successfully! π") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images = [] |
|
images_folder = "data/" + st.session_state["session_id"] + "/" |
|
index_name = "index_" + st.session_state["session_id"] |
|
|
|
|
|
|
|
if uploaded_pdf and "model_embedding" in st.session_state and "is_index_complete" not in st.session_state: |
|
images = convert_from_bytes(uploaded_pdf.getvalue()) |
|
save_images_to_local(images, output_folder=images_folder) |
|
|
|
|
|
st.session_state["model_embedding"].index( |
|
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True |
|
) |
|
logging.info(f"{len(images)} number of images extracted from PDF and indexed") |
|
st.session_state["is_index_complete"] = True |
|
|
|
|
|
|
|
if uploaded_pdf and query and "model_embedding" in st.session_state and "model_vlm" in st.session_state: |
|
docs_retrieved = st.session_state["model_embedding"].search(query, k=1) |
|
logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query") |
|
image_id = docs_retrieved[0]["doc_id"] |
|
logging.info(f"Image id:{image_id} retrieved" ) |
|
image_similar_to_query = images[image_id] |
|
|
|
|
|
model_vlm, processor_vlm = st.session_state["model_vlm"], st.session_state["processor_vlm"] |
|
|
|
|
|
|
|
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context." |
|
chat_template = [ |
|
{"role": "system", "content": system_prompt}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": query} |
|
] |
|
}, |
|
] |
|
|
|
|
|
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True) |
|
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt") |
|
inputs = inputs.to(DEVICE) |
|
|
|
|
|
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500) |
|
|
|
|
|
generated_texts = processor_vlm.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False, |
|
) |
|
response = generated_texts[0] |
|
|
|
st.write(response) |
|
|