""" Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints Model: allenai/Molmo2-8B For ProofPath video assessment - video pointing, tracking, and grounded analysis. Unique capability: Returns pixel-level coordinates for objects in videos. """ from typing import Dict, List, Any, Optional, Tuple, Union import torch import numpy as np import base64 import io import tempfile import os import re class EndpointHandler: def __init__(self, path: str = ""): """ Initialize Molmo 2 model for video pointing and tracking. Args: path: Path to the model directory (ignored - we always load from HF hub) """ # IMPORTANT: Always load from HF hub, not the repository path model_id = "allenai/Molmo2-8B" # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load processor and model - Molmo2 uses AutoModelForImageTextToText from transformers import AutoProcessor, AutoModelForImageTextToText self.processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, ) self.model = AutoModelForImageTextToText.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() # Molmo 2 limits self.max_frames = 128 self.default_fps = 2.0 # Regex patterns for parsing Molmo pointing output self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>") self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)") self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})") def _parse_video_points(self, text: str, image_w: int, image_h: int) -> List[Dict]: """ Extract pointing coordinates from Molmo output. Molmo outputs coordinates in format: Where: timestamp instance_id x y (coords scaled by 1000) """ all_points = [] for coord_match in self.COORD_REGEX.finditer(text): for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): timestamp = float(frame_match.group(1)) for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): instance_id = int(point_match.group(1)) # Coordinates are scaled by 1000 x = float(point_match.group(2)) / 1000 * image_w y = float(point_match.group(3)) / 1000 * image_h if 0 <= x <= image_w and 0 <= y <= image_h: all_points.append({ "timestamp": timestamp, "instance_id": instance_id, "x": x, "y": y }) return all_points def _load_image(self, image_data: Any): """Load a single image from various formats.""" from PIL import Image import requests if isinstance(image_data, Image.Image): return image_data elif isinstance(image_data, str): if image_data.startswith(('http://', 'https://')): response = requests.get(image_data, stream=True) return Image.open(response.raw).convert('RGB') elif image_data.startswith('data:'): header, encoded = image_data.split(',', 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert('RGB') else: image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') elif isinstance(image_data, bytes): return Image.open(io.BytesIO(image_data)).convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_data)}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process video or images with Molmo 2. Expected input formats: 1. Image analysis with pointing: { "inputs": , "parameters": { "prompt": "Point to the Excel cell B2.", "max_new_tokens": 1024 } } 2. Video analysis: { "inputs": , "parameters": { "prompt": "What happens in this video?", "max_new_tokens": 2048 } } 3. Multi-image comparison: { "inputs": [, ], "parameters": { "prompt": "Compare these screenshots." } } Returns: { "generated_text": "...", "points": [{"timestamp": 0, "x": 123, "y": 456, ...}], # If pointing detected "image_size": {...} } """ inputs = data.get("inputs") if inputs is None: inputs = data.get("video") or data.get("image") or data.get("images") if inputs is None: raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") params = data.get("parameters", {}) prompt = params.get("prompt", "Describe this image.") max_new_tokens = params.get("max_new_tokens", 1024) try: if isinstance(inputs, list): return self._process_multi_image(inputs, prompt, max_new_tokens) elif self._is_video(inputs, params): return self._process_video(inputs, prompt, params, max_new_tokens) else: return self._process_image(inputs, prompt, max_new_tokens) except Exception as e: import traceback return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()} def _is_video(self, inputs: Any, params: Dict) -> bool: """Determine if input is video.""" if params.get("input_type") == "video": return True if params.get("input_type") == "image": return False if isinstance(inputs, str): lower = inputs.lower() video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] return any(ext in lower for ext in video_exts) return False def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]: """Process a single image.""" from PIL import Image image = self._load_image(image_data) # Build message in Molmo format messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] # Apply chat template and process inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate with torch.inference_mode(): output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, ) # Decode - only new tokens generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) result = { "generated_text": generated_text, "image_size": {"width": image.width, "height": image.height} } # Parse any pointing coordinates points = self._parse_video_points(generated_text, image.width, image.height) if points: result["points"] = points result["num_points"] = len(points) return result def _process_video( self, video_data: Any, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Process video using molmo_utils.""" from molmo_utils import process_vision_info # Handle video URL or base64 if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')): video_source = video_data temp_path = None else: # Write to temp file if isinstance(video_data, str): video_bytes = base64.b64decode(video_data) else: video_bytes = video_data with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_source = f.name temp_path = f.name try: # Build message messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "video", "video": video_source}, ], } ] # Process video with molmo_utils _, videos, video_kwargs = process_vision_info(messages) videos, video_metadatas = zip(*videos) videos, video_metadatas = list(videos), list(video_metadatas) # Apply chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process inputs inputs = self.processor( videos=videos, video_metadata=video_metadatas, text=text, padding=True, return_tensors="pt", **video_kwargs, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate with torch.inference_mode(): output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, ) # Decode generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) # Get video dimensions video_w = video_metadatas[0].get("width", 1920) video_h = video_metadatas[0].get("height", 1080) result = { "generated_text": generated_text, "video_metadata": { "width": video_w, "height": video_h, } } # Parse coordinates points = self._parse_video_points(generated_text, video_w, video_h) if points: result["points"] = points result["num_points"] = len(points) return result finally: # Clean up temp file if temp_path and os.path.exists(temp_path): os.unlink(temp_path) def _process_multi_image( self, images_data: List, prompt: str, max_new_tokens: int ) -> Dict[str, Any]: """Process multiple images.""" from PIL import Image images = [self._load_image(img) for img in images_data] # Build content with all images content = [] for image in images: content.append({"type": "image", "image": image}) content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] # Apply chat template inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate with torch.inference_mode(): output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, ) # Decode generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) result = { "generated_text": generated_text, "num_images": len(images), "image_sizes": [{"width": img.width, "height": img.height} for img in images] } # Parse points using first image dimensions if images: points = self._parse_video_points(generated_text, images[0].width, images[0].height) if points: result["points"] = points result["num_points"] = len(points) return result