multimodal-rag / app.py
deepakkarkala's picture
Loading model async
a90237d
import asyncio
import io
import logging
import os
import threading
import uuid
import streamlit as st
import torch
from byaldi import RAGMultiModalModel
from pdf2image import convert_from_bytes
from PIL import Image
from transformers import (AutoModelForVision2Seq, AutoProcessor,
BitsAndBytesConfig)
from transformers.image_utils import load_image
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Capture logs
#log_stream = io.StringIO()
#logging.basicConfig(stream=log_stream, level=logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
# Async function to load the model
async def load_model_embedding_async():
st.session_state["loading_model_embedding"] = True # Show loading status
await asyncio.sleep(0.1) # Allow UI updates
model_embedding = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
st.session_state["model_embedding"] = model_embedding
st.session_state["loading_model_embedding"] = False # Model is ready
# Function to run async function in a separate thread
def load_model_embedding():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(load_model_embedding_async())
# Start model loading in a background thread
if "model_embedding" not in st.session_state:
with st.status("Loading embedding model... ⏳"):
threading.Thread(target=load_model_embedding, daemon=True).start()
# Async function to load the model
async def load_model_vlm_async():
st.session_state["loading_model_vlm"] = True # Show loading status
await asyncio.sleep(0.1) # Allow UI updates
checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
processor_vlm = AutoProcessor.from_pretrained(checkpoint)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_vlm = AutoModelForVision2Seq.from_pretrained(
checkpoint,
#torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
st.session_state["model_vlm"] = model_vlm
st.session_state["processor_vlm"] = processor_vlm
st.session_state["loading_model_vlm"] = False # Model is ready
# Function to run async function in a separate thread
def load_model_vlm():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(load_model_vlm_async())
# Start model loading in a background thread
if "model_vlm" not in st.session_state:
with st.status("Loading VLM model... ⏳"):
threading.Thread(target=load_model_vlm, daemon=True).start()
def save_images_to_local(dataset, output_folder="data/"):
os.makedirs(output_folder, exist_ok=True)
for image_id, image in enumerate(dataset):
#if isinstance(image, str):
# image = Image.open(image)
output_path = os.path.join(output_folder, f"image_{image_id}.png")
#image = Image.open(io.BytesIO(image_data))
image.save(output_path, format="PNG")
# Home page UI
with st.sidebar:
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
st.title("πŸ“ Image Q&A with VLM")
#st.text_area("Logs:", log_stream.getvalue(), height=200)
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
query = st.text_input(
"Ask something about the image",
placeholder="Can you describe me the image ?",
disabled=not uploaded_pdf,
)
if st.session_state.get("loading_model_embedding", True):
st.warning("Loading Embedding model....")
else:
st.success("Embedding Model loaded successfully! πŸŽ‰")
if st.session_state.get("loading_model_vlm", True):
st.warning("Loading VLM model....")
else:
st.success("VLM Model loaded successfully! πŸŽ‰")
images = []
images_folder = "data/" + st.session_state["session_id"] + "/"
index_name = "index_" + st.session_state["session_id"]
if uploaded_pdf and "model_embedding" in st.session_state and "is_index_complete" not in st.session_state:
images = convert_from_bytes(uploaded_pdf.getvalue())
save_images_to_local(images, output_folder=images_folder)
# index documents using the document retrieval model
st.session_state["model_embedding"].index(
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
)
logging.info(f"{len(images)} number of images extracted from PDF and indexed")
st.session_state["is_index_complete"] = True
if uploaded_pdf and query and "model_embedding" in st.session_state and "model_vlm" in st.session_state:
docs_retrieved = st.session_state["model_embedding"].search(query, k=1)
logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query")
image_id = docs_retrieved[0]["doc_id"]
logging.info(f"Image id:{image_id} retrieved" )
image_similar_to_query = images[image_id]
model_vlm, processor_vlm = st.session_state["model_vlm"], st.session_state["processor_vlm"]
# Create input messages
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
chat_template = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": query}
]
},
]
# Prepare inputs
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True)
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500)
#generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
generated_texts = processor_vlm.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
response = generated_texts[0]
st.write(response)