Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
ColPali Embedding API for generating query embeddings | |
""" | |
import os | |
import logging | |
import numpy as np | |
from pathlib import Path | |
from typing import List, Dict | |
from fastapi import FastAPI, Query, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from PIL import Image | |
import uvicorn | |
from colpali_engine.models import ColPali, ColPaliProcessor | |
from colpali_engine.utils.torch_utils import get_torch_device | |
# Set HF token if available | |
hf_token = os.environ.get("HUGGING_FACE_TOKEN") or os.environ.get("HF_TOKEN") | |
if hf_token: | |
os.environ["HF_TOKEN"] = hf_token | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI | |
app = FastAPI(title="ColPali Embedding API") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000", "http://localhost:3025", "http://localhost:4000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global model variables | |
model = None | |
processor = None | |
device = None | |
MAX_QUERY_TERMS = 64 | |
def load_model(): | |
"""Load ColPali model and processor""" | |
global model, processor, device | |
if model is None: | |
logger.info("Loading ColPali model...") | |
device = get_torch_device("auto") | |
logger.info(f"Using device: {device}") | |
try: | |
model_name = "vidore/colpali-v1.2" | |
model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map=device, | |
token=hf_token | |
).eval() | |
processor = ColPaliProcessor.from_pretrained(model_name, token=hf_token) | |
logger.info("ColPali model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
# Try alternative model | |
model_name = "vidore/colpaligemma-3b-pt-448-base" | |
model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map=device | |
).eval() | |
processor = ColPaliProcessor.from_pretrained(model_name) | |
logger.info(f"Loaded alternative model: {model_name}") | |
return model, processor | |
async def health(): | |
"""Health check endpoint""" | |
return {"status": "healthy", "service": "colpali-embedding-api"} | |
async def embed_query( | |
query: str = Query(..., description="Text query to embed") | |
): | |
"""Generate ColPali embeddings for a text query""" | |
try: | |
model, processor = load_model() | |
# Create a dummy image for text-only queries | |
# ColPali expects image inputs, so we use a white image | |
dummy_image = Image.new('RGB', (448, 448), color='white') | |
# Process query with dummy image | |
inputs = processor( | |
images=[dummy_image], | |
text=[query], | |
return_tensors="pt", | |
padding=True | |
).to(device) | |
# Generate embeddings | |
with torch.no_grad(): | |
embeddings = model(**inputs) # Direct output, not .last_hidden_state | |
# Process embeddings for Vespa format | |
# Extract query embeddings (text tokens) | |
query_embeddings = embeddings[0] # First item in batch | |
# Convert to list format expected by Vespa | |
float_query_embedding = {} | |
binary_query_embeddings = {} | |
for idx in range(min(query_embeddings.shape[0], MAX_QUERY_TERMS)): | |
embedding_vector = query_embeddings[idx].cpu().numpy().tolist() | |
float_query_embedding[str(idx)] = embedding_vector | |
# Create binary version | |
binary_vector = ( | |
np.packbits(np.where(np.array(embedding_vector) > 0, 1, 0)) | |
.astype(np.int8) | |
.tolist() | |
) | |
binary_query_embeddings[str(idx)] = binary_vector | |
return { | |
"query": query, | |
"embeddings": { | |
"float": float_query_embedding, | |
"binary": binary_query_embeddings | |
}, | |
"num_tokens": len(float_query_embedding) | |
} | |
except Exception as e: | |
logger.error(f"Embedding error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def embed_query_simple( | |
query: str = Query(..., description="Text query to embed") | |
): | |
"""Generate simplified embeddings for text query (for testing)""" | |
try: | |
# For testing, return mock embeddings | |
# In production, this would use the actual ColPali model | |
mock_embedding = [0.1] * 128 # 128-dimensional embedding | |
return { | |
"query": query, | |
"embedding": mock_embedding, | |
"model": "colpali-v1.2" | |
} | |
except Exception as e: | |
logger.error(f"Embedding error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
port = int(os.getenv("EMBEDDING_PORT", "7861")) | |
logger.info(f"Starting ColPali Embedding API on port {port}") | |
# Pre-load model | |
load_model() | |
uvicorn.run(app, host="0.0.0.0", port=port) |