|
|
""" |
|
|
V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources) |
|
|
|
|
|
For ProofPath video assessment - extracts motion features from skill demonstration videos. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Optional |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize V-JEPA 2 model for video feature extraction. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (provided by HF Inference Endpoints) |
|
|
""" |
|
|
from transformers import AutoVideoProcessor, AutoModel |
|
|
|
|
|
|
|
|
|
|
|
model_id = "facebook/vjepa2-vitl-fpc64-256" |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.processor = AutoVideoProcessor.from_pretrained(model_id) |
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
attn_implementation="sdpa" |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.default_num_frames = 64 |
|
|
|
|
|
def _decode_video(self, video_data: Any) -> torch.Tensor: |
|
|
""" |
|
|
Decode video from various input formats. |
|
|
|
|
|
Supports: |
|
|
- Base64 encoded video bytes |
|
|
- URL to video file |
|
|
- Raw bytes |
|
|
""" |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
|
|
|
|
|
|
if isinstance(video_data, str): |
|
|
if video_data.startswith(('http://', 'https://')): |
|
|
|
|
|
vr = VideoDecoder(video_data) |
|
|
elif video_data.startswith('data:'): |
|
|
|
|
|
header, encoded = video_data.split(',', 1) |
|
|
video_bytes = base64.b64decode(encoded) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
temp_path = f.name |
|
|
vr = VideoDecoder(temp_path) |
|
|
os.unlink(temp_path) |
|
|
else: |
|
|
|
|
|
video_bytes = base64.b64decode(video_data) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
temp_path = f.name |
|
|
vr = VideoDecoder(temp_path) |
|
|
os.unlink(temp_path) |
|
|
elif isinstance(video_data, bytes): |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_data) |
|
|
temp_path = f.name |
|
|
vr = VideoDecoder(temp_path) |
|
|
os.unlink(temp_path) |
|
|
else: |
|
|
raise ValueError(f"Unsupported video input type: {type(video_data)}") |
|
|
|
|
|
return vr |
|
|
|
|
|
def _sample_frames( |
|
|
self, |
|
|
video_decoder, |
|
|
num_frames: int = 64, |
|
|
sampling_strategy: str = "uniform" |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Sample frames from video decoder. |
|
|
|
|
|
Args: |
|
|
video_decoder: torchcodec VideoDecoder instance |
|
|
num_frames: Number of frames to sample |
|
|
sampling_strategy: "uniform" or "random" |
|
|
""" |
|
|
|
|
|
metadata = video_decoder.metadata |
|
|
total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000 |
|
|
|
|
|
if sampling_strategy == "uniform": |
|
|
|
|
|
if total_frames <= num_frames: |
|
|
frame_idx = np.arange(total_frames) |
|
|
else: |
|
|
frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
|
|
elif sampling_strategy == "random": |
|
|
frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) |
|
|
else: |
|
|
|
|
|
frame_idx = np.arange(min(num_frames, total_frames)) |
|
|
|
|
|
|
|
|
frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data |
|
|
|
|
|
return frames |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process video and extract V-JEPA 2 features. |
|
|
|
|
|
Expected input format: |
|
|
{ |
|
|
"inputs": <base64_video_string or video_url>, |
|
|
"parameters": { |
|
|
"num_frames": 64, # Optional: number of frames to sample |
|
|
"sampling_strategy": "uniform", # Optional: "uniform" or "random" |
|
|
"return_predictor": true, # Optional: also return predictor features |
|
|
"pooling": "mean" # Optional: "mean", "cls", or "none" |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"encoder_features": [...], # Encoder output features |
|
|
"predictor_features": [...], # Optional predictor features |
|
|
"feature_shape": [T, D], # Shape of features |
|
|
} |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs") |
|
|
if inputs is None: |
|
|
inputs = data.get("video") |
|
|
if inputs is None: |
|
|
raise ValueError("No video input provided. Use 'inputs' or 'video' key.") |
|
|
|
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
num_frames = params.get("num_frames", self.default_num_frames) |
|
|
sampling_strategy = params.get("sampling_strategy", "uniform") |
|
|
return_predictor = params.get("return_predictor", False) |
|
|
pooling = params.get("pooling", "mean") |
|
|
|
|
|
try: |
|
|
|
|
|
video_decoder = self._decode_video(inputs) |
|
|
frames = self._sample_frames(video_decoder, num_frames, sampling_strategy) |
|
|
|
|
|
|
|
|
processed = self.processor(frames, return_tensors="pt") |
|
|
processed = {k: v.to(self.model.device) for k, v in processed.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processed) |
|
|
|
|
|
|
|
|
encoder_features = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
if pooling == "mean": |
|
|
encoder_pooled = encoder_features.mean(dim=1) |
|
|
elif pooling == "cls": |
|
|
encoder_pooled = encoder_features[:, 0, :] |
|
|
else: |
|
|
encoder_pooled = encoder_features |
|
|
|
|
|
result = { |
|
|
"encoder_features": encoder_pooled.cpu().numpy().tolist(), |
|
|
"feature_shape": list(encoder_pooled.shape), |
|
|
} |
|
|
|
|
|
|
|
|
if return_predictor and hasattr(outputs, 'predictor_output'): |
|
|
predictor_features = outputs.predictor_output.last_hidden_state |
|
|
if pooling == "mean": |
|
|
predictor_pooled = predictor_features.mean(dim=1) |
|
|
elif pooling == "cls": |
|
|
predictor_pooled = predictor_features[:, 0, :] |
|
|
else: |
|
|
predictor_pooled = predictor_features |
|
|
result["predictor_features"] = predictor_pooled.cpu().numpy().tolist() |
|
|
result["predictor_shape"] = list(predictor_pooled.shape) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "error_type": type(e).__name__} |
|
|
|