File size: 8,401 Bytes
79053cb c8b5767 79053cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints
Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources)
For ProofPath video assessment - extracts motion features from skill demonstration videos.
"""
from typing import Dict, List, Any, Optional
import torch
import numpy as np
import base64
import io
import tempfile
import os
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize V-JEPA 2 model for video feature extraction.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from transformers import AutoVideoProcessor, AutoModel
# Always load from the official Facebook model on HuggingFace Hub
# (path points to /repository which is our custom handler, not the model weights)
model_id = "facebook/vjepa2-vitl-fpc64-256"
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load processor and model
self.processor = AutoVideoProcessor.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
attn_implementation="sdpa" # Use scaled dot product attention for efficiency
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
# Default config
self.default_num_frames = 64 # V-JEPA 2 is trained with 64 frames
def _decode_video(self, video_data: Any) -> torch.Tensor:
"""
Decode video from various input formats.
Supports:
- Base64 encoded video bytes
- URL to video file
- Raw bytes
"""
from torchcodec.decoders import VideoDecoder
# Handle base64 encoded video
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
# URL - torchcodec can handle URLs directly
vr = VideoDecoder(video_data)
elif video_data.startswith('data:'):
# Data URL format
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
# Write to temp file for torchcodec
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
else:
# Assume base64 encoded
video_bytes = base64.b64decode(video_data)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
elif isinstance(video_data, bytes):
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_data)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
else:
raise ValueError(f"Unsupported video input type: {type(video_data)}")
return vr
def _sample_frames(
self,
video_decoder,
num_frames: int = 64,
sampling_strategy: str = "uniform"
) -> torch.Tensor:
"""
Sample frames from video decoder.
Args:
video_decoder: torchcodec VideoDecoder instance
num_frames: Number of frames to sample
sampling_strategy: "uniform" or "random"
"""
# Get video metadata
metadata = video_decoder.metadata
total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000
if sampling_strategy == "uniform":
# Uniformly sample frames across the video
if total_frames <= num_frames:
frame_idx = np.arange(total_frames)
else:
frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
elif sampling_strategy == "random":
frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False))
else:
# Default to sequential from start
frame_idx = np.arange(min(num_frames, total_frames))
# Get frames: returns T x C x H x W
frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data
return frames
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process video and extract V-JEPA 2 features.
Expected input format:
{
"inputs": <base64_video_string or video_url>,
"parameters": {
"num_frames": 64, # Optional: number of frames to sample
"sampling_strategy": "uniform", # Optional: "uniform" or "random"
"return_predictor": true, # Optional: also return predictor features
"pooling": "mean" # Optional: "mean", "cls", or "none"
}
}
Returns:
{
"encoder_features": [...], # Encoder output features
"predictor_features": [...], # Optional predictor features
"feature_shape": [T, D], # Shape of features
}
"""
# Extract inputs
inputs = data.get("inputs")
if inputs is None:
inputs = data.get("video")
if inputs is None:
raise ValueError("No video input provided. Use 'inputs' or 'video' key.")
# Extract parameters
params = data.get("parameters", {})
num_frames = params.get("num_frames", self.default_num_frames)
sampling_strategy = params.get("sampling_strategy", "uniform")
return_predictor = params.get("return_predictor", False)
pooling = params.get("pooling", "mean")
try:
# Decode and sample video
video_decoder = self._decode_video(inputs)
frames = self._sample_frames(video_decoder, num_frames, sampling_strategy)
# Process through V-JEPA 2 processor
processed = self.processor(frames, return_tensors="pt")
processed = {k: v.to(self.model.device) for k, v in processed.items()}
# Run inference
with torch.no_grad():
outputs = self.model(**processed)
# Extract encoder features
encoder_features = outputs.last_hidden_state # [batch, seq, hidden]
# Apply pooling
if pooling == "mean":
encoder_pooled = encoder_features.mean(dim=1) # [batch, hidden]
elif pooling == "cls":
encoder_pooled = encoder_features[:, 0, :] # [batch, hidden]
else:
encoder_pooled = encoder_features # [batch, seq, hidden]
result = {
"encoder_features": encoder_pooled.cpu().numpy().tolist(),
"feature_shape": list(encoder_pooled.shape),
}
# Optionally include predictor features
if return_predictor and hasattr(outputs, 'predictor_output'):
predictor_features = outputs.predictor_output.last_hidden_state
if pooling == "mean":
predictor_pooled = predictor_features.mean(dim=1)
elif pooling == "cls":
predictor_pooled = predictor_features[:, 0, :]
else:
predictor_pooled = predictor_features
result["predictor_features"] = predictor_pooled.cpu().numpy().tolist()
result["predictor_shape"] = list(predictor_pooled.shape)
return result
except Exception as e:
return {"error": str(e), "error_type": type(e).__name__}
|