""" SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints Model: facebook/sam3 Using the official sam3 package from Meta (pip install sam3) NOT the transformers integration. For ProofPath video assessment - text-prompted segmentation to find UI elements. Supports text prompts like "Save button", "dropdown menu", "text input field". KEY CAPABILITIES: - Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) - Promptable Concept Segmentation (PCS): 270K unique concepts - Video tracking: Consistent object IDs across frames - Presence token: Discriminates similar elements ("player in white" vs "player in red") REQUIREMENTS: 1. Set HF_TOKEN environment variable (model is gated) 2. Accept license at https://huggingface.co/facebook/sam3 """ from typing import Dict, List, Any, Optional, Union import torch import numpy as np import base64 import io import os class EndpointHandler: def __init__(self, path: str = ""): """ Initialize SAM 3 model for text-prompted segmentation. Uses the official sam3 package from Meta. Args: path: Path to the model directory (ignored - we load from HF hub) """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Import from official sam3 package from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # Build model - this downloads from HuggingFace automatically # Requires HF_TOKEN for gated model access self.model = build_sam3_image_model() self.processor = Sam3Processor(self.model) # Video model will be loaded lazily self._video_predictor = None def _get_video_predictor(self): """Lazy load video predictor only when needed.""" if self._video_predictor is None: from sam3.model_builder import build_sam3_video_predictor self._video_predictor = build_sam3_video_predictor() return self._video_predictor def _load_image(self, image_data: Any): """Load image from various formats.""" from PIL import Image import requests if isinstance(image_data, Image.Image): return image_data.convert('RGB') elif isinstance(image_data, str): if image_data.startswith(('http://', 'https://')): response = requests.get(image_data, stream=True) return Image.open(response.raw).convert('RGB') elif image_data.startswith('data:'): header, encoded = image_data.split(',', 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert('RGB') else: # Assume base64 encoded image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') elif isinstance(image_data, bytes): return Image.open(io.BytesIO(image_data)).convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_data)}") def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> tuple: """Load video frames from various formats.""" import cv2 from PIL import Image import tempfile # Decode to temp file if needed if isinstance(video_data, str): if video_data.startswith(('http://', 'https://')): import requests response = requests.get(video_data, stream=True) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) video_path = f.name elif video_data.startswith('data:'): header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name else: video_bytes = base64.b64decode(video_data) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name elif isinstance(video_data, bytes): with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_data) video_path = f.name else: raise ValueError(f"Unsupported video input type: {type(video_data)}") try: cap = cv2.VideoCapture(video_path) video_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / video_fps if video_fps > 0 else 0 # Calculate frames to sample target_frames = min(max_frames, int(duration * fps), total_frames) if target_frames <= 0: target_frames = min(max_frames, total_frames) frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) frames.append(pil_image) cap.release() metadata = { "duration": duration, "total_frames": total_frames, "sampled_frames": len(frames), "video_fps": video_fps } return video_path, metadata except Exception as e: if os.path.exists(video_path): os.unlink(video_path) raise e def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process image or video with SAM 3 for text-prompted segmentation. INPUT FORMATS: 1. Single image with text prompt (find all instances): { "inputs": , "parameters": { "prompt": "Save button", "return_masks": true } } 2. Single image with multiple text prompts: { "inputs": , "parameters": { "prompts": ["button", "text field", "dropdown"] } } 3. Video with text prompt (track all instances): { "inputs": , "parameters": { "mode": "video", "prompt": "Submit button", "max_frames": 100 } } 4. ProofPath UI element detection: { "inputs": , "parameters": { "mode": "ui_elements", "elements": ["Save button", "Cancel button", "text input"] } } OUTPUT FORMAT: { "results": [ { "prompt": "Save button", "instances": [ { "box": [x1, y1, x2, y2], "score": 0.95, "mask": "" // if return_masks=true } ] } ], "image_size": {"width": 1920, "height": 1080} } """ inputs = data.get("inputs") params = data.get("parameters", {}) if inputs is None: raise ValueError("No inputs provided") mode = params.get("mode", "image") if mode == "video": return self._process_video(inputs, params) elif mode == "ui_elements": return self._process_ui_elements(inputs, params) else: return self._process_single_image(inputs, params) def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: """Process a single image with text prompts using official sam3 API.""" image = self._load_image(image_data) return_masks = params.get("return_masks", True) # Get prompts prompt = params.get("prompt") prompts = params.get("prompts", [prompt] if prompt else []) if not prompts: raise ValueError("No text prompt(s) provided") # Set the image in processor inference_state = self.processor.set_image(image) results = [] for text_prompt in prompts: # Use official sam3 API output = self.processor.set_text_prompt( state=inference_state, prompt=text_prompt ) masks = output.get("masks", []) boxes = output.get("boxes", []) scores = output.get("scores", []) instances = [] # Convert tensors to lists if hasattr(boxes, 'tolist'): boxes = boxes.tolist() if hasattr(scores, 'tolist'): scores = scores.tolist() for i in range(len(boxes)): instance = { "box": boxes[i] if i < len(boxes) else None, "score": float(scores[i]) if i < len(scores) else 0.0 } if return_masks and masks is not None and i < len(masks): # Encode mask as base64 PNG mask = masks[i] if hasattr(mask, 'cpu'): mask = mask.cpu().numpy() mask_uint8 = (mask * 255).astype(np.uint8) from PIL import Image as PILImage mask_img = PILImage.fromarray(mask_uint8) buffer = io.BytesIO() mask_img.save(buffer, format='PNG') instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') instances.append(instance) results.append({ "prompt": text_prompt, "instances": instances, "count": len(instances) }) return { "results": results, "image_size": {"width": image.width, "height": image.height} } def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: """ ProofPath-specific mode: Detect multiple UI element types in a screenshot. Returns structured data for each element type with bounding boxes. """ image = self._load_image(image_data) elements = params.get("elements", []) if not elements: # Default UI elements to look for elements = ["button", "text input", "dropdown", "checkbox", "link"] # Set the image once inference_state = self.processor.set_image(image) all_detections = {} for element_type in elements: output = self.processor.set_text_prompt( state=inference_state, prompt=element_type ) boxes = output.get("boxes", []) scores = output.get("scores", []) if hasattr(boxes, 'tolist'): boxes = boxes.tolist() if hasattr(scores, 'tolist'): scores = scores.tolist() detections = [] for i in range(len(boxes)): box = boxes[i] detections.append({ "box": box, "score": float(scores[i]) if i < len(scores) else 0.0, "center": [ (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 ] if len(box) >= 4 else None }) all_detections[element_type] = { "count": len(detections), "instances": detections } return { "ui_elements": all_detections, "image_size": {"width": image.width, "height": image.height}, "total_elements": sum(d["count"] for d in all_detections.values()) } def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: """ Process video with SAM3 Video for text-prompted tracking. Uses the official sam3 video predictor API. """ video_predictor = self._get_video_predictor() prompt = params.get("prompt") if not prompt: raise ValueError("Text prompt required for video mode") max_frames = params.get("max_frames", 100) # Load video to temp path video_path, video_metadata = self._load_video_frames(video_data, max_frames) try: # Start video session response = video_predictor.handle_request( request=dict( type="start_session", resource_path=video_path, ) ) session_id = response.get("session_id") # Add text prompt at frame 0 response = video_predictor.handle_request( request=dict( type="add_prompt", session_id=session_id, frame_index=0, text=prompt, ) ) output = response.get("outputs", {}) # Get tracked objects object_ids = output.get("object_ids", []) if hasattr(object_ids, 'tolist'): object_ids = object_ids.tolist() # Propagate through video propagate_response = video_predictor.handle_request( request=dict( type="propagate", session_id=session_id, ) ) # Collect results per frame per_frame_results = propagate_response.get("per_frame_outputs", {}) # Convert to serializable format tracks = [] for obj_id in object_ids: track = { "object_id": int(obj_id) if hasattr(obj_id, 'item') else obj_id, "frames": [] } tracks.append(track) return { "prompt": prompt, "video_metadata": video_metadata, "objects_tracked": len(object_ids), "tracks": tracks, "session_id": session_id } finally: # Clean up temp file if os.path.exists(video_path): os.unlink(video_path) # For testing locally if __name__ == "__main__": handler = EndpointHandler() # Test with a sample image URL test_data = { "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", "parameters": { "prompt": "ear", "return_masks": False } } result = handler(test_data) print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") for inst in result['results'][0]['instances']: print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")