colpali-backend-api / embedding_api.py
vk98's picture
Add debug Dockerfile with better error handling and logging
9fc7504
#!/usr/bin/env python3
"""
ColPali Embedding API for generating query embeddings
"""
import os
import logging
import numpy as np
from pathlib import Path
from typing import List, Dict
from fastapi import FastAPI, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
from PIL import Image
import uvicorn
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
# Set HF token if available
hf_token = os.environ.get("HUGGING_FACE_TOKEN") or os.environ.get("HF_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI
app = FastAPI(title="ColPali Embedding API")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:3025", "http://localhost:4000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model variables
model = None
processor = None
device = None
MAX_QUERY_TERMS = 64
def load_model():
"""Load ColPali model and processor"""
global model, processor, device
if model is None:
logger.info("Loading ColPali model...")
device = get_torch_device("auto")
logger.info(f"Using device: {device}")
try:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device,
token=hf_token
).eval()
processor = ColPaliProcessor.from_pretrained(model_name, token=hf_token)
logger.info("ColPali model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
# Try alternative model
model_name = "vidore/colpaligemma-3b-pt-448-base"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
logger.info(f"Loaded alternative model: {model_name}")
return model, processor
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "healthy", "service": "colpali-embedding-api"}
@app.get("/embed_query")
async def embed_query(
query: str = Query(..., description="Text query to embed")
):
"""Generate ColPali embeddings for a text query"""
try:
model, processor = load_model()
# Create a dummy image for text-only queries
# ColPali expects image inputs, so we use a white image
dummy_image = Image.new('RGB', (448, 448), color='white')
# Process query with dummy image
inputs = processor(
images=[dummy_image],
text=[query],
return_tensors="pt",
padding=True
).to(device)
# Generate embeddings
with torch.no_grad():
embeddings = model(**inputs) # Direct output, not .last_hidden_state
# Process embeddings for Vespa format
# Extract query embeddings (text tokens)
query_embeddings = embeddings[0] # First item in batch
# Convert to list format expected by Vespa
float_query_embedding = {}
binary_query_embeddings = {}
for idx in range(min(query_embeddings.shape[0], MAX_QUERY_TERMS)):
embedding_vector = query_embeddings[idx].cpu().numpy().tolist()
float_query_embedding[str(idx)] = embedding_vector
# Create binary version
binary_vector = (
np.packbits(np.where(np.array(embedding_vector) > 0, 1, 0))
.astype(np.int8)
.tolist()
)
binary_query_embeddings[str(idx)] = binary_vector
return {
"query": query,
"embeddings": {
"float": float_query_embedding,
"binary": binary_query_embeddings
},
"num_tokens": len(float_query_embedding)
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/embed_query_simple")
async def embed_query_simple(
query: str = Query(..., description="Text query to embed")
):
"""Generate simplified embeddings for text query (for testing)"""
try:
# For testing, return mock embeddings
# In production, this would use the actual ColPali model
mock_embedding = [0.1] * 128 # 128-dimensional embedding
return {
"query": query,
"embedding": mock_embedding,
"model": "colpali-v1.2"
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
port = int(os.getenv("EMBEDDING_PORT", "7861"))
logger.info(f"Starting ColPali Embedding API on port {port}")
# Pre-load model
load_model()
uvicorn.run(app, host="0.0.0.0", port=port)