shamik
feat: adding api key usage instead of using the default api key.
b84899f unverified
# %%
import base64
import concurrent.futures
import logging
import os
import shutil
from io import BytesIO
from pathlib import Path
import numpy as np
import torch
from colpali_engine.models import (
# ColPali,
# ColPaliProcessor,
ColQwen2_5,
ColQwen2_5_Processor,
)
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
# from dotenv import find_dotenv, load_dotenv
# from openai import OpenAI
from pdf2image import convert_from_path
from PIL import Image
from pymilvus import DataType, MilvusClient
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.insurance_assistants.consts import PROJECT_ROOT_DIR, PROMPT
# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# %%
model_name = "vidore/colqwen2.5-v0.2"
# model_name = "vidore/colpali-v1.2"
device = get_torch_device()
model = ColQwen2_5.from_pretrained(
# model = ColPali.from_pretrained(
pretrained_model_name_or_path=model_name,
# torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = ColQwen2_5_Processor.from_pretrained(
# processor = ColPaliProcessor.from_pretrained(
pretrained_model_name_or_path=model_name,
use_fast=True,
)
# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False))
# openai_client = OpenAI()
# %%
class MilvusManager:
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
"""
Initializes the MilvusManager.
Args:
milvus_uri (str): URI for Milvus server.
collection_name (str): Name of the collection.
create_collection (bool): Whether to create a new collection.
dim (int, optional): Dimension of the vector. Defaults to 128.
"""
self.client = MilvusClient(uri=milvus_uri)
self.collection_name = collection_name
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
self.dim = dim
self.max_doc_id = 0
if create_collection:
self.create_collection()
self.create_index()
def create_collection(self):
"""
Creates a new collection in Milvus. Drops existing collection if present.
"""
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
"""
Creates a vector index for the collection in Milvus.
"""
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(
collection_name=self.collection_name, index_name="vector"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="FLAT",
metric_type="IP",
params={
"M": 16,
"efConstruction": 500,
},
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
"""
Creates a scalar index for the doc_id field in Milvus.
"""
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED",
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
"""
Searches for the top-k most similar documents in Milvus.
Args:
data (np.ndarray): Query vector.
topk (int): Number of top results to return.
Returns:
list: List of (score, doc_id) tuples.
"""
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=50,
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
doc_colqwen_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc_colqwen_vecs[i]["vector"] for i in range(len(doc_colqwen_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, self.client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
if len(scores) >= topk:
return scores[:topk]
else:
return scores
def insert(self, data):
"""
Inserts a document's vectors and metadata into Milvus.
Args:
data (dict): Dictionary containing 'colqwen_vecs', 'doc_id', and 'filepath'.
"""
colqwen_vecs = [vec for vec in data["colqwen_vecs"]]
seq_length = len(colqwen_vecs)
doc_ids = [data["doc_id"] for _ in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [data["filepath"] for _ in range(seq_length)]
# docs = [""] * seq_length
# docs[0] = data["filepath"]
self.client.insert(
self.collection_name,
[
{
"vector": colqwen_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
def get_images_as_doc(self, images_with_vectors: list):
"""
Converts a list of image vectors and filepaths into Milvus insertable format.
Args:
images_with_vectors (list): List of dicts with 'colqwen_vecs' and 'filepath'.
Returns:
list: List of dicts ready for Milvus insertion.
"""
images_data = []
for i in range(len(images_with_vectors)):
self.max_doc_id += 1
data = {
"colqwen_vecs": images_with_vectors[i]["colqwen_vecs"],
"doc_id": self.max_doc_id,
"filepath": images_with_vectors[i]["filepath"],
}
images_data.append(data)
return images_data
def insert_images_data(self, image_data):
"""
Inserts multiple images' data into Milvus.
Args:
image_data (list): List of image data dicts.
"""
data = self.get_images_as_doc(image_data)
for i in range(len(data)):
self.insert(data[i])
# %%
class VectorProcessor:
def __init__(
self,
id: str,
create_collection=True,
):
"""
Initializes the VectorProcessor with Milvus, Colqwen, and PDF managers.
Args:
id (str): Unique identifier for the session/user.
create_collection (bool, optional): Whether to create a new collection. Defaults to True.
"""
# hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
# milvus_db_name = f"milvus_{hashed_id}.db"
milvus_db_name = (
PROJECT_ROOT_DIR / f"src/insurance_assistants/milvus_{id}.db"
).as_posix()
self.milvus_manager = MilvusManager(milvus_db_name, f"{id}", create_collection)
self.colqwen_manager = ColqwenManager()
self.pdf_manager = PdfManager()
def index(
self,
pdf_path: str,
id: str,
max_pages: int,
):
"""
Indexes a PDF file by converting pages to images, embedding them, and storing in Milvus.
Args:
pdf_path (str): Path to the PDF file.
id (str): Unique identifier.
max_pages (int): Maximum number of pages to process.
Returns:
list: List of saved image paths.
"""
logger.info(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
image_paths = self.pdf_manager.save_images(id, pdf_path, max_pages)
logger.info(f"Saved {len(image_paths)} images")
colqwen_vecs = self.colqwen_manager.process_images(image_paths)
images_data = [
{"colqwen_vecs": colqwen_vecs[i], "filepath": image_paths[i]}
for i in range(len(image_paths))
]
logger.info(f"Inserting {len(images_data)} images data to Milvus")
self.milvus_manager.insert_images_data(images_data)
logger.info("Indexing completed")
return image_paths
def search(self, search_queries: list[str]):
logger.info(f"Searching for {len(search_queries)} queries")
final_res = []
for query in search_queries:
logger.info(f"Searching for query: {query}")
query_vec = self.colqwen_manager.process_text([query])[0]
search_res = self.milvus_manager.search(query_vec, topk=4)
logger.info(f"Search result: {search_res} for query: {query}")
final_res.append(search_res)
return final_res
# %%
class PdfManager:
def __init__(self):
"""
Initializes the PdfManager.
"""
pass
def clear_and_recreate_dir(self, output_folder):
logger.info(f"Clearing output folder {output_folder}")
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
os.makedirs(output_folder)
def save_images(
self, id, pdf_path, max_pages, pages: list[int] = None, output_folder=None
) -> list[str]:
"""
Saves images of PDF pages to disk.
Args:
id (str): Unique identifier.
pdf_path (str): Path to the PDF file.
max_pages (int): Maximum number of pages to save.
pages (list[int], optional): Specific pages to save. Defaults to None.
Returns:
list[str]: List of saved image file paths.
"""
output_folder = (
Path(output_folder) if output_folder is not None else output_folder
)
if output_folder is None:
output_folder = PROJECT_ROOT_DIR / f"src/insurance_assistants/pages/{id}/"
if not output_folder.exists():
output_folder.mkdir(parents=True, exist_ok=True)
images = convert_from_path(pdf_path=pdf_path)
logger.info(
f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}"
)
# self.clear_and_recreate_dir(output_folder)
num_page_processed = 0
for i, image in enumerate(images):
if max_pages and num_page_processed >= max_pages:
break
if pages and i not in pages:
continue
full_save_path = output_folder / f"{id}_page_{i + 1}.png"
# logger.debug(f"Saving image to {full_save_path}")
image.save(fp=full_save_path, format="PNG")
num_page_processed += 1
return [
f"{output_folder}/{id}_page_{i + 1}.png" for i in range(num_page_processed)
]
# %%
class ColqwenManager:
def get_images(self, paths: list[str]) -> list[Image.Image]:
"""
Loads images from file paths.
Args:
paths (list[str]): List of image file paths.
Returns:
list[Image.Image]: List of PIL Image objects.
"""
return [Image.open(path) for path in paths]
def process_images(self, image_paths: list[str], batch_size=5):
logger.info(f"Processing {len(image_paths)} image_paths")
images = self.get_images(image_paths)
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to(device))))
ds_np = [d.float().cpu().numpy() for d in ds]
return ds_np
def process_text(self, texts: list[str]):
logger.info(f"Processing {len(texts)} texts")
dataloader = DataLoader(
dataset=ListDataset[str](texts),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to(device))))
qs_np = [q.float().cpu().numpy() for q in qs]
return qs_np
# %%
# def generate_uuid(state):
# """
# Generates or retrieves a UUID for the user session.
# Args:
# state (dict): State dictionary containing 'user_uuid'.
# Returns:
# str: UUID string.
# """
# # Check if UUID already exists in session state
# if state["user_uuid"] is None:
# # Generate a new UUID if not already set
# state["user_uuid"] = str(uuid.uuid4())
# return state["user_uuid"]
class RAG:
def __init__(self):
"""
Initializes the RAG.
"""
self.vectordb_id = None
self.img_path_dir = PROJECT_ROOT_DIR / "src/insurance_assistants/pages/"
def create_vector_db(
self,
vectordb_id="policy_wordings",
dir=PROJECT_ROOT_DIR / "data",
max_pages=200,
):
"""
Uploads a PDF file, converts it to images, and indexes it.
Args:
state (dict): State dictionary for user session.
file: Uploaded file object.
max_pages (int, optional): Maximum number of pages to process. Defaults to 100.
Returns:
str: Status message.
"""
logger.info(f"Converting files in: {dir}.")
try:
for idx, f in enumerate((dir / "policy_wordings").iterdir()):
if idx == 0:
vectorprocessor = VectorProcessor(
id=vectordb_id, create_collection=True
)
self.vectordb_id = vectordb_id
_ = vectorprocessor.index(pdf_path=f, id=f.stem, max_pages=max_pages)
vectorprocessor.milvus_manager.client.close()
return f"βœ… Created the vector_db: milvus_{vectordb_id} under `src` dir."
except Exception as err:
return f"❌ Error creating vector_db: {err}"
def search_documents(self, query):
if self.vectordb_id is None:
raise Exception(
"Create the vector db first by invoking `create_vector_db`."
)
try:
vectorprocessor = VectorProcessor(
id=self.vectordb_id, create_collection=False
)
search_results = vectorprocessor.search(search_queries=[query])[0]
check_res = vectorprocessor.milvus_manager.client.query(
collection_name=self.vectordb_id,
filter=f"doc_id in {[d[1] for d in search_results]}",
output_fields=["doc_id", "doc"],
)
vectorprocessor.milvus_manager.client.close()
img_path_doc_id = set((i["doc"], i["doc_id"]) for i in check_res)
logger.info("βœ… Retrieved the images for answering query.")
return img_path_doc_id
except Exception as err:
return f"❌ Error during search: {err}"
def encode_image_to_base64(self, image_path):
"""
Encodes an image file to a base64 string.
Args:
image_path (str): Path to the image file.
Returns:
str: Base64 encoded image string.
"""
image = Image.open(image_path)
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def query_gpt4o_mini(self, query, image_path):
"""
Queries the OpenAI GPT-4o-mini model with a query and images.
Args:
query (str): The user query.
image_path (list): List of image file paths.
Returns:
str: The AI response.
"""
try:
base64_images = [self.encode_image_to_base64(pth) for pth in image_path]
response = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": PROMPT.format(query=query)}
]
+ [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{im}"},
}
for im in base64_images
],
}
],
max_tokens=500,
)
return response.choices[0].message.content
except Exception as err:
return f"Unable to generate the final output due to: {err}."