""" V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources) For ProofPath video assessment - extracts motion features from skill demonstration videos. """ from typing import Dict, List, Any, Optional import torch import numpy as np import base64 import io import tempfile import os class EndpointHandler: def __init__(self, path: str = ""): """ Initialize V-JEPA 2 model for video feature extraction. Args: path: Path to the model directory (provided by HF Inference Endpoints) """ from transformers import AutoVideoProcessor, AutoModel # Always load from the official Facebook model on HuggingFace Hub # (path points to /repository which is our custom handler, not the model weights) model_id = "facebook/vjepa2-vitl-fpc64-256" # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load processor and model self.processor = AutoVideoProcessor.from_pretrained(model_id) self.model = AutoModel.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, attn_implementation="sdpa" # Use scaled dot product attention for efficiency ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() # Default config self.default_num_frames = 64 # V-JEPA 2 is trained with 64 frames def _decode_video(self, video_data: Any) -> torch.Tensor: """ Decode video from various input formats. Supports: - Base64 encoded video bytes - URL to video file - Raw bytes """ from torchcodec.decoders import VideoDecoder # Handle base64 encoded video if isinstance(video_data, str): if video_data.startswith(('http://', 'https://')): # URL - torchcodec can handle URLs directly vr = VideoDecoder(video_data) elif video_data.startswith('data:'): # Data URL format header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) # Write to temp file for torchcodec with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) temp_path = f.name vr = VideoDecoder(temp_path) os.unlink(temp_path) else: # Assume base64 encoded video_bytes = base64.b64decode(video_data) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) temp_path = f.name vr = VideoDecoder(temp_path) os.unlink(temp_path) elif isinstance(video_data, bytes): with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_data) temp_path = f.name vr = VideoDecoder(temp_path) os.unlink(temp_path) else: raise ValueError(f"Unsupported video input type: {type(video_data)}") return vr def _sample_frames( self, video_decoder, num_frames: int = 64, sampling_strategy: str = "uniform" ) -> torch.Tensor: """ Sample frames from video decoder. Args: video_decoder: torchcodec VideoDecoder instance num_frames: Number of frames to sample sampling_strategy: "uniform" or "random" """ # Get video metadata metadata = video_decoder.metadata total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000 if sampling_strategy == "uniform": # Uniformly sample frames across the video if total_frames <= num_frames: frame_idx = np.arange(total_frames) else: frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) elif sampling_strategy == "random": frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) else: # Default to sequential from start frame_idx = np.arange(min(num_frames, total_frames)) # Get frames: returns T x C x H x W frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data return frames def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process video and extract V-JEPA 2 features. Expected input format: { "inputs": , "parameters": { "num_frames": 64, # Optional: number of frames to sample "sampling_strategy": "uniform", # Optional: "uniform" or "random" "return_predictor": true, # Optional: also return predictor features "pooling": "mean" # Optional: "mean", "cls", or "none" } } Returns: { "encoder_features": [...], # Encoder output features "predictor_features": [...], # Optional predictor features "feature_shape": [T, D], # Shape of features } """ # Extract inputs inputs = data.get("inputs") if inputs is None: inputs = data.get("video") if inputs is None: raise ValueError("No video input provided. Use 'inputs' or 'video' key.") # Extract parameters params = data.get("parameters", {}) num_frames = params.get("num_frames", self.default_num_frames) sampling_strategy = params.get("sampling_strategy", "uniform") return_predictor = params.get("return_predictor", False) pooling = params.get("pooling", "mean") try: # Decode and sample video video_decoder = self._decode_video(inputs) frames = self._sample_frames(video_decoder, num_frames, sampling_strategy) # Process through V-JEPA 2 processor processed = self.processor(frames, return_tensors="pt") processed = {k: v.to(self.model.device) for k, v in processed.items()} # Run inference with torch.no_grad(): outputs = self.model(**processed) # Extract encoder features encoder_features = outputs.last_hidden_state # [batch, seq, hidden] # Apply pooling if pooling == "mean": encoder_pooled = encoder_features.mean(dim=1) # [batch, hidden] elif pooling == "cls": encoder_pooled = encoder_features[:, 0, :] # [batch, hidden] else: encoder_pooled = encoder_features # [batch, seq, hidden] result = { "encoder_features": encoder_pooled.cpu().numpy().tolist(), "feature_shape": list(encoder_pooled.shape), } # Optionally include predictor features if return_predictor and hasattr(outputs, 'predictor_output'): predictor_features = outputs.predictor_output.last_hidden_state if pooling == "mean": predictor_pooled = predictor_features.mean(dim=1) elif pooling == "cls": predictor_pooled = predictor_features[:, 0, :] else: predictor_pooled = predictor_features result["predictor_features"] = predictor_pooled.cpu().numpy().tolist() result["predictor_shape"] = list(predictor_pooled.shape) return result except Exception as e: return {"error": str(e), "error_type": type(e).__name__}