File size: 6,164 Bytes
a90237d e36ac67 0e9898e 1564fda a90237d c10d965 e36ac67 f82c4d1 5b759af 1564fda e36ac67 5b759af e36ac67 f82c4d1 0e9898e a90237d 0e9898e c10d965 a90237d 1564fda a90237d f82c4d1 a90237d 04743bf a90237d 04743bf a90237d 04743bf a90237d f82c4d1 1564fda 6298b44 1564fda e36ac67 f82c4d1 04743bf 1564fda f82c4d1 a90237d 0e9898e 1564fda f82c4d1 e36ac67 1564fda f82c4d1 a90237d 1564fda c10d965 cfeb389 a90237d 1564fda a90237d 1564fda a90237d c10d965 1564fda 0e9898e cfeb389 f82c4d1 1564fda a90237d 0e9898e f82c4d1 a90237d f82c4d1 e36ac67 1564fda e36ac67 f82c4d1 1564fda f82c4d1 1564fda f82c4d1 1564fda f82c4d1 1564fda f82c4d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import asyncio
import io
import logging
import os
import threading
import uuid
import streamlit as st
import torch
from byaldi import RAGMultiModalModel
from pdf2image import convert_from_bytes
from PIL import Image
from transformers import (AutoModelForVision2Seq, AutoProcessor,
BitsAndBytesConfig)
from transformers.image_utils import load_image
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Capture logs
#log_stream = io.StringIO()
#logging.basicConfig(stream=log_stream, level=logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4()) # Generate unique session ID
# Async function to load the model
async def load_model_embedding_async():
st.session_state["loading_model_embedding"] = True # Show loading status
await asyncio.sleep(0.1) # Allow UI updates
model_embedding = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
st.session_state["model_embedding"] = model_embedding
st.session_state["loading_model_embedding"] = False # Model is ready
# Function to run async function in a separate thread
def load_model_embedding():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(load_model_embedding_async())
# Start model loading in a background thread
if "model_embedding" not in st.session_state:
with st.status("Loading embedding model... β³"):
threading.Thread(target=load_model_embedding, daemon=True).start()
# Async function to load the model
async def load_model_vlm_async():
st.session_state["loading_model_vlm"] = True # Show loading status
await asyncio.sleep(0.1) # Allow UI updates
checkpoint = "HuggingFaceTB/SmolVLM-Instruct"
processor_vlm = AutoProcessor.from_pretrained(checkpoint)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_vlm = AutoModelForVision2Seq.from_pretrained(
checkpoint,
#torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
st.session_state["model_vlm"] = model_vlm
st.session_state["processor_vlm"] = processor_vlm
st.session_state["loading_model_vlm"] = False # Model is ready
# Function to run async function in a separate thread
def load_model_vlm():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(load_model_vlm_async())
# Start model loading in a background thread
if "model_vlm" not in st.session_state:
with st.status("Loading VLM model... β³"):
threading.Thread(target=load_model_vlm, daemon=True).start()
def save_images_to_local(dataset, output_folder="data/"):
os.makedirs(output_folder, exist_ok=True)
for image_id, image in enumerate(dataset):
#if isinstance(image, str):
# image = Image.open(image)
output_path = os.path.join(output_folder, f"image_{image_id}.png")
#image = Image.open(io.BytesIO(image_data))
image.save(output_path, format="PNG")
# Home page UI
with st.sidebar:
"[Source Code](https://huggingface.co/spaces/deepakkarkala/multimodal-rag/tree/main)"
st.title("π Image Q&A with VLM")
#st.text_area("Logs:", log_stream.getvalue(), height=200)
uploaded_pdf = st.file_uploader("Upload PDF file", type=("pdf"))
query = st.text_input(
"Ask something about the image",
placeholder="Can you describe me the image ?",
disabled=not uploaded_pdf,
)
if st.session_state.get("loading_model_embedding", True):
st.warning("Loading Embedding model....")
else:
st.success("Embedding Model loaded successfully! π")
if st.session_state.get("loading_model_vlm", True):
st.warning("Loading VLM model....")
else:
st.success("VLM Model loaded successfully! π")
images = []
images_folder = "data/" + st.session_state["session_id"] + "/"
index_name = "index_" + st.session_state["session_id"]
if uploaded_pdf and "model_embedding" in st.session_state and "is_index_complete" not in st.session_state:
images = convert_from_bytes(uploaded_pdf.getvalue())
save_images_to_local(images, output_folder=images_folder)
# index documents using the document retrieval model
st.session_state["model_embedding"].index(
input_path=images_folder, index_name=index_name, store_collection_with_index=False, overwrite=True
)
logging.info(f"{len(images)} number of images extracted from PDF and indexed")
st.session_state["is_index_complete"] = True
if uploaded_pdf and query and "model_embedding" in st.session_state and "model_vlm" in st.session_state:
docs_retrieved = st.session_state["model_embedding"].search(query, k=1)
logging.info(f"{len(docs_retrieved)} number of images retrieved as relevant to query")
image_id = docs_retrieved[0]["doc_id"]
logging.info(f"Image id:{image_id} retrieved" )
image_similar_to_query = images[image_id]
model_vlm, processor_vlm = st.session_state["model_vlm"], st.session_state["processor_vlm"]
# Create input messages
system_prompt = "You are an AI assistant. Your task is reply to user questions based on the provided image context."
chat_template = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": query}
]
},
]
# Prepare inputs
prompt = processor_vlm.apply_chat_template(chat_template, add_generation_prompt=True)
inputs = processor_vlm(text=prompt, images=[image_similar_to_query], return_tensors="pt")
inputs = inputs.to(DEVICE)
# Generate outputs
generated_ids = model_vlm.generate(**inputs, max_new_tokens=500)
#generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
generated_texts = processor_vlm.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
response = generated_texts[0]
st.write(response)
|