|
|
""" |
|
|
SigLIP 2 Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Model: google/siglip2-so400m-patch14-384 (Best balance of performance/quality) |
|
|
|
|
|
For ProofPath video assessment - identifies objects, tools, and actions in video frames. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Union |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize SigLIP 2 model for image/frame classification and embedding. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (provided by HF Inference Endpoints) |
|
|
""" |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
|
|
|
|
|
|
model_id = "google/siglip2-so400m-patch14-384" |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_id) |
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
attn_implementation="sdpa" |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def _decode_image(self, image_data: Any) -> Image.Image: |
|
|
""" |
|
|
Decode image from various input formats. |
|
|
|
|
|
Supports: |
|
|
- Base64 encoded image |
|
|
- URL to image |
|
|
- PIL Image |
|
|
- Raw bytes |
|
|
""" |
|
|
import requests |
|
|
|
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data |
|
|
elif isinstance(image_data, str): |
|
|
if image_data.startswith(('http://', 'https://')): |
|
|
|
|
|
response = requests.get(image_data, stream=True) |
|
|
return Image.open(response.raw).convert('RGB') |
|
|
elif image_data.startswith('data:'): |
|
|
|
|
|
header, encoded = image_data.split(',', 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
else: |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
elif isinstance(image_data, bytes): |
|
|
return Image.open(io.BytesIO(image_data)).convert('RGB') |
|
|
else: |
|
|
raise ValueError(f"Unsupported image input type: {type(image_data)}") |
|
|
|
|
|
def _process_batch( |
|
|
self, |
|
|
images: List[Image.Image], |
|
|
texts: List[str] = None |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Process a batch of images and optional texts.""" |
|
|
if texts: |
|
|
|
|
|
inputs = self.processor( |
|
|
images=images, |
|
|
text=texts, |
|
|
padding="max_length", |
|
|
max_length=64, |
|
|
return_tensors="pt" |
|
|
) |
|
|
else: |
|
|
inputs = self.processor( |
|
|
images=images, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
return {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process image(s) for classification or embedding extraction. |
|
|
|
|
|
Expected input formats: |
|
|
|
|
|
1. Zero-shot classification: |
|
|
{ |
|
|
"inputs": <image_data>, # single image or list of images |
|
|
"parameters": { |
|
|
"candidate_labels": ["label1", "label2", ...], |
|
|
"hypothesis_template": "This is a photo of {}." # Optional |
|
|
} |
|
|
} |
|
|
|
|
|
2. Image embedding only: |
|
|
{ |
|
|
"inputs": <image_data>, |
|
|
"parameters": { |
|
|
"mode": "embedding" |
|
|
} |
|
|
} |
|
|
|
|
|
3. Image-text similarity: |
|
|
{ |
|
|
"inputs": { |
|
|
"images": [<image1>, <image2>, ...], |
|
|
"texts": ["text1", "text2", ...] |
|
|
}, |
|
|
"parameters": { |
|
|
"mode": "similarity" |
|
|
} |
|
|
} |
|
|
|
|
|
Returns for classification: |
|
|
{ |
|
|
"labels": ["label1", "label2"], |
|
|
"scores": [0.85, 0.12], |
|
|
"predictions": [{"label": "label1", "score": 0.85}, ...] |
|
|
} |
|
|
|
|
|
Returns for embedding: |
|
|
{ |
|
|
"image_embeddings": [[...], ...], |
|
|
"embedding_shape": [batch, hidden_dim] |
|
|
} |
|
|
|
|
|
Returns for similarity: |
|
|
{ |
|
|
"similarity_matrix": [[...], ...], |
|
|
"shape": [num_images, num_texts] |
|
|
} |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
if inputs is None: |
|
|
inputs = data.get("image") or data.get("images") |
|
|
if inputs is None: |
|
|
raise ValueError("No input provided. Use 'inputs', 'image', or 'images' key.") |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
mode = params.get("mode", "classification") |
|
|
|
|
|
try: |
|
|
|
|
|
if mode == "embedding": |
|
|
return self._extract_embeddings(inputs) |
|
|
elif mode == "similarity": |
|
|
return self._compute_similarity(inputs, params) |
|
|
else: |
|
|
|
|
|
return self._classify(inputs, params) |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "error_type": type(e).__name__} |
|
|
|
|
|
def _classify(self, inputs: Any, params: Dict) -> Dict[str, Any]: |
|
|
"""Zero-shot image classification.""" |
|
|
candidate_labels = params.get("candidate_labels", []) |
|
|
if not candidate_labels: |
|
|
raise ValueError("candidate_labels required for classification mode") |
|
|
|
|
|
hypothesis_template = params.get("hypothesis_template", "This is a photo of {}.") |
|
|
|
|
|
|
|
|
if isinstance(inputs, list): |
|
|
images = [self._decode_image(img) for img in inputs] |
|
|
else: |
|
|
images = [self._decode_image(inputs)] |
|
|
|
|
|
|
|
|
texts = [hypothesis_template.format(label) for label in candidate_labels] |
|
|
|
|
|
results = [] |
|
|
for image in images: |
|
|
|
|
|
processed = self._process_batch([image] * len(texts), texts) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processed) |
|
|
|
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
|
probs = torch.sigmoid(logits_per_image[0]) |
|
|
|
|
|
|
|
|
sorted_indices = probs.argsort(descending=True) |
|
|
|
|
|
predictions = [] |
|
|
for idx in sorted_indices: |
|
|
predictions.append({ |
|
|
"label": candidate_labels[idx.item()], |
|
|
"score": float(probs[idx].item()) |
|
|
}) |
|
|
|
|
|
results.append({ |
|
|
"labels": [p["label"] for p in predictions], |
|
|
"scores": [p["score"] for p in predictions], |
|
|
"predictions": predictions |
|
|
}) |
|
|
|
|
|
|
|
|
if len(results) == 1: |
|
|
return results[0] |
|
|
return {"results": results} |
|
|
|
|
|
def _extract_embeddings(self, inputs: Any) -> Dict[str, Any]: |
|
|
"""Extract image embeddings only.""" |
|
|
|
|
|
if isinstance(inputs, list): |
|
|
images = [self._decode_image(img) for img in inputs] |
|
|
else: |
|
|
images = [self._decode_image(inputs)] |
|
|
|
|
|
processed = self.processor(images=images, return_tensors="pt") |
|
|
processed = {k: v.to(self.model.device) for k, v in processed.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
vision_outputs = self.model.get_image_features(**processed) |
|
|
|
|
|
embeddings = vision_outputs.cpu().numpy().tolist() |
|
|
|
|
|
return { |
|
|
"image_embeddings": embeddings, |
|
|
"embedding_shape": list(vision_outputs.shape) |
|
|
} |
|
|
|
|
|
def _compute_similarity(self, inputs: Dict, params: Dict) -> Dict[str, Any]: |
|
|
"""Compute image-text similarity matrix.""" |
|
|
images_data = inputs.get("images", []) |
|
|
texts = inputs.get("texts", []) |
|
|
|
|
|
if not images_data or not texts: |
|
|
raise ValueError("Both 'images' and 'texts' required for similarity mode") |
|
|
|
|
|
|
|
|
images = [self._decode_image(img) for img in images_data] |
|
|
|
|
|
|
|
|
processed = self.processor( |
|
|
images=images, |
|
|
text=texts, |
|
|
padding="max_length", |
|
|
max_length=64, |
|
|
return_tensors="pt" |
|
|
) |
|
|
processed = {k: v.to(self.model.device) for k, v in processed.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processed) |
|
|
|
|
|
|
|
|
similarity = outputs.logits_per_image |
|
|
probs = torch.sigmoid(similarity) |
|
|
|
|
|
return { |
|
|
"similarity_matrix": probs.cpu().numpy().tolist(), |
|
|
"shape": list(probs.shape), |
|
|
"logits": similarity.cpu().numpy().tolist() |
|
|
} |
|
|
|