Sylvest commited on
Commit
16a7dbd
·
verified ·
1 Parent(s): 9950f88

Upload modeling_prismatic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +1085 -0
modeling_prismatic.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ )
37
+
38
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
39
+
40
+ # Set up logger
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # === Utility Functions for Monkey-Patching ===
45
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
46
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
47
+ result = fn(*args, **kwargs)
48
+ return result[0] if isinstance(result, tuple) else result
49
+
50
+ return wrapper
51
+
52
+
53
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
54
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
55
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
56
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
58
+
59
+
60
+ def ls_apply_patch(ls_module: LayerScale):
61
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
62
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
63
+ del ls_module.gamma
64
+
65
+
66
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
67
+ class PrismaticVisionBackbone(nn.Module):
68
+ """
69
+ Vision backbone for Prismatic models that handles image feature extraction.
70
+
71
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72
+ For fused backbones, features from both models are concatenated along the feature dimension.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ use_fused_vision_backbone: bool,
78
+ image_sizes: List[int],
79
+ timm_model_ids: List[str],
80
+ timm_override_act_layers: List[Optional[str]],
81
+ ) -> None:
82
+ """
83
+ Initialize the vision backbone.
84
+
85
+ Args:
86
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
87
+ image_sizes: List of image sizes for each backbone
88
+ timm_model_ids: List of TIMM model IDs to use for each backbone
89
+ timm_override_act_layers: List of activation layer overrides for each backbone
90
+ """
91
+ super().__init__()
92
+ self.use_fused_vision_backbone = use_fused_vision_backbone
93
+ self.num_images_in_input = 1 # Default value, can be overridden later
94
+
95
+ # Validate number of (fused) vision backbones
96
+ if len(timm_model_ids) > 2:
97
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
98
+
99
+ # Create primary featurizer
100
+ self.featurizer = self._create_featurizer(
101
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
102
+ )
103
+ self.embed_dim = self.featurizer.embed_dim
104
+
105
+ # Create secondary featurizer if using fused backbone
106
+ if self.use_fused_vision_backbone:
107
+ self.fused_featurizer = self._create_featurizer(
108
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
109
+ )
110
+ self.embed_dim += self.fused_featurizer.embed_dim
111
+
112
+ # Patch LayerScale modules for HF compatibility
113
+ self._patch_layer_scales()
114
+
115
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
116
+ """
117
+ Create a TIMM-based featurizer model with appropriate configurations.
118
+
119
+ Args:
120
+ model_id: The TIMM model ID to load
121
+ img_size: Input image size for the model
122
+ act_layer: Override for the activation layer type
123
+
124
+ Returns:
125
+ A configured featurizer model
126
+ """
127
+ featurizer = timm.create_model(
128
+ model_id,
129
+ pretrained=False,
130
+ num_classes=0,
131
+ img_size=img_size,
132
+ act_layer=act_layer,
133
+ )
134
+
135
+ # Monkey-patch the forward function to extract the second-to-last layer features
136
+ num_blocks = len(featurizer.blocks)
137
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
138
+
139
+ return featurizer
140
+
141
+ def _patch_layer_scales(self) -> None:
142
+ """
143
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
144
+
145
+ HF Transformers overwrites parameters with names containing 'gamma',
146
+ so we need to rename and modify the forward method.
147
+ """
148
+ # Patch primary featurizer
149
+ for module in self.featurizer.modules():
150
+ if isinstance(module, LayerScale):
151
+ ls_apply_patch(module)
152
+
153
+ # Patch secondary featurizer if it exists
154
+ if self.use_fused_vision_backbone:
155
+ for module in self.fused_featurizer.modules():
156
+ if isinstance(module, LayerScale):
157
+ ls_apply_patch(module)
158
+
159
+ def get_num_patches(self) -> int:
160
+ """
161
+ Returns the number of vision patches output by the vision backbone.
162
+
163
+ Returns:
164
+ Number of patches per image
165
+ """
166
+ return self.featurizer.patch_embed.num_patches
167
+
168
+ def get_num_images_in_input(self) -> int:
169
+ """
170
+ Returns the number of input images for the vision backbone.
171
+
172
+ Returns:
173
+ Number of images expected in the input
174
+ """
175
+ return self.num_images_in_input
176
+
177
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
178
+ """
179
+ Sets the number of input images for the vision backbone.
180
+
181
+ Args:
182
+ num_images_in_input: Number of images to expect in the input
183
+ """
184
+ self.num_images_in_input = num_images_in_input
185
+
186
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Implements the forward pass for the vision backbone.
189
+
190
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
191
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
192
+
193
+ Args:
194
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
195
+ """
196
+ if self.num_images_in_input == 1:
197
+ if not self.use_fused_vision_backbone:
198
+ return self.featurizer(pixel_values)
199
+
200
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
201
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
202
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
203
+
204
+ return torch.cat([patches, patches_fused], dim=2)
205
+
206
+ else:
207
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
208
+
209
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
210
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
211
+
212
+ # Process each image and collect patches
213
+ all_patches = []
214
+ for img in images:
215
+ # Split each image further into two stacks of channels (each with 3 channels)
216
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
217
+
218
+ # Get patches from both SigLIP and DINOv2 vision transformers
219
+ patches = self.featurizer(img_regular)
220
+ patches_fused = self.fused_featurizer(img_fused)
221
+
222
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
223
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
224
+ all_patches.append(combined_patches)
225
+
226
+ # Concatenate all patches along the patch dimension
227
+ return torch.cat(all_patches, dim=1)
228
+
229
+
230
+ # === Prismatic Projector (nn.Module) Definitions ===
231
+ class PrismaticProjector(nn.Module):
232
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
233
+ super().__init__()
234
+ self.use_fused_vision_backbone = use_fused_vision_backbone
235
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
236
+
237
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
238
+ if not self.use_fused_vision_backbone:
239
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
240
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
241
+ self.act_fn1 = nn.GELU()
242
+ else:
243
+ initial_projection_dim = 4 * vision_dim
244
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
245
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
246
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
247
+ self.act_fn1 = nn.GELU()
248
+ self.act_fn2 = nn.GELU()
249
+
250
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
251
+ if not self.use_fused_vision_backbone:
252
+ projected_features = self.fc1(img_patches)
253
+ projected_features = self.act_fn1(projected_features)
254
+ projected_features = self.fc2(projected_features)
255
+ else:
256
+ projected_features = self.fc1(img_patches)
257
+ projected_features = self.act_fn1(projected_features)
258
+ projected_features = self.fc2(projected_features)
259
+ projected_features = self.act_fn2(projected_features)
260
+ projected_features = self.fc3(projected_features)
261
+
262
+ return projected_features
263
+
264
+
265
+ # === Main HF Class Definitions ===
266
+ @dataclass
267
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
268
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
269
+
270
+ loss: Optional[torch.FloatTensor] = None
271
+ logits: torch.FloatTensor = None
272
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
273
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
274
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
275
+
276
+ # Additions for VLMs
277
+ projector_features: Optional[torch.FloatTensor] = None
278
+
279
+
280
+ class PrismaticPreTrainedModel(PreTrainedModel):
281
+ config_class: PretrainedConfig = PrismaticConfig
282
+ base_model_prefix: str = "model"
283
+ supports_gradient_checkpointing: bool = True
284
+
285
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
286
+ _skip_keys_device_placement: str = "past_key_values"
287
+ _supports_flash_attn_2: bool = True
288
+
289
+ def _init_weights(self, module: nn.Module) -> None:
290
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
291
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
292
+ # https://github.com/TRI-ML/prismatic-vlms
293
+ std = (
294
+ self.config.initializer_range
295
+ if hasattr(self.config, "initializer_range")
296
+ else self.config.text_config.initializer_range
297
+ )
298
+
299
+ if hasattr(module, "class_embedding"):
300
+ module.class_embedding.data.normal_(mean=0.0, std=std)
301
+
302
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
303
+ module.weight.data.normal_(mean=0.0, std=std)
304
+ if module.bias is not None:
305
+ module.bias.data.zero_()
306
+ elif isinstance(module, nn.Embedding):
307
+ module.weight.data.normal_(mean=0.0, std=std)
308
+ if module.padding_idx is not None:
309
+ module.weight.data[module.padding_idx].zero_()
310
+
311
+ @property
312
+ def _supports_sdpa(self) -> bool:
313
+ """Check LLM supports SDPA Attention"""
314
+ return self.language_model._supports_sdpa
315
+
316
+
317
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
318
+ def __init__(self, config: PrismaticConfig) -> None:
319
+ super().__init__(config)
320
+
321
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
322
+ if config.use_fused_vision_backbone is None:
323
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
324
+
325
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
326
+ raise NotImplementedError(
327
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
328
+ "if you urgently need support for latest TIMM versions."
329
+ )
330
+
331
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
332
+ logger.warning(
333
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
334
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
335
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
336
+ f"use the above versions."
337
+ )
338
+
339
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
340
+ self.vision_backbone = PrismaticVisionBackbone(
341
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
342
+ )
343
+
344
+ # Create Multimodal Projector
345
+ self.projector = PrismaticProjector(
346
+ config.use_fused_vision_backbone,
347
+ vision_dim=self.vision_backbone.embed_dim,
348
+ llm_dim=config.text_config.hidden_size,
349
+ )
350
+
351
+ # Instantiate LLM Backbone
352
+ self.language_model = AutoModelForCausalLM.from_config(
353
+ config.text_config, attn_implementation=config._attn_implementation
354
+ )
355
+ self.vocab_size = config.text_config.vocab_size
356
+ self.pad_token_id = config.pad_token_id
357
+ self.llm_dim = config.text_config.hidden_size
358
+
359
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
360
+ self.post_init()
361
+
362
+ # === `PreTrainedModel` Boilerplate ===
363
+ def get_input_embeddings(self) -> nn.Module:
364
+ return self.language_model.get_input_embeddings()
365
+
366
+ def set_input_embeddings(self, value: nn.Module) -> None:
367
+ self.language_model.set_input_embeddings(value)
368
+
369
+ def get_output_embeddings(self) -> nn.Module:
370
+ return self.language_model.get_output_embeddings()
371
+
372
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
373
+ self.language_model.set_output_embeddings(new_embeddings)
374
+
375
+ def get_decoder(self) -> nn.Module:
376
+ return self.language_model.get_decoder()
377
+
378
+ def set_decoder(self, decoder: nn.Module) -> None:
379
+ self.language_model.set_decoder(decoder)
380
+
381
+ def tie_weights(self) -> None:
382
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
383
+
384
+ def resize_token_embeddings(
385
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
386
+ ) -> nn.Embedding:
387
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
+
389
+ # Update config/instance variables
390
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
391
+ self.vocab_size = updated_embeddings.num_embeddings
392
+
393
+ return updated_embeddings
394
+
395
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
396
+ """
397
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
398
+ with embeddings from noisy_action_features, using vectorized operations.
399
+
400
+ Args:
401
+ input_embeddings: Tensor of shape (B, S, D)
402
+ all_actions_mask: Boolean tensor of shape (B, S)
403
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
404
+
405
+ Returns:
406
+ Modified input_embeddings tensor
407
+ """
408
+ # Clone input to avoid modifying the original tensor
409
+ new_input_embeddings = input_embeddings.clone()
410
+
411
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
412
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
413
+
414
+ # Create batch indices for splicing
415
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
416
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
417
+
418
+ # Get indices where mask is True for each sample
419
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
420
+
421
+ # Move the noisy action features into their correct positions
422
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
423
+
424
+ # Combine original input embeddings and noisy action embeddings using the mask
425
+ new_input_embeddings = torch.where(
426
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
427
+ )
428
+
429
+ return new_input_embeddings
430
+
431
+ def _process_action_masks(self, labels):
432
+ """Helper to get action masks from labels"""
433
+ current_action_mask = get_current_action_mask(labels)
434
+ next_actions_mask = get_next_actions_mask(labels)
435
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
436
+ return all_actions_mask
437
+
438
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
439
+ """Process vision features with optional FiLM conditioning"""
440
+ if use_film:
441
+ # FiLM: Infuse language inputs into visual features
442
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
443
+ else:
444
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
445
+
446
+ # Project patch embeddings into language embedding space
447
+ return self.projector(patch_features)
448
+
449
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
450
+ """Process proprioceptive features and append to vision features"""
451
+ if proprio_projector is not None and proprio is not None:
452
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
453
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
454
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
455
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
456
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
457
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
458
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
459
+ return projected_patch_embeddings
460
+
461
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
462
+ """Build multimodal embeddings and attention mask"""
463
+ # Update attention mask
464
+ projected_patch_attention_mask = None
465
+ if attention_mask is not None:
466
+ projected_patch_attention_mask = torch.full(
467
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
468
+ fill_value=True,
469
+ dtype=attention_mask.dtype,
470
+ device=attention_mask.device,
471
+ )
472
+
473
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
474
+ multimodal_embeddings = torch.cat(
475
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
476
+ )
477
+
478
+ multimodal_attention_mask = None
479
+ if attention_mask is not None:
480
+ multimodal_attention_mask = torch.cat(
481
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
482
+ )
483
+
484
+ return multimodal_embeddings, multimodal_attention_mask
485
+
486
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
487
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
488
+ if labels is not None:
489
+ projected_patch_labels = torch.full(
490
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
491
+ fill_value=IGNORE_INDEX,
492
+ dtype=labels.dtype,
493
+ device=labels.device,
494
+ )
495
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
496
+ return None
497
+
498
+ # === Core Prismatic VLM `forward()` Logic ===
499
+ def forward(
500
+ self,
501
+ input_ids: Optional[torch.LongTensor] = None,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ pixel_values: Optional[torch.FloatTensor] = None,
504
+ labels: Optional[torch.LongTensor] = None,
505
+ inputs_embeds: Optional[torch.FloatTensor] = None,
506
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ output_projector_features: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ proprio=None,
513
+ proprio_projector=None,
514
+ noisy_actions=None,
515
+ noisy_action_projector=None,
516
+ diffusion_timestep_embeddings=None,
517
+ use_film: bool = False,
518
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
519
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
520
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
+ output_hidden_states = (
522
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
+ )
524
+ output_projector_features = output_projector_features if output_projector_features is not None else False
525
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
+
527
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
528
+ use_cache = use_cache and not self.training
529
+
530
+ # Instantiate Placeholder for Projector Features
531
+ projected_patch_embeddings = None
532
+
533
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
534
+ if input_ids.shape[1] == 1:
535
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
536
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
537
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
538
+
539
+ language_model_output = self.language_model(
540
+ input_ids=input_ids,
541
+ attention_mask=None,
542
+ position_ids=None,
543
+ past_key_values=past_key_values,
544
+ inputs_embeds=None,
545
+ labels=None,
546
+ use_cache=use_cache,
547
+ output_attentions=output_attentions,
548
+ output_hidden_states=output_hidden_states,
549
+ return_dict=return_dict,
550
+ )
551
+
552
+ # === Handle Unimodal Forward ===
553
+ elif pixel_values is None:
554
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
555
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
556
+
557
+ language_model_output = self.language_model(
558
+ input_ids=input_ids,
559
+ attention_mask=attention_mask,
560
+ position_ids=None,
561
+ past_key_values=None,
562
+ inputs_embeds=None,
563
+ labels=labels,
564
+ use_cache=use_cache,
565
+ output_attentions=output_attentions,
566
+ output_hidden_states=output_hidden_states,
567
+ return_dict=return_dict,
568
+ )
569
+
570
+ # === Handle Multimodal Forward ===
571
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
572
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
573
+
574
+ # Get input embeddings (from language model embeddings)
575
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
576
+
577
+ # Extract action masks
578
+ all_actions_mask = self._process_action_masks(labels)
579
+
580
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
581
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
582
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
583
+ ) # (B, lang_seq_len, llm_dim)
584
+
585
+ # Get visual features
586
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
587
+
588
+ # Add proprioceptive state if provided
589
+ projected_patch_embeddings = self._process_proprio_features(
590
+ projected_patch_embeddings, proprio, proprio_projector
591
+ )
592
+
593
+ # [Diffusion] Add diffusion timestep embedding if provided
594
+ if diffusion_timestep_embeddings is not None:
595
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
596
+ projected_patch_embeddings = torch.cat(
597
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
598
+ )
599
+
600
+ # Process action embeddings
601
+ if noisy_actions is not None:
602
+ # Get mask corresponding to all action tokens
603
+ all_actions_mask = self._process_action_masks(labels)
604
+
605
+ # Reshape noisy actions into individual action tokens
606
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
607
+ B = noisy_actions.shape[0]
608
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
609
+
610
+ # Project noisy action tokens into language model embedding space
611
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
612
+
613
+ # Replace embeddings of the action tokens with noisy action embeddings
614
+ input_embeddings = self._replace_input_embeddings(
615
+ input_embeddings, all_actions_mask, noisy_action_features
616
+ )
617
+ else:
618
+ # Replace the embeddings of the action tokens with zeros
619
+ # (Later on, the positional embeddings will be added to them)
620
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
621
+ input_embeddings = input_embeddings * ~all_actions_mask
622
+
623
+ # Build multimodal embeddings & attention mask
624
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
625
+ input_embeddings, projected_patch_embeddings, attention_mask
626
+ )
627
+
628
+ # Build labels for multimodal sequence if needed
629
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
630
+
631
+ # Dispatch to language model
632
+ language_model_output = self.language_model(
633
+ input_ids=None,
634
+ attention_mask=multimodal_attention_mask,
635
+ position_ids=None,
636
+ past_key_values=None,
637
+ inputs_embeds=multimodal_embeddings,
638
+ labels=multimodal_labels,
639
+ use_cache=use_cache,
640
+ output_attentions=output_attentions,
641
+ output_hidden_states=output_hidden_states,
642
+ return_dict=return_dict,
643
+ )
644
+
645
+ # === Otherwise =>> Assume Invalid! ===
646
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
647
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
648
+
649
+ else:
650
+ raise ValueError(
651
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
652
+ f"=> `input_ids` = {input_ids is not None}\n"
653
+ f"=> `attention_mask` = {attention_mask is not None}\n"
654
+ f"=> `pixel_values` = {pixel_values is not None}\n"
655
+ f"=> `labels` = {labels is not None}\n"
656
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
657
+ f"=> `past_key_values` = {past_key_values is not None}\n"
658
+ f"=> `use_cache` = {use_cache}"
659
+ )
660
+
661
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
662
+ if not return_dict:
663
+ if output_projector_features and (projected_patch_embeddings is not None):
664
+ return *language_model_output, projected_patch_embeddings
665
+
666
+ return language_model_output
667
+
668
+ return PrismaticCausalLMOutputWithPast(
669
+ loss=language_model_output.loss,
670
+ logits=language_model_output.logits,
671
+ past_key_values=language_model_output.past_key_values,
672
+ hidden_states=language_model_output.hidden_states,
673
+ attentions=language_model_output.attentions,
674
+ projector_features=projected_patch_embeddings,
675
+ )
676
+
677
+ # === GenerationMixin Methods ===
678
+ def prepare_inputs_for_generation(
679
+ self,
680
+ input_ids: Optional[torch.Tensor] = None,
681
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
682
+ inputs_embeds: Optional[torch.FloatTensor] = None,
683
+ pixel_values: Optional[torch.FloatTensor] = None,
684
+ attention_mask: Optional[torch.Tensor] = None,
685
+ **kwargs: str,
686
+ ) -> Dict[str, torch.Tensor]:
687
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
688
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
689
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
690
+ ):
691
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
692
+
693
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
694
+ if past_key_values is not None:
695
+ input_ids = input_ids[:, -1:]
696
+
697
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
698
+ if inputs_embeds is not None and past_key_values is None:
699
+ model_inputs = {"input_embeds": inputs_embeds}
700
+ else:
701
+ model_inputs = {"input_ids": input_ids}
702
+
703
+ # Make sure `pixel_values` are preserved in `model_inputs`
704
+ model_inputs.update(
705
+ {
706
+ "attention_mask": attention_mask,
707
+ "pixel_values": pixel_values,
708
+ "past_key_values": past_key_values,
709
+ "use_cache": kwargs.get("use_cache"),
710
+ }
711
+ )
712
+
713
+ return model_inputs
714
+
715
+ # Defer to Language Model (all handle this differently, with different return types)
716
+ def _reorder_cache(self, *args, **kwargs) -> Any:
717
+ return self.language_model._reorder_cache(*args, **kwargs)
718
+
719
+
720
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
721
+ config_class: PretrainedConfig = OpenVLAConfig
722
+
723
+ def __init__(self, config: OpenVLAConfig) -> None:
724
+ super().__init__(config)
725
+ self.norm_stats = config.norm_stats
726
+
727
+ # Compute action bins
728
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
729
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
730
+
731
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
732
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
733
+
734
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
735
+ """Prepares input for action prediction by adding necessary tokens"""
736
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
737
+ placeholder_action_token_ids = (
738
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
739
+ )
740
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
741
+
742
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
743
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
744
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
745
+
746
+ # Extend the attention mask to fit the new shape of input
747
+ # Note: Only batch size == 1 supported right now
748
+ mask_extension = (
749
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
750
+ .to(attention_mask.device)
751
+ .to(attention_mask.dtype)
752
+ )
753
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
754
+
755
+ return input_ids, attention_mask
756
+
757
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
758
+ """Creates labels tensor for action prediction if not provided"""
759
+ # Extend labels tensor with fake action labels
760
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
761
+ labels_extension = (
762
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
763
+ * ARBITRARY_ACTION_TOKEN_IDX
764
+ )
765
+ labels = torch.cat([labels, labels_extension], dim=-1)
766
+
767
+ # Replace last label token with stop token
768
+ labels[:, -1] = STOP_INDEX
769
+
770
+ return labels
771
+
772
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
773
+ """Unnormalize actions using dataset statistics"""
774
+ action_norm_stats = self.get_action_stats(unnorm_key)
775
+
776
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
777
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
778
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
779
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
780
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
781
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
782
+ else:
783
+ raise ValueError("Unsupported action/proprio normalization type detected!")
784
+
785
+ actions = np.where(
786
+ mask,
787
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
788
+ normalized_actions,
789
+ )
790
+
791
+ return actions
792
+
793
+ def _run_diffusion_prediction(
794
+ self,
795
+ input_embeddings,
796
+ all_actions_mask,
797
+ noise,
798
+ action_head,
799
+ projected_patch_embeddings,
800
+ labels,
801
+ attention_mask,
802
+ NUM_PATCHES,
803
+ NUM_PROMPT_TOKENS,
804
+ noisy_action_projector,
805
+ ):
806
+ """Run diffusion-based action prediction"""
807
+ # Clone embedding for reuse in each timestep
808
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809
+ curr_noisy_actions = noise
810
+
811
+ # Reverse diffusion: Iteratively denoise to generate action prediction
812
+ for t in action_head.noise_scheduler.timesteps:
813
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814
+ # embedding, and diffusion timestep embedding)
815
+ timesteps = torch.Tensor([t]).to(labels.device)
816
+ diffusion_timestep_embeddings = (
817
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818
+ ) # (B, llm_dim)
819
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
+
821
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822
+ # (Later on, the positional embeddings will be added to them)
823
+
824
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825
+ projected_patch_embeddings = torch.cat(
826
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827
+ )
828
+
829
+ # Reshape and project noisy actions into language embedding space
830
+ B = curr_noisy_actions.shape[0]
831
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
834
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
+
836
+ # Replace action token embeddings with noisy action embeddings
837
+ input_embeddings = self._replace_input_embeddings(
838
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
839
+ )
840
+
841
+ # Build multimodal embeddings and attention mask
842
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843
+ input_embeddings, projected_patch_embeddings, attention_mask
844
+ )
845
+
846
+ # Forward pass through language model
847
+ language_model_output = self.language_model(
848
+ input_ids=None,
849
+ attention_mask=multimodal_attention_mask,
850
+ position_ids=None,
851
+ past_key_values=None,
852
+ inputs_embeds=multimodal_embeddings,
853
+ labels=None,
854
+ use_cache=None,
855
+ output_attentions=False,
856
+ output_hidden_states=True,
857
+ return_dict=True,
858
+ )
859
+
860
+ # Extract hidden states for action portion of response
861
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862
+ actions_hidden_states = last_hidden_states[
863
+ :,
864
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865
+ :,
866
+ ] # (B, act_chunk_len, D)
867
+
868
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
869
+ noise_pred = action_head.predict_noise(actions_hidden_states)
870
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
+
872
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
+
874
+ # Return final actions
875
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
+
877
+ def _regression_or_discrete_prediction(
878
+ self,
879
+ input_embeddings,
880
+ all_actions_mask,
881
+ projected_patch_embeddings,
882
+ attention_mask,
883
+ labels,
884
+ NUM_PATCHES,
885
+ NUM_PROMPT_TOKENS,
886
+ action_head=None,
887
+ ):
888
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889
+ # Zero out action token embeddings
890
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
891
+ input_embeddings = input_embeddings * ~all_actions_mask
892
+
893
+ # Build multimodal embeddings and attention mask
894
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
895
+ input_embeddings, projected_patch_embeddings, attention_mask
896
+ )
897
+
898
+ # Forward pass through language model
899
+ language_model_output = self.language_model(
900
+ input_ids=None,
901
+ attention_mask=multimodal_attention_mask,
902
+ position_ids=None,
903
+ past_key_values=None,
904
+ inputs_embeds=multimodal_embeddings,
905
+ labels=None,
906
+ use_cache=None,
907
+ output_attentions=False,
908
+ output_hidden_states=True,
909
+ return_dict=True,
910
+ )
911
+
912
+ # Extract hidden states for action tokens
913
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
914
+ actions_hidden_states = last_hidden_states[
915
+ :,
916
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
917
+ :,
918
+ ] # (B, act_chunk_len, D)
919
+
920
+ # Handle different prediction methods
921
+ if action_head is not None:
922
+ # L1 regression prediction
923
+ normalized_actions = action_head.predict_action(actions_hidden_states)
924
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
925
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
926
+ else:
927
+ # Discrete token-based prediction
928
+ predicted_action_token_ids = (
929
+ language_model_output.logits[
930
+ :,
931
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
932
+ ]
933
+ .argmax(dim=2)
934
+ .cpu()
935
+ .numpy()
936
+ )
937
+ discretized_actions = self.vocab_size - predicted_action_token_ids
938
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
939
+ normalized_actions = self.bin_centers[discretized_actions]
940
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
941
+
942
+ return normalized_actions, actions_hidden_states
943
+
944
+ def predict_action(
945
+ self,
946
+ input_ids: Optional[torch.LongTensor] = None,
947
+ unnorm_key: Optional[str] = None,
948
+ proprio=None,
949
+ proprio_projector=None,
950
+ action_head=None,
951
+ noisy_action_projector=None,
952
+ use_film: bool = False,
953
+ **kwargs: str,
954
+ ) -> np.ndarray:
955
+ """Predict actions from input sequence, with options for different prediction methods.
956
+
957
+ Args:
958
+ input_ids: Input token ids
959
+ unnorm_key: Key for unnormalization statistics
960
+ proprio: Proprioceptive features
961
+ proprio_projector: Projector for proprioceptive features
962
+ action_head: Optional head for L1 regression or diffusion-based prediction
963
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
964
+ use_film: Whether to use FiLM conditioning
965
+ **kwargs: Additional arguments including pixel_values and attention_mask
966
+
967
+ Returns:
968
+ Tuple of (unnormalized_actions, action_hidden_states)
969
+ """
970
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
971
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
972
+ if not torch.all(input_ids[:, -1] == 29871):
973
+ input_ids = torch.cat(
974
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
975
+ )
976
+
977
+ pixel_values = kwargs["pixel_values"]
978
+ attention_mask = kwargs["attention_mask"]
979
+
980
+ # Create fake labels tensor (needed for action mask)
981
+ labels = input_ids.clone()
982
+ labels[:] = IGNORE_INDEX
983
+
984
+ # Get number of tokens in prompt (excluding the start token)
985
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
986
+
987
+ # Prepare inputs by adding necessary tokens
988
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
989
+
990
+ # Update labels tensor for action mask computation later
991
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
992
+
993
+ # Get input embeddings and action masks
994
+ input_embeddings = self.get_input_embeddings()(input_ids)
995
+ all_actions_mask = self._process_action_masks(labels)
996
+
997
+ # Extract language embeddings
998
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
999
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1000
+ )
1001
+
1002
+ # Process vision features
1003
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1004
+
1005
+ # Add proprioceptive features if provided
1006
+ use_proprio = proprio_projector is not None and proprio is not None
1007
+ if use_proprio:
1008
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1009
+ projected_patch_embeddings = self._process_proprio_features(
1010
+ projected_patch_embeddings, proprio, proprio_projector
1011
+ )
1012
+
1013
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1014
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1015
+
1016
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1017
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1018
+ if use_proprio:
1019
+ NUM_PATCHES += 1
1020
+ if use_diffusion:
1021
+ NUM_PATCHES += 1
1022
+
1023
+ if use_diffusion:
1024
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1025
+ noise = torch.randn(
1026
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1027
+ )
1028
+
1029
+ # Run diffusion-based prediction
1030
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1031
+ input_embeddings,
1032
+ all_actions_mask,
1033
+ noise,
1034
+ action_head,
1035
+ projected_patch_embeddings,
1036
+ labels,
1037
+ attention_mask,
1038
+ NUM_PATCHES,
1039
+ NUM_PROMPT_TOKENS,
1040
+ noisy_action_projector,
1041
+ )
1042
+ else:
1043
+ # Run regression or discrete token-based prediction
1044
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1045
+ input_embeddings,
1046
+ all_actions_mask,
1047
+ projected_patch_embeddings,
1048
+ attention_mask,
1049
+ labels,
1050
+ NUM_PATCHES,
1051
+ NUM_PROMPT_TOKENS,
1052
+ action_head,
1053
+ )
1054
+
1055
+ # Unnormalize predicted actions
1056
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1057
+
1058
+ return actions, actions_hidden_states
1059
+
1060
+ @staticmethod
1061
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1062
+ """Validate and resolve the unnormalization key for action statistics"""
1063
+ if unnorm_key is None:
1064
+ assert len(norm_stats) == 1, (
1065
+ f"Your model was trained on more than one dataset, "
1066
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1067
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1068
+ )
1069
+ unnorm_key = next(iter(norm_stats.keys()))
1070
+
1071
+ assert unnorm_key in norm_stats, (
1072
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1073
+ f"please choose from: {norm_stats.keys()}"
1074
+ )
1075
+ return unnorm_key
1076
+
1077
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1078
+ """Get the dimensionality of the policy's action space."""
1079
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1080
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1081
+
1082
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1083
+ """Get all the logged statistics for the given dataset."""
1084
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1085
+ return self.norm_stats[unnorm_key]["action"]