Spaces:
Running
Running
import torch | |
from typing import List | |
from PIL import Image | |
from transformers import ViTImageProcessor, ViTModel | |
from utils.logger import logger | |
from config.model_configs import IMAGE_EMBEDDING_MODEL | |
class ImageEmbeddingModel: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Loading Image Embedding Model '{IMAGE_EMBEDDING_MODEL}' to device: {self.device}") | |
self.model = ViTModel.from_pretrained(IMAGE_EMBEDDING_MODEL).to(self.device) | |
self.processor = ViTImageProcessor.from_pretrained(IMAGE_EMBEDDING_MODEL) | |
# Set model to evaluation mode | |
self.model.eval() | |
logger.info("Image Embedding Model loaded successfully.") | |
def get_embeddings(self, image_paths: List[str]) -> List[List[float]]: | |
if not image_paths: | |
logger.warning("No image paths provided") | |
return [] | |
images = [] | |
valid_paths = [] | |
for img_path in image_paths: | |
try: | |
image = Image.open(img_path).convert("RGB") | |
images.append(image) | |
valid_paths.append(img_path) | |
except Exception as e: | |
logger.warning(f"Could not load image {img_path}: {e}. Skipping.") | |
continue | |
if not images: | |
logger.warning("No valid images to process") | |
return [] | |
try: | |
# Process images | |
inputs = self.processor(images=images, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
# Get model outputs | |
outputs = self.model(**inputs) | |
# Extract embeddings from the [CLS] token (first token) | |
# Shape: (batch_size, sequence_length, hidden_size) | |
last_hidden_states = outputs.last_hidden_state | |
# Take the [CLS] token embedding (index 0) | |
# Shape: (batch_size, hidden_size) | |
cls_embeddings = last_hidden_states[:, 0, :] | |
# Alternatively, you can use pooler_output if available | |
# cls_embeddings = outputs.pooler_output | |
# Normalize embeddings (L2 normalization) | |
embeddings = cls_embeddings / cls_embeddings.norm(p=2, dim=-1, keepdim=True) | |
# Convert to list | |
embeddings_list = embeddings.cpu().tolist() | |
logger.debug(f"Generated {len(embeddings_list)} embeddings for {len(images)} images.") | |
# Ensure we return the right number of embeddings | |
if len(embeddings_list) != len(image_paths): | |
logger.warning(f"Mismatch: {len(embeddings_list)} embeddings for {len(image_paths)} input paths") | |
# Pad with empty lists if needed | |
while len(embeddings_list) < len(image_paths): | |
embeddings_list.append([]) | |
return embeddings_list | |
except Exception as e: | |
logger.error(f"Error generating embeddings: {e}") | |
# Return empty embeddings for all input paths | |
return [[] for _ in image_paths] |