sam3 / handler.py
peterproofpath's picture
Update handler.py
b7720c4 verified
"""
SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints
Model: facebook/sam3
Using the official sam3 package from Meta (pip install sam3)
NOT the transformers integration.
For ProofPath video assessment - text-prompted segmentation to find UI elements.
Supports text prompts like "Save button", "dropdown menu", "text input field".
KEY CAPABILITIES:
- Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons)
- Promptable Concept Segmentation (PCS): 270K unique concepts
- Video tracking: Consistent object IDs across frames
- Presence token: Discriminates similar elements ("player in white" vs "player in red")
REQUIREMENTS:
1. Set HF_TOKEN environment variable (model is gated)
2. Accept license at https://huggingface.co/facebook/sam3
"""
from typing import Dict, List, Any, Optional, Union
import torch
import numpy as np
import base64
import io
import os
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize SAM 3 model for text-prompted segmentation.
Uses the official sam3 package from Meta.
Args:
path: Path to the model directory (ignored - we load from HF hub)
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Import from official sam3 package
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Build model - this downloads from HuggingFace automatically
# Requires HF_TOKEN for gated model access
self.model = build_sam3_image_model()
self.processor = Sam3Processor(self.model)
# Video model will be loaded lazily
self._video_predictor = None
def _get_video_predictor(self):
"""Lazy load video predictor only when needed."""
if self._video_predictor is None:
from sam3.model_builder import build_sam3_video_predictor
self._video_predictor = build_sam3_video_predictor()
return self._video_predictor
def _load_image(self, image_data: Any):
"""Load image from various formats."""
from PIL import Image
import requests
if isinstance(image_data, Image.Image):
return image_data.convert('RGB')
elif isinstance(image_data, str):
if image_data.startswith(('http://', 'https://')):
response = requests.get(image_data, stream=True)
return Image.open(response.raw).convert('RGB')
elif image_data.startswith('data:'):
header, encoded = image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
# Assume base64 encoded
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
elif isinstance(image_data, bytes):
return Image.open(io.BytesIO(image_data)).convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_data)}")
def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> tuple:
"""Load video frames from various formats."""
import cv2
from PIL import Image
import tempfile
# Decode to temp file if needed
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
import requests
response = requests.get(video_data, stream=True)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
video_path = f.name
elif video_data.startswith('data:'):
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
else:
video_bytes = base64.b64decode(video_data)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
elif isinstance(video_data, bytes):
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_data)
video_path = f.name
else:
raise ValueError(f"Unsupported video input type: {type(video_data)}")
try:
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps if video_fps > 0 else 0
# Calculate frames to sample
target_frames = min(max_frames, int(duration * fps), total_frames)
if target_frames <= 0:
target_frames = min(max_frames, total_frames)
frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
frames.append(pil_image)
cap.release()
metadata = {
"duration": duration,
"total_frames": total_frames,
"sampled_frames": len(frames),
"video_fps": video_fps
}
return video_path, metadata
except Exception as e:
if os.path.exists(video_path):
os.unlink(video_path)
raise e
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process image or video with SAM 3 for text-prompted segmentation.
INPUT FORMATS:
1. Single image with text prompt (find all instances):
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompt": "Save button",
"return_masks": true
}
}
2. Single image with multiple text prompts:
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompts": ["button", "text field", "dropdown"]
}
}
3. Video with text prompt (track all instances):
{
"inputs": <video_url_or_base64>,
"parameters": {
"mode": "video",
"prompt": "Submit button",
"max_frames": 100
}
}
4. ProofPath UI element detection:
{
"inputs": <screenshot_base64>,
"parameters": {
"mode": "ui_elements",
"elements": ["Save button", "Cancel button", "text input"]
}
}
OUTPUT FORMAT:
{
"results": [
{
"prompt": "Save button",
"instances": [
{
"box": [x1, y1, x2, y2],
"score": 0.95,
"mask": "<base64_png>" // if return_masks=true
}
]
}
],
"image_size": {"width": 1920, "height": 1080}
}
"""
inputs = data.get("inputs")
params = data.get("parameters", {})
if inputs is None:
raise ValueError("No inputs provided")
mode = params.get("mode", "image")
if mode == "video":
return self._process_video(inputs, params)
elif mode == "ui_elements":
return self._process_ui_elements(inputs, params)
else:
return self._process_single_image(inputs, params)
def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]:
"""Process a single image with text prompts using official sam3 API."""
image = self._load_image(image_data)
return_masks = params.get("return_masks", True)
# Get prompts
prompt = params.get("prompt")
prompts = params.get("prompts", [prompt] if prompt else [])
if not prompts:
raise ValueError("No text prompt(s) provided")
# Set the image in processor
inference_state = self.processor.set_image(image)
results = []
for text_prompt in prompts:
# Use official sam3 API
output = self.processor.set_text_prompt(
state=inference_state,
prompt=text_prompt
)
masks = output.get("masks", [])
boxes = output.get("boxes", [])
scores = output.get("scores", [])
instances = []
# Convert tensors to lists
if hasattr(boxes, 'tolist'):
boxes = boxes.tolist()
if hasattr(scores, 'tolist'):
scores = scores.tolist()
for i in range(len(boxes)):
instance = {
"box": boxes[i] if i < len(boxes) else None,
"score": float(scores[i]) if i < len(scores) else 0.0
}
if return_masks and masks is not None and i < len(masks):
# Encode mask as base64 PNG
mask = masks[i]
if hasattr(mask, 'cpu'):
mask = mask.cpu().numpy()
mask_uint8 = (mask * 255).astype(np.uint8)
from PIL import Image as PILImage
mask_img = PILImage.fromarray(mask_uint8)
buffer = io.BytesIO()
mask_img.save(buffer, format='PNG')
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
instances.append(instance)
results.append({
"prompt": text_prompt,
"instances": instances,
"count": len(instances)
})
return {
"results": results,
"image_size": {"width": image.width, "height": image.height}
}
def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]:
"""
ProofPath-specific mode: Detect multiple UI element types in a screenshot.
Returns structured data for each element type with bounding boxes.
"""
image = self._load_image(image_data)
elements = params.get("elements", [])
if not elements:
# Default UI elements to look for
elements = ["button", "text input", "dropdown", "checkbox", "link"]
# Set the image once
inference_state = self.processor.set_image(image)
all_detections = {}
for element_type in elements:
output = self.processor.set_text_prompt(
state=inference_state,
prompt=element_type
)
boxes = output.get("boxes", [])
scores = output.get("scores", [])
if hasattr(boxes, 'tolist'):
boxes = boxes.tolist()
if hasattr(scores, 'tolist'):
scores = scores.tolist()
detections = []
for i in range(len(boxes)):
box = boxes[i]
detections.append({
"box": box,
"score": float(scores[i]) if i < len(scores) else 0.0,
"center": [
(box[0] + box[2]) / 2,
(box[1] + box[3]) / 2
] if len(box) >= 4 else None
})
all_detections[element_type] = {
"count": len(detections),
"instances": detections
}
return {
"ui_elements": all_detections,
"image_size": {"width": image.width, "height": image.height},
"total_elements": sum(d["count"] for d in all_detections.values())
}
def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]:
"""
Process video with SAM3 Video for text-prompted tracking.
Uses the official sam3 video predictor API.
"""
video_predictor = self._get_video_predictor()
prompt = params.get("prompt")
if not prompt:
raise ValueError("Text prompt required for video mode")
max_frames = params.get("max_frames", 100)
# Load video to temp path
video_path, video_metadata = self._load_video_frames(video_data, max_frames)
try:
# Start video session
response = video_predictor.handle_request(
request=dict(
type="start_session",
resource_path=video_path,
)
)
session_id = response.get("session_id")
# Add text prompt at frame 0
response = video_predictor.handle_request(
request=dict(
type="add_prompt",
session_id=session_id,
frame_index=0,
text=prompt,
)
)
output = response.get("outputs", {})
# Get tracked objects
object_ids = output.get("object_ids", [])
if hasattr(object_ids, 'tolist'):
object_ids = object_ids.tolist()
# Propagate through video
propagate_response = video_predictor.handle_request(
request=dict(
type="propagate",
session_id=session_id,
)
)
# Collect results per frame
per_frame_results = propagate_response.get("per_frame_outputs", {})
# Convert to serializable format
tracks = []
for obj_id in object_ids:
track = {
"object_id": int(obj_id) if hasattr(obj_id, 'item') else obj_id,
"frames": []
}
tracks.append(track)
return {
"prompt": prompt,
"video_metadata": video_metadata,
"objects_tracked": len(object_ids),
"tracks": tracks,
"session_id": session_id
}
finally:
# Clean up temp file
if os.path.exists(video_path):
os.unlink(video_path)
# For testing locally
if __name__ == "__main__":
handler = EndpointHandler()
# Test with a sample image URL
test_data = {
"inputs": "http://images.cocodataset.org/val2017/000000077595.jpg",
"parameters": {
"prompt": "ear",
"return_masks": False
}
}
result = handler(test_data)
print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'")
for inst in result['results'][0]['instances']:
print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")