Spaces:
Running
Running
File size: 5,514 Bytes
5dfbe50 c3f5ff1 5dfbe50 9fc7504 5dfbe50 c3f5ff1 5dfbe50 c3f5ff1 5dfbe50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
#!/usr/bin/env python3
"""
ColPali Embedding API for generating query embeddings
"""
import os
import logging
import numpy as np
from pathlib import Path
from typing import List, Dict
from fastapi import FastAPI, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
from PIL import Image
import uvicorn
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
# Set HF token if available
hf_token = os.environ.get("HUGGING_FACE_TOKEN") or os.environ.get("HF_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI
app = FastAPI(title="ColPali Embedding API")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:3025", "http://localhost:4000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model variables
model = None
processor = None
device = None
MAX_QUERY_TERMS = 64
def load_model():
"""Load ColPali model and processor"""
global model, processor, device
if model is None:
logger.info("Loading ColPali model...")
device = get_torch_device("auto")
logger.info(f"Using device: {device}")
try:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device,
token=hf_token
).eval()
processor = ColPaliProcessor.from_pretrained(model_name, token=hf_token)
logger.info("ColPali model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
# Try alternative model
model_name = "vidore/colpaligemma-3b-pt-448-base"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
logger.info(f"Loaded alternative model: {model_name}")
return model, processor
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "healthy", "service": "colpali-embedding-api"}
@app.get("/embed_query")
async def embed_query(
query: str = Query(..., description="Text query to embed")
):
"""Generate ColPali embeddings for a text query"""
try:
model, processor = load_model()
# Create a dummy image for text-only queries
# ColPali expects image inputs, so we use a white image
dummy_image = Image.new('RGB', (448, 448), color='white')
# Process query with dummy image
inputs = processor(
images=[dummy_image],
text=[query],
return_tensors="pt",
padding=True
).to(device)
# Generate embeddings
with torch.no_grad():
embeddings = model(**inputs) # Direct output, not .last_hidden_state
# Process embeddings for Vespa format
# Extract query embeddings (text tokens)
query_embeddings = embeddings[0] # First item in batch
# Convert to list format expected by Vespa
float_query_embedding = {}
binary_query_embeddings = {}
for idx in range(min(query_embeddings.shape[0], MAX_QUERY_TERMS)):
embedding_vector = query_embeddings[idx].cpu().numpy().tolist()
float_query_embedding[str(idx)] = embedding_vector
# Create binary version
binary_vector = (
np.packbits(np.where(np.array(embedding_vector) > 0, 1, 0))
.astype(np.int8)
.tolist()
)
binary_query_embeddings[str(idx)] = binary_vector
return {
"query": query,
"embeddings": {
"float": float_query_embedding,
"binary": binary_query_embeddings
},
"num_tokens": len(float_query_embedding)
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/embed_query_simple")
async def embed_query_simple(
query: str = Query(..., description="Text query to embed")
):
"""Generate simplified embeddings for text query (for testing)"""
try:
# For testing, return mock embeddings
# In production, this would use the actual ColPali model
mock_embedding = [0.1] * 128 # 128-dimensional embedding
return {
"query": query,
"embedding": mock_embedding,
"model": "colpali-v1.2"
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
port = int(os.getenv("EMBEDDING_PORT", "7861"))
logger.info(f"Starting ColPali Embedding API on port {port}")
# Pre-load model
load_model()
uvicorn.run(app, host="0.0.0.0", port=port) |