|
|
""" |
|
|
SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Model: facebook/sam3 |
|
|
|
|
|
Using the official sam3 package from Meta (pip install sam3) |
|
|
NOT the transformers integration. |
|
|
|
|
|
For ProofPath video assessment - text-prompted segmentation to find UI elements. |
|
|
Supports text prompts like "Save button", "dropdown menu", "text input field". |
|
|
|
|
|
KEY CAPABILITIES: |
|
|
- Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) |
|
|
- Promptable Concept Segmentation (PCS): 270K unique concepts |
|
|
- Video tracking: Consistent object IDs across frames |
|
|
- Presence token: Discriminates similar elements ("player in white" vs "player in red") |
|
|
|
|
|
REQUIREMENTS: |
|
|
1. Set HF_TOKEN environment variable (model is gated) |
|
|
2. Accept license at https://huggingface.co/facebook/sam3 |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Optional, Union |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize SAM 3 model for text-prompted segmentation. |
|
|
Uses the official sam3 package from Meta. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (ignored - we load from HF hub) |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
from sam3.model_builder import build_sam3_image_model |
|
|
from sam3.model.sam3_image_processor import Sam3Processor |
|
|
|
|
|
|
|
|
|
|
|
self.model = build_sam3_image_model() |
|
|
self.processor = Sam3Processor(self.model) |
|
|
|
|
|
|
|
|
self._video_predictor = None |
|
|
|
|
|
def _get_video_predictor(self): |
|
|
"""Lazy load video predictor only when needed.""" |
|
|
if self._video_predictor is None: |
|
|
from sam3.model_builder import build_sam3_video_predictor |
|
|
self._video_predictor = build_sam3_video_predictor() |
|
|
return self._video_predictor |
|
|
|
|
|
def _load_image(self, image_data: Any): |
|
|
"""Load image from various formats.""" |
|
|
from PIL import Image |
|
|
import requests |
|
|
|
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data.convert('RGB') |
|
|
elif isinstance(image_data, str): |
|
|
if image_data.startswith(('http://', 'https://')): |
|
|
response = requests.get(image_data, stream=True) |
|
|
return Image.open(response.raw).convert('RGB') |
|
|
elif image_data.startswith('data:'): |
|
|
header, encoded = image_data.split(',', 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
else: |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
elif isinstance(image_data, bytes): |
|
|
return Image.open(io.BytesIO(image_data)).convert('RGB') |
|
|
else: |
|
|
raise ValueError(f"Unsupported image input type: {type(image_data)}") |
|
|
|
|
|
def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> tuple: |
|
|
"""Load video frames from various formats.""" |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
|
|
|
|
|
|
if isinstance(video_data, str): |
|
|
if video_data.startswith(('http://', 'https://')): |
|
|
import requests |
|
|
response = requests.get(video_data, stream=True) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
video_path = f.name |
|
|
elif video_data.startswith('data:'): |
|
|
header, encoded = video_data.split(',', 1) |
|
|
video_bytes = base64.b64decode(encoded) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_path = f.name |
|
|
else: |
|
|
video_bytes = base64.b64decode(video_data) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_path = f.name |
|
|
elif isinstance(video_data, bytes): |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_data) |
|
|
video_path = f.name |
|
|
else: |
|
|
raise ValueError(f"Unsupported video input type: {type(video_data)}") |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / video_fps if video_fps > 0 else 0 |
|
|
|
|
|
|
|
|
target_frames = min(max_frames, int(duration * fps), total_frames) |
|
|
if target_frames <= 0: |
|
|
target_frames = min(max_frames, total_frames) |
|
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) |
|
|
|
|
|
frames = [] |
|
|
for idx in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(frame_rgb) |
|
|
frames.append(pil_image) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
metadata = { |
|
|
"duration": duration, |
|
|
"total_frames": total_frames, |
|
|
"sampled_frames": len(frames), |
|
|
"video_fps": video_fps |
|
|
} |
|
|
|
|
|
return video_path, metadata |
|
|
|
|
|
except Exception as e: |
|
|
if os.path.exists(video_path): |
|
|
os.unlink(video_path) |
|
|
raise e |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process image or video with SAM 3 for text-prompted segmentation. |
|
|
|
|
|
INPUT FORMATS: |
|
|
|
|
|
1. Single image with text prompt (find all instances): |
|
|
{ |
|
|
"inputs": <image_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompt": "Save button", |
|
|
"return_masks": true |
|
|
} |
|
|
} |
|
|
|
|
|
2. Single image with multiple text prompts: |
|
|
{ |
|
|
"inputs": <image_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompts": ["button", "text field", "dropdown"] |
|
|
} |
|
|
} |
|
|
|
|
|
3. Video with text prompt (track all instances): |
|
|
{ |
|
|
"inputs": <video_url_or_base64>, |
|
|
"parameters": { |
|
|
"mode": "video", |
|
|
"prompt": "Submit button", |
|
|
"max_frames": 100 |
|
|
} |
|
|
} |
|
|
|
|
|
4. ProofPath UI element detection: |
|
|
{ |
|
|
"inputs": <screenshot_base64>, |
|
|
"parameters": { |
|
|
"mode": "ui_elements", |
|
|
"elements": ["Save button", "Cancel button", "text input"] |
|
|
} |
|
|
} |
|
|
|
|
|
OUTPUT FORMAT: |
|
|
{ |
|
|
"results": [ |
|
|
{ |
|
|
"prompt": "Save button", |
|
|
"instances": [ |
|
|
{ |
|
|
"box": [x1, y1, x2, y2], |
|
|
"score": 0.95, |
|
|
"mask": "<base64_png>" // if return_masks=true |
|
|
} |
|
|
] |
|
|
} |
|
|
], |
|
|
"image_size": {"width": 1920, "height": 1080} |
|
|
} |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
if inputs is None: |
|
|
raise ValueError("No inputs provided") |
|
|
|
|
|
mode = params.get("mode", "image") |
|
|
|
|
|
if mode == "video": |
|
|
return self._process_video(inputs, params) |
|
|
elif mode == "ui_elements": |
|
|
return self._process_ui_elements(inputs, params) |
|
|
else: |
|
|
return self._process_single_image(inputs, params) |
|
|
|
|
|
def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
"""Process a single image with text prompts using official sam3 API.""" |
|
|
image = self._load_image(image_data) |
|
|
|
|
|
return_masks = params.get("return_masks", True) |
|
|
|
|
|
|
|
|
prompt = params.get("prompt") |
|
|
prompts = params.get("prompts", [prompt] if prompt else []) |
|
|
|
|
|
if not prompts: |
|
|
raise ValueError("No text prompt(s) provided") |
|
|
|
|
|
|
|
|
inference_state = self.processor.set_image(image) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for text_prompt in prompts: |
|
|
|
|
|
output = self.processor.set_text_prompt( |
|
|
state=inference_state, |
|
|
prompt=text_prompt |
|
|
) |
|
|
|
|
|
masks = output.get("masks", []) |
|
|
boxes = output.get("boxes", []) |
|
|
scores = output.get("scores", []) |
|
|
|
|
|
instances = [] |
|
|
|
|
|
|
|
|
if hasattr(boxes, 'tolist'): |
|
|
boxes = boxes.tolist() |
|
|
if hasattr(scores, 'tolist'): |
|
|
scores = scores.tolist() |
|
|
|
|
|
for i in range(len(boxes)): |
|
|
instance = { |
|
|
"box": boxes[i] if i < len(boxes) else None, |
|
|
"score": float(scores[i]) if i < len(scores) else 0.0 |
|
|
} |
|
|
|
|
|
if return_masks and masks is not None and i < len(masks): |
|
|
|
|
|
mask = masks[i] |
|
|
if hasattr(mask, 'cpu'): |
|
|
mask = mask.cpu().numpy() |
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
from PIL import Image as PILImage |
|
|
mask_img = PILImage.fromarray(mask_uint8) |
|
|
buffer = io.BytesIO() |
|
|
mask_img.save(buffer, format='PNG') |
|
|
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
instances.append(instance) |
|
|
|
|
|
results.append({ |
|
|
"prompt": text_prompt, |
|
|
"instances": instances, |
|
|
"count": len(instances) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"results": results, |
|
|
"image_size": {"width": image.width, "height": image.height} |
|
|
} |
|
|
|
|
|
def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
""" |
|
|
ProofPath-specific mode: Detect multiple UI element types in a screenshot. |
|
|
Returns structured data for each element type with bounding boxes. |
|
|
""" |
|
|
image = self._load_image(image_data) |
|
|
|
|
|
elements = params.get("elements", []) |
|
|
if not elements: |
|
|
|
|
|
elements = ["button", "text input", "dropdown", "checkbox", "link"] |
|
|
|
|
|
|
|
|
inference_state = self.processor.set_image(image) |
|
|
|
|
|
all_detections = {} |
|
|
|
|
|
for element_type in elements: |
|
|
output = self.processor.set_text_prompt( |
|
|
state=inference_state, |
|
|
prompt=element_type |
|
|
) |
|
|
|
|
|
boxes = output.get("boxes", []) |
|
|
scores = output.get("scores", []) |
|
|
|
|
|
if hasattr(boxes, 'tolist'): |
|
|
boxes = boxes.tolist() |
|
|
if hasattr(scores, 'tolist'): |
|
|
scores = scores.tolist() |
|
|
|
|
|
detections = [] |
|
|
for i in range(len(boxes)): |
|
|
box = boxes[i] |
|
|
detections.append({ |
|
|
"box": box, |
|
|
"score": float(scores[i]) if i < len(scores) else 0.0, |
|
|
"center": [ |
|
|
(box[0] + box[2]) / 2, |
|
|
(box[1] + box[3]) / 2 |
|
|
] if len(box) >= 4 else None |
|
|
}) |
|
|
|
|
|
all_detections[element_type] = { |
|
|
"count": len(detections), |
|
|
"instances": detections |
|
|
} |
|
|
|
|
|
return { |
|
|
"ui_elements": all_detections, |
|
|
"image_size": {"width": image.width, "height": image.height}, |
|
|
"total_elements": sum(d["count"] for d in all_detections.values()) |
|
|
} |
|
|
|
|
|
def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
""" |
|
|
Process video with SAM3 Video for text-prompted tracking. |
|
|
Uses the official sam3 video predictor API. |
|
|
""" |
|
|
video_predictor = self._get_video_predictor() |
|
|
|
|
|
prompt = params.get("prompt") |
|
|
if not prompt: |
|
|
raise ValueError("Text prompt required for video mode") |
|
|
|
|
|
max_frames = params.get("max_frames", 100) |
|
|
|
|
|
|
|
|
video_path, video_metadata = self._load_video_frames(video_data, max_frames) |
|
|
|
|
|
try: |
|
|
|
|
|
response = video_predictor.handle_request( |
|
|
request=dict( |
|
|
type="start_session", |
|
|
resource_path=video_path, |
|
|
) |
|
|
) |
|
|
session_id = response.get("session_id") |
|
|
|
|
|
|
|
|
response = video_predictor.handle_request( |
|
|
request=dict( |
|
|
type="add_prompt", |
|
|
session_id=session_id, |
|
|
frame_index=0, |
|
|
text=prompt, |
|
|
) |
|
|
) |
|
|
|
|
|
output = response.get("outputs", {}) |
|
|
|
|
|
|
|
|
object_ids = output.get("object_ids", []) |
|
|
if hasattr(object_ids, 'tolist'): |
|
|
object_ids = object_ids.tolist() |
|
|
|
|
|
|
|
|
propagate_response = video_predictor.handle_request( |
|
|
request=dict( |
|
|
type="propagate", |
|
|
session_id=session_id, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
per_frame_results = propagate_response.get("per_frame_outputs", {}) |
|
|
|
|
|
|
|
|
tracks = [] |
|
|
for obj_id in object_ids: |
|
|
track = { |
|
|
"object_id": int(obj_id) if hasattr(obj_id, 'item') else obj_id, |
|
|
"frames": [] |
|
|
} |
|
|
tracks.append(track) |
|
|
|
|
|
return { |
|
|
"prompt": prompt, |
|
|
"video_metadata": video_metadata, |
|
|
"objects_tracked": len(object_ids), |
|
|
"tracks": tracks, |
|
|
"session_id": session_id |
|
|
} |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists(video_path): |
|
|
os.unlink(video_path) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
handler = EndpointHandler() |
|
|
|
|
|
|
|
|
test_data = { |
|
|
"inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", |
|
|
"parameters": { |
|
|
"prompt": "ear", |
|
|
"return_masks": False |
|
|
} |
|
|
} |
|
|
|
|
|
result = handler(test_data) |
|
|
print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") |
|
|
for inst in result['results'][0]['instances']: |
|
|
print(f" Box: {inst['box']}, Score: {inst['score']:.3f}") |