File size: 8,401 Bytes
79053cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8b5767
 
 
79053cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints
Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources)

For ProofPath video assessment - extracts motion features from skill demonstration videos.
"""

from typing import Dict, List, Any, Optional
import torch
import numpy as np
import base64
import io
import tempfile
import os


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize V-JEPA 2 model for video feature extraction.
        
        Args:
            path: Path to the model directory (provided by HF Inference Endpoints)
        """
        from transformers import AutoVideoProcessor, AutoModel
        
        # Always load from the official Facebook model on HuggingFace Hub
        # (path points to /repository which is our custom handler, not the model weights)
        model_id = "facebook/vjepa2-vitl-fpc64-256"
        
        # Determine device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load processor and model
        self.processor = AutoVideoProcessor.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            attn_implementation="sdpa"  # Use scaled dot product attention for efficiency
        )
        
        if not torch.cuda.is_available():
            self.model = self.model.to(self.device)
        
        self.model.eval()
        
        # Default config
        self.default_num_frames = 64  # V-JEPA 2 is trained with 64 frames
        
    def _decode_video(self, video_data: Any) -> torch.Tensor:
        """
        Decode video from various input formats.
        
        Supports:
        - Base64 encoded video bytes
        - URL to video file
        - Raw bytes
        """
        from torchcodec.decoders import VideoDecoder
        
        # Handle base64 encoded video
        if isinstance(video_data, str):
            if video_data.startswith(('http://', 'https://')):
                # URL - torchcodec can handle URLs directly
                vr = VideoDecoder(video_data)
            elif video_data.startswith('data:'):
                # Data URL format
                header, encoded = video_data.split(',', 1)
                video_bytes = base64.b64decode(encoded)
                # Write to temp file for torchcodec
                with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
                    f.write(video_bytes)
                    temp_path = f.name
                vr = VideoDecoder(temp_path)
                os.unlink(temp_path)
            else:
                # Assume base64 encoded
                video_bytes = base64.b64decode(video_data)
                with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
                    f.write(video_bytes)
                    temp_path = f.name
                vr = VideoDecoder(temp_path)
                os.unlink(temp_path)
        elif isinstance(video_data, bytes):
            with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
                f.write(video_data)
                temp_path = f.name
            vr = VideoDecoder(temp_path)
            os.unlink(temp_path)
        else:
            raise ValueError(f"Unsupported video input type: {type(video_data)}")
        
        return vr
    
    def _sample_frames(
        self, 
        video_decoder, 
        num_frames: int = 64,
        sampling_strategy: str = "uniform"
    ) -> torch.Tensor:
        """
        Sample frames from video decoder.
        
        Args:
            video_decoder: torchcodec VideoDecoder instance
            num_frames: Number of frames to sample
            sampling_strategy: "uniform" or "random"
        """
        # Get video metadata
        metadata = video_decoder.metadata
        total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000
        
        if sampling_strategy == "uniform":
            # Uniformly sample frames across the video
            if total_frames <= num_frames:
                frame_idx = np.arange(total_frames)
            else:
                frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        elif sampling_strategy == "random":
            frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False))
        else:
            # Default to sequential from start
            frame_idx = np.arange(min(num_frames, total_frames))
        
        # Get frames: returns T x C x H x W
        frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data
        
        return frames
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process video and extract V-JEPA 2 features.
        
        Expected input format:
        {
            "inputs": <base64_video_string or video_url>,
            "parameters": {
                "num_frames": 64,           # Optional: number of frames to sample
                "sampling_strategy": "uniform",  # Optional: "uniform" or "random"
                "return_predictor": true,   # Optional: also return predictor features
                "pooling": "mean"           # Optional: "mean", "cls", or "none"
            }
        }
        
        Returns:
        {
            "encoder_features": [...],      # Encoder output features
            "predictor_features": [...],    # Optional predictor features
            "feature_shape": [T, D],        # Shape of features
        }
        """
        # Extract inputs
        inputs = data.get("inputs")
        if inputs is None:
            inputs = data.get("video")
        if inputs is None:
            raise ValueError("No video input provided. Use 'inputs' or 'video' key.")
        
        # Extract parameters
        params = data.get("parameters", {})
        num_frames = params.get("num_frames", self.default_num_frames)
        sampling_strategy = params.get("sampling_strategy", "uniform")
        return_predictor = params.get("return_predictor", False)
        pooling = params.get("pooling", "mean")
        
        try:
            # Decode and sample video
            video_decoder = self._decode_video(inputs)
            frames = self._sample_frames(video_decoder, num_frames, sampling_strategy)
            
            # Process through V-JEPA 2 processor
            processed = self.processor(frames, return_tensors="pt")
            processed = {k: v.to(self.model.device) for k, v in processed.items()}
            
            # Run inference
            with torch.no_grad():
                outputs = self.model(**processed)
            
            # Extract encoder features
            encoder_features = outputs.last_hidden_state  # [batch, seq, hidden]
            
            # Apply pooling
            if pooling == "mean":
                encoder_pooled = encoder_features.mean(dim=1)  # [batch, hidden]
            elif pooling == "cls":
                encoder_pooled = encoder_features[:, 0, :]  # [batch, hidden]
            else:
                encoder_pooled = encoder_features  # [batch, seq, hidden]
            
            result = {
                "encoder_features": encoder_pooled.cpu().numpy().tolist(),
                "feature_shape": list(encoder_pooled.shape),
            }
            
            # Optionally include predictor features
            if return_predictor and hasattr(outputs, 'predictor_output'):
                predictor_features = outputs.predictor_output.last_hidden_state
                if pooling == "mean":
                    predictor_pooled = predictor_features.mean(dim=1)
                elif pooling == "cls":
                    predictor_pooled = predictor_features[:, 0, :]
                else:
                    predictor_pooled = predictor_features
                result["predictor_features"] = predictor_pooled.cpu().numpy().tolist()
                result["predictor_shape"] = list(predictor_pooled.shape)
            
            return result
            
        except Exception as e:
            return {"error": str(e), "error_type": type(e).__name__}