Emrys-Hong
commited on
Commit
·
1b2ebf2
1
Parent(s):
ae16192
Update
Browse files- modeling_prismatic.py +5 -8
modeling_prismatic.py
CHANGED
|
@@ -541,7 +541,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 541 |
return actions, generated_ids
|
| 542 |
|
| 543 |
@torch.inference_mode()
|
| 544 |
-
def generate_actions(self,
|
| 545 |
# For now, only support generation with a batch size of 1 for simplicity
|
| 546 |
# image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
| 547 |
|
|
@@ -557,18 +557,15 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 557 |
# raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 558 |
|
| 559 |
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
| 560 |
-
autocast_dtype = self.llm_backbone.half_precision_dtype
|
| 561 |
# with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
| 562 |
-
with torch.autocast("cuda", dtype=torch.
|
| 563 |
# fmt: off
|
| 564 |
generated_ids = self.generate(
|
| 565 |
-
|
| 566 |
-
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
|
| 567 |
-
**kwargs
|
| 568 |
)
|
| 569 |
# fmt: on
|
| 570 |
|
| 571 |
-
generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
|
| 572 |
|
| 573 |
s = solver
|
| 574 |
actions, reasoning = s.extract_action_policies(generated_text)
|
|
@@ -586,7 +583,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
| 586 |
)
|
| 587 |
_actions.append(action_norm)
|
| 588 |
|
| 589 |
-
return _actions, generated_text
|
| 590 |
|
| 591 |
|
| 592 |
@staticmethod
|
|
|
|
| 541 |
return actions, generated_ids
|
| 542 |
|
| 543 |
@torch.inference_mode()
|
| 544 |
+
def generate_actions(self, inputs, tokenizer, **kwargs: str) -> str:
|
| 545 |
# For now, only support generation with a batch size of 1 for simplicity
|
| 546 |
# image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
| 547 |
|
|
|
|
| 557 |
# raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 558 |
|
| 559 |
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
|
|
|
| 560 |
# with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
| 561 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 562 |
# fmt: off
|
| 563 |
generated_ids = self.generate(
|
| 564 |
+
**inputs, **kwargs
|
|
|
|
|
|
|
| 565 |
)
|
| 566 |
# fmt: on
|
| 567 |
|
| 568 |
+
generated_text = tokenizer.decode(generated_ids[0, inputs['input_ids'].shape[1] :], skip_special_tokens=True).strip()
|
| 569 |
|
| 570 |
s = solver
|
| 571 |
actions, reasoning = s.extract_action_policies(generated_text)
|
|
|
|
| 583 |
)
|
| 584 |
_actions.append(action_norm)
|
| 585 |
|
| 586 |
+
return _actions[0], generated_text
|
| 587 |
|
| 588 |
|
| 589 |
@staticmethod
|