M2-Encoder-0.4B / handler.py
malusama's picture
Add Inference Endpoints handler
3b428b1 verified
import base64
import io
import os
from typing import Any, Dict, List
from urllib.parse import urlparse
from urllib.request import urlopen
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model.to(self.device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data.pop("inputs", data)
parameters = data.pop("parameters", {}) or {}
texts = self._coerce_texts(payload)
images = self._coerce_images(payload)
if not texts and not images:
raise ValueError(
"Expected `inputs` to include `text`/`texts` and/or `image`/`images`."
)
result: Dict[str, Any] = {}
with torch.no_grad():
text_embeds = None
image_embeds = None
if texts:
text_inputs = self.processor(text=texts, return_tensors="pt")
text_inputs = self._move_to_device(text_inputs)
text_embeds = self.model(**text_inputs).text_embeds
result["text_embedding"] = text_embeds.cpu().tolist()
if images:
image_inputs = self.processor(images=images, return_tensors="pt")
image_inputs = self._move_to_device(image_inputs)
image_embeds = self.model(**image_inputs).image_embeds
result["image_embedding"] = image_embeds.cpu().tolist()
if text_embeds is not None and image_embeds is not None:
scores = image_embeds @ text_embeds.t()
result["scores"] = scores.cpu().tolist()
if parameters.get("return_probs", True):
result["probs"] = scores.softmax(dim=-1).cpu().tolist()
if parameters.get("return_logits", False):
logit_scale = self.model.model.logit_scale.exp()
result["logits_per_image"] = (
(logit_scale * image_embeds @ text_embeds.t()).cpu().tolist()
)
return result
def _move_to_device(self, batch: Dict[str, Any]) -> Dict[str, Any]:
moved = {}
for key, value in batch.items():
moved[key] = value.to(self.device) if hasattr(value, "to") else value
return moved
def _coerce_texts(self, payload: Any) -> List[str]:
if isinstance(payload, str):
return [payload]
if not isinstance(payload, dict):
return []
texts = payload.get("text", payload.get("texts"))
if texts is None:
return []
if isinstance(texts, str):
return [texts]
return [str(item) for item in texts]
def _coerce_images(self, payload: Any) -> List[Image.Image]:
if not isinstance(payload, dict):
return []
images = payload.get("image", payload.get("images"))
if images is None:
return []
if not isinstance(images, (list, tuple)):
images = [images]
return [self._load_image(item) for item in images]
def _load_image(self, value: Any) -> Image.Image:
if isinstance(value, Image.Image):
return value.convert("RGB")
if isinstance(value, dict):
for key in ("data", "image", "url", "path"):
if key in value:
value = value[key]
break
if not isinstance(value, str):
raise TypeError(f"Unsupported image input type: {type(value)!r}")
if os.path.exists(value):
return Image.open(value).convert("RGB")
parsed = urlparse(value)
if parsed.scheme in ("http", "https"):
with urlopen(value) as response:
return Image.open(io.BytesIO(response.read())).convert("RGB")
if value.startswith("data:image/"):
_, encoded = value.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
try:
return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB")
except Exception as exc:
raise ValueError("Unsupported image string. Use URL, local path, or base64.") from exc