#!/usr/bin/env python3 """ ColPali Embedding API for generating query embeddings """ import os import logging import numpy as np from pathlib import Path from typing import List, Dict from fastapi import FastAPI, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware import torch from PIL import Image import uvicorn from colpali_engine.models import ColPali, ColPaliProcessor from colpali_engine.utils.torch_utils import get_torch_device # Set HF token if available hf_token = os.environ.get("HUGGING_FACE_TOKEN") or os.environ.get("HF_TOKEN") if hf_token: os.environ["HF_TOKEN"] = hf_token # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize FastAPI app = FastAPI(title="ColPali Embedding API") # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", "http://localhost:3025", "http://localhost:4000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model variables model = None processor = None device = None MAX_QUERY_TERMS = 64 def load_model(): """Load ColPali model and processor""" global model, processor, device if model is None: logger.info("Loading ColPali model...") device = get_torch_device("auto") logger.info(f"Using device: {device}") try: model_name = "vidore/colpali-v1.2" model = ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map=device, token=hf_token ).eval() processor = ColPaliProcessor.from_pretrained(model_name, token=hf_token) logger.info("ColPali model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") # Try alternative model model_name = "vidore/colpaligemma-3b-pt-448-base" model = ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map=device ).eval() processor = ColPaliProcessor.from_pretrained(model_name) logger.info(f"Loaded alternative model: {model_name}") return model, processor @app.get("/health") async def health(): """Health check endpoint""" return {"status": "healthy", "service": "colpali-embedding-api"} @app.get("/embed_query") async def embed_query( query: str = Query(..., description="Text query to embed") ): """Generate ColPali embeddings for a text query""" try: model, processor = load_model() # Create a dummy image for text-only queries # ColPali expects image inputs, so we use a white image dummy_image = Image.new('RGB', (448, 448), color='white') # Process query with dummy image inputs = processor( images=[dummy_image], text=[query], return_tensors="pt", padding=True ).to(device) # Generate embeddings with torch.no_grad(): embeddings = model(**inputs) # Direct output, not .last_hidden_state # Process embeddings for Vespa format # Extract query embeddings (text tokens) query_embeddings = embeddings[0] # First item in batch # Convert to list format expected by Vespa float_query_embedding = {} binary_query_embeddings = {} for idx in range(min(query_embeddings.shape[0], MAX_QUERY_TERMS)): embedding_vector = query_embeddings[idx].cpu().numpy().tolist() float_query_embedding[str(idx)] = embedding_vector # Create binary version binary_vector = ( np.packbits(np.where(np.array(embedding_vector) > 0, 1, 0)) .astype(np.int8) .tolist() ) binary_query_embeddings[str(idx)] = binary_vector return { "query": query, "embeddings": { "float": float_query_embedding, "binary": binary_query_embeddings }, "num_tokens": len(float_query_embedding) } except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/embed_query_simple") async def embed_query_simple( query: str = Query(..., description="Text query to embed") ): """Generate simplified embeddings for text query (for testing)""" try: # For testing, return mock embeddings # In production, this would use the actual ColPali model mock_embedding = [0.1] * 128 # 128-dimensional embedding return { "query": query, "embedding": mock_embedding, "model": "colpali-v1.2" } except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": port = int(os.getenv("EMBEDDING_PORT", "7861")) logger.info(f"Starting ColPali Embedding API on port {port}") # Pre-load model load_model() uvicorn.run(app, host="0.0.0.0", port=port)