Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import random | |
import warnings | |
from transformers import ( | |
ByT5Tokenizer, | |
Seq2SeqTrainingArguments, | |
Seq2SeqTrainer, | |
) | |
from transformers.models.t5 import T5Config | |
from transformers.models.t5.modeling_t5 import * | |
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput | |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
from torch.nn import CrossEntropyLoss | |
from collections.abc import Mapping | |
from dataclasses import dataclass | |
from random import randint | |
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union | |
from transformers.utils import PaddingStrategy | |
from shubert import SignHubertModel, SignHubertConfig | |
class SignHubertAdapter(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
# Adjust intermediate_dim based on number of channels | |
intermediate_dim_shubert = 1024 | |
self.signhubert = SignHubertModel(SignHubertConfig( | |
channels=channels, | |
intermediate_dim=intermediate_dim_shubert | |
)) | |
def forward(self, x): | |
features = self.signhubert.extract_features(x, padding_mask=None, kmeans_labels=None, mask=False) | |
# Extract layer outputs | |
layer_outputs = [] | |
for layer in features['layer_results']: | |
layer_output = layer[-1] # Shape: [B, T, D] | |
layer_outputs.append(layer_output) | |
# Stack the outputs from all layers | |
stacked_features = torch.stack(layer_outputs, dim=1) # Shape: [B, L, T, D] | |
return stacked_features | |
class LinearAdapter(nn.Module): | |
def __init__(self, face_dim, hand_dim, pose_dim, representations_dim, out_dim, extraction_layer, channels): | |
super().__init__() | |
self.signhubert_adapter = SignHubertAdapter(channels) | |
self.layer_weights = nn.Parameter(torch.ones(12)) # Learnable weights for each layer | |
self.final_layer = nn.Linear(representations_dim, out_dim) | |
self.extraction_layer = extraction_layer | |
def forward(self, face_features, left_hand_features, right_hand_features, body_posture_features): | |
dtype = torch.float32 | |
face_features = face_features.to(dtype=dtype) | |
left_hand_features = left_hand_features.to(dtype=dtype) | |
right_hand_features = right_hand_features.to(dtype=dtype) | |
body_posture_features = body_posture_features.to(dtype=dtype) | |
batch_size, seq_len = face_features.shape[:2] | |
dummy_labels = torch.zeros((seq_len, 1), dtype=dtype, device=face_features.device) | |
source = [] | |
for i in range(batch_size): | |
source.append({ | |
"face": face_features[i], | |
"left_hand": left_hand_features[i], | |
"right_hand": right_hand_features[i], | |
"body_posture": body_posture_features[i], | |
"label_face": dummy_labels, | |
"label_left_hand": dummy_labels, | |
"label_right_hand": dummy_labels, | |
"label_body_posture": dummy_labels | |
}) | |
# Get representations from SignHubert | |
representations_features = self.signhubert_adapter(source) # [T, L, B, D] | |
representations_features = representations_features.permute(2, 1, 0, 3) # [B, L, T, D] | |
if self.extraction_layer == 0: | |
normalized_weights = self.layer_weights | |
weighted_representations = representations_features * normalized_weights.view(1, -1, 1, 1) | |
representations_for_downstream_task = torch.sum(weighted_representations, dim=1) | |
else: | |
representations_for_downstream_task = representations_features[:, self.extraction_layer-1, :, :] | |
final_output = self.final_layer(representations_for_downstream_task) | |
return final_output | |
class SignLanguageByT5Config(T5Config): | |
def __init__( | |
self, | |
representations_dim=768, | |
adapter="linear", | |
finetune_signhubert=False, | |
face_dim=384, | |
hand_dim=384, | |
pose_dim=14, | |
extraction_layer=0, # use last layer | |
channels="face,left_hand,right_hand,body_posture", | |
**kwargs | |
): | |
self.representations_dim = representations_dim | |
self.adapter = adapter | |
self.finetune_signhubert = finetune_signhubert | |
self.face_dim = face_dim | |
self.hand_dim = hand_dim | |
self.pose_dim = pose_dim | |
self.extraction_layer = extraction_layer | |
self.channels = channels | |
super().__init__(**kwargs) | |
class SignLanguageByT5Encoder(T5PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Initialize the adapter based on the configuration | |
if config.adapter == "linear": | |
self.adapter = LinearAdapter( | |
config.face_dim, | |
config.hand_dim, | |
config.pose_dim, | |
config.representations_dim, | |
config.d_model, | |
config.extraction_layer, | |
config.channels | |
) | |
else: | |
raise NotImplementedError("Adapter type not implemented.") | |
self.is_decoder = config.is_decoder | |
# Define the encoder blocks | |
self.block = nn.ModuleList( | |
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] | |
) | |
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) | |
self.dropout = nn.Dropout(config.dropout_rate) | |
# Initialize weights and apply final processing | |
self.post_init() | |
# Model parallel settings | |
self.model_parallel = False | |
self.device_map = None | |
self.gradient_checkpointing = False | |
def parallelize(self, device_map=None): | |
warnings.warn( | |
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" | |
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" | |
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," | |
" 'block.1': 1, ...}", | |
FutureWarning, | |
) | |
# Check validity of device_map | |
self.device_map = ( | |
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map | |
) | |
assert_device_map(self.device_map, len(self.block)) | |
self.model_parallel = True | |
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) | |
self.last_device = "cuda:" + str(max(self.device_map.keys())) | |
# Load onto devices | |
for k, v in self.device_map.items(): | |
for layer in v: | |
cuda_device = "cuda:" + str(k) | |
self.block[layer] = self.block[layer].to(cuda_device) | |
# Set embed_tokens to first layer | |
self.embed_tokens = self.embed_tokens.to(self.first_device) | |
# Set final layer norm to last device | |
self.final_layer_norm = self.final_layer_norm.to(self.last_device) | |
def deparallelize(self): | |
warnings.warn( | |
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", | |
FutureWarning, | |
) | |
self.model_parallel = False | |
self.device_map = None | |
self.first_device = "cpu" | |
self.last_device = "cpu" | |
for i in range(len(self.block)): | |
self.block[i] = self.block[i].to("cpu") | |
self.embed_tokens = self.embed_tokens.to("cpu") | |
self.final_layer_norm = self.final_layer_norm.to("cpu") | |
torch.cuda.empty_cache() | |
def get_input_embeddings(self): | |
return self.embed_tokens | |
def set_input_embeddings(self, new_embeddings): | |
self.embed_tokens = new_embeddings | |
def forward( | |
self, | |
face_features=None, | |
left_hand_features=None, | |
right_hand_features=None, | |
pose_features=None, | |
attention_mask=None, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
cross_attn_head_mask=None, | |
past_key_values=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
# Set default values if not provided | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# Use the adapter to convert representation features into embeddings | |
inputs_embeds = self.adapter(face_features, left_hand_features, right_hand_features, pose_features) | |
input_shape = inputs_embeds.shape[:2] | |
batch_size, seq_length = input_shape | |
mask_seq_length = seq_length | |
if attention_mask is None: | |
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) | |
# Initialize past_key_values if not provided | |
if past_key_values is None: | |
past_key_values = [None] * len(self.block) | |
# Extend attention mask | |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) | |
# Prepare head mask if needed | |
head_mask = self.get_head_mask(head_mask, self.config.num_layers) | |
present_key_value_states = () if use_cache else None | |
all_hidden_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
hidden_states = self.dropout(inputs_embeds) | |
# Iterate over each encoder block | |
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): | |
layer_head_mask = head_mask[i] | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask=extended_attention_mask, | |
position_bias=None, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
encoder_decoder_position_bias=None, | |
layer_head_mask=layer_head_mask, | |
cross_attn_layer_head_mask=cross_attn_head_mask, | |
past_key_value=past_key_value, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
present_key_value_states = present_key_value_states + (layer_outputs[1],) | |
if output_attentions: | |
all_attentions = all_attentions + (layer_outputs[2],) | |
hidden_states = self.final_layer_norm(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
# Add last hidden state | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions] | |
if v is not None | |
) | |
return BaseModelOutputWithPastAndCrossAttentions( | |
last_hidden_state=hidden_states, | |
past_key_values=present_key_value_states, | |
hidden_states=all_hidden_states, | |
attentions=all_attentions, | |
cross_attentions=None, | |
) | |
class SignLanguageByT5ForConditionalGeneration(T5PreTrainedModel): | |
_keys_to_ignore_on_load_unexpected = [ | |
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", | |
] | |
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] | |
def __init__(self, config: T5Config): | |
super().__init__(config) | |
self.model_dim = config.d_model | |
# Initialize the decoder embedding | |
self.decoder_emb = nn.Embedding(config.vocab_size, config.d_model) | |
# Initialize the encoder with the custom SignLanguageByT5Encoder | |
encoder_config = copy.deepcopy(config) | |
encoder_config.is_decoder = False | |
encoder_config.use_cache = False | |
encoder_config.is_encoder_decoder = False | |
self.encoder = SignLanguageByT5Encoder(encoder_config) | |
# Initialize the decoder | |
decoder_config = copy.deepcopy(config) | |
decoder_config.is_decoder = True | |
decoder_config.is_encoder_decoder = False | |
decoder_config.num_layers = config.num_decoder_layers | |
self.decoder = T5Stack(decoder_config, self.decoder_emb) | |
# Initialize the language modeling head | |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
# Model parallel settings | |
self.model_parallel = False | |
self.device_map = None | |
def parallelize(self, device_map=None): | |
warnings.warn( | |
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" | |
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" | |
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" | |
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", | |
FutureWarning, | |
) | |
self.device_map = ( | |
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | |
if device_map is None | |
else device_map | |
) | |
assert_device_map(self.device_map, len(self.encoder.block)) | |
self.encoder.parallelize(self.device_map) | |
self.decoder.parallelize(self.device_map) | |
self.lm_head = self.lm_head.to(self.decoder.first_device) | |
self.model_parallel = True | |
def deparallelize(self): | |
warnings.warn( | |
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", | |
FutureWarning, | |
) | |
self.encoder.deparallelize() | |
self.decoder.deparallelize() | |
self.encoder = self.encoder.to("cpu") | |
self.decoder = self.decoder.to("cpu") | |
self.lm_head = self.lm_head.to("cpu") | |
self.model_parallel = False | |
self.device_map = None | |
torch.cuda.empty_cache() | |
def get_input_embeddings(self): | |
return self.decoder_emb | |
def set_input_embeddings(self, new_embeddings): | |
self.shared = new_embeddings | |
self.encoder.set_input_embeddings(new_embeddings) | |
self.decoder.set_input_embeddings(new_embeddings) | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def get_output_embeddings(self): | |
return self.lm_head | |
def get_encoder(self): | |
return self.encoder | |
def get_decoder(self): | |
return self.decoder | |
def prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
head_mask=None, | |
decoder_head_mask=None, | |
decoder_attention_mask=None, | |
cross_attn_head_mask=None, | |
use_cache=None, | |
encoder_outputs=None, | |
**kwargs, | |
): | |
# cut decoder_input_ids if past is used | |
if past_key_values is not None: | |
input_ids = input_ids[:, -1:] | |
return { | |
"decoder_input_ids": input_ids, | |
"past_key_values": past_key_values, | |
"encoder_outputs": encoder_outputs, | |
"attention_mask": attention_mask, | |
"head_mask": head_mask, | |
"decoder_head_mask": decoder_head_mask, | |
"decoder_attention_mask": decoder_attention_mask, | |
"cross_attn_head_mask": cross_attn_head_mask, | |
"use_cache": use_cache, | |
} | |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
return self._shift_right(labels) | |
def _reorder_cache(self, past_key_values, beam_idx): | |
# if decoder past is not included in output | |
# speedy decoding is disabled and no need to reorder | |
if past_key_values is None: | |
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") | |
return past_key_values | |
reordered_decoder_past = () | |
for layer_past_states in past_key_values: | |
# get the correct batch idx from layer past batch dim | |
# batch dim of `past` is at 2nd position | |
reordered_layer_past_states = () | |
for layer_past_state in layer_past_states: | |
# need to set correct `past` for each of the four key / value states | |
reordered_layer_past_states = reordered_layer_past_states + ( | |
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), | |
) | |
if reordered_layer_past_states[0].shape != layer_past_states[0].shape: | |
raise ValueError( | |
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" | |
) | |
if len(reordered_layer_past_states) != len(layer_past_states): | |
raise ValueError( | |
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" | |
) | |
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) | |
return reordered_decoder_past | |
def forward( | |
self, | |
face_features=None, | |
left_hand_features=None, | |
right_hand_features=None, | |
pose_features=None, | |
attention_mask=None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
head_mask=None, | |
decoder_head_mask=None, | |
cross_attn_head_mask=None, | |
encoder_outputs=None, | |
past_key_values=None, | |
decoder_inputs_embeds=None, | |
labels=None, # Keep this for training compatibility | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
# Set default values if not provided | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# Prepare head masks if needed | |
if head_mask is not None and decoder_head_mask is None: | |
if self.config.num_layers == self.config.num_decoder_layers: | |
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | |
decoder_head_mask = head_mask | |
# Encode if encoder outputs are not provided | |
if encoder_outputs is None: | |
encoder_outputs = self.encoder( | |
face_features=face_features, | |
left_hand_features=left_hand_features, | |
right_hand_features=right_hand_features, | |
pose_features=pose_features, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutputWithPastAndCrossAttentions): | |
encoder_outputs = BaseModelOutputWithPastAndCrossAttentions( | |
last_hidden_state=encoder_outputs[0], | |
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | |
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
) | |
hidden_states = encoder_outputs[0] | |
# Prepare decoder inputs | |
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: | |
decoder_input_ids = self._shift_right(labels) | |
# Decode | |
decoder_outputs = self.decoder( | |
input_ids=decoder_input_ids, | |
attention_mask=decoder_attention_mask, | |
inputs_embeds=decoder_inputs_embeds, | |
past_key_values=past_key_values, | |
encoder_hidden_states=hidden_states, | |
encoder_attention_mask=attention_mask, | |
head_mask=decoder_head_mask, | |
cross_attn_head_mask=cross_attn_head_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = decoder_outputs[0] | |
# Scale sequence output if embeddings are tied | |
if self.config.tie_word_embeddings: | |
sequence_output = sequence_output * (self.model_dim ** -0.5) | |
# Compute language modeling logits | |
lm_logits = self.lm_head(sequence_output) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss(ignore_index=-100) | |
labels = labels.to(lm_logits.device) | |
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs | |
return ((loss,) + output) if loss is not None else output | |
return Seq2SeqLMOutput( | |
loss=loss, | |
logits=lm_logits, | |
past_key_values=decoder_outputs.past_key_values, | |
decoder_hidden_states=decoder_outputs.hidden_states, | |
decoder_attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
encoder_hidden_states=encoder_outputs.hidden_states, | |
encoder_attentions=encoder_outputs.attentions, | |
) | |
def generate( | |
self, | |
face_features=None, | |
left_hand_features=None, | |
right_hand_features=None, | |
pose_features=None, | |
attention_mask=None, | |
**kwargs | |
): | |
""" | |
Generate method to handle sign language features and generate output sequences. | |
""" | |
# Compute encoder outputs using sign language features | |
encoder_outputs = self.encoder( | |
face_features=face_features, | |
left_hand_features=left_hand_features, | |
right_hand_features=right_hand_features, | |
pose_features=pose_features, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
# Pass encoder outputs to the decoder | |
kwargs["encoder_outputs"] = encoder_outputs | |
# Generate sequences using the parent class's generate method | |
return super().generate( | |
attention_mask=attention_mask, | |
**kwargs | |
) | |
class SignLanguageT5Collator: | |
model: Optional[Any] = None | |
padding: Union[bool, str, PaddingStrategy] = True | |
max_length: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
label_pad_token_id: int = -100 | |
return_tensors: str = "pt" | |
def __call__(self, features, return_tensors=None): | |
if return_tensors is None: | |
return_tensors = self.return_tensors | |
face_embeds = [feature["face_features"] for feature in features] | |
left_hand_embeds = [feature["left_hand_features"] for feature in features] | |
right_hand_embeds = [feature["right_hand_features"] for feature in features] | |
pose_embeds = [feature["pose_features"] for feature in features] | |
# Padding | |
max_len = max([emb.shape[0] for emb in face_embeds]) | |
def pad_embeds(embeds): | |
padded_embeds = [] | |
for emb in embeds: | |
if emb.dim() == 3: # For 3D tensors (pose features) | |
pad_len = max_len - emb.shape[1] # padding the second dimension (T) | |
emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len, 0, 0), value=0) | |
else: # For 2D tensors (face, hand features) | |
pad_len = max_len - emb.shape[0] | |
emb_pad = torch.nn.functional.pad(emb, (0, 0, 0, pad_len), value=0) | |
padded_embeds.append(emb_pad) | |
return padded_embeds | |
padded_face_embeds = pad_embeds(face_embeds) | |
padded_left_hand_embeds = pad_embeds(left_hand_embeds) | |
padded_right_hand_embeds = pad_embeds(right_hand_embeds) | |
padded_pose_embeds = pad_embeds(pose_embeds) | |
batch = {} | |
batch["face_features"] = torch.stack(padded_face_embeds, dim=0) | |
batch["left_hand_features"] = torch.stack(padded_left_hand_embeds, dim=0) | |
batch["right_hand_features"] = torch.stack(padded_right_hand_embeds, dim=0) | |
batch["pose_features"] = torch.stack(padded_pose_embeds, dim=0) | |
# For inference, we don't need decoder_input_ids - the model.generate() will handle this | |
# Remove the decoder_input_ids requirement | |
return batch | |
class TranslationFeatures(torch.utils.data.Dataset): | |
def __init__(self, face_embeddings, left_hand_embeddings, right_hand_embeddings, body_posture_embeddings): | |
self.face_embeddings = face_embeddings | |
self.left_hand_embeddings = left_hand_embeddings | |
self.right_hand_embeddings = right_hand_embeddings | |
self.body_posture_embeddings = body_posture_embeddings | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
return { | |
"face_features": torch.tensor(self.face_embeddings), | |
"left_hand_features": torch.tensor(self.left_hand_embeddings), | |
"right_hand_features": torch.tensor(self.right_hand_embeddings), | |
"pose_features": torch.tensor(self.body_posture_embeddings), | |
} | |
def generate_text_from_features( | |
face_embeddings: np.ndarray, | |
left_hand_embeddings: np.ndarray, | |
right_hand_embeddings: np.ndarray, | |
body_posture_embeddings: np.ndarray, | |
model_config: str, | |
model_checkpoint: str, | |
tokenizer_checkpoint: str, | |
output_dir: str, | |
generation_max_length: int = 2048, | |
generation_num_beams: int = 5, | |
): | |
""" | |
Direct inference function that generates text from sign language features. | |
""" | |
# Load model and tokenizer | |
config = SignLanguageByT5Config.from_pretrained(model_config) | |
model = SignLanguageByT5ForConditionalGeneration.from_pretrained( | |
model_checkpoint, | |
# config=config, | |
cache_dir=os.path.join(output_dir, "cache"), | |
) | |
tokenizer = ByT5Tokenizer.from_pretrained(tokenizer_checkpoint) | |
# Move model to appropriate device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# Convert inputs to tensors and move to device | |
face_tensor = torch.tensor(face_embeddings, dtype=torch.float32).unsqueeze(0).to(device) | |
left_hand_tensor = torch.tensor(left_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device) | |
right_hand_tensor = torch.tensor(right_hand_embeddings, dtype=torch.float32).unsqueeze(0).to(device) | |
pose_tensor = torch.tensor(body_posture_embeddings, dtype=torch.float32).unsqueeze(0).to(device) | |
# Generate text | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
face_features=face_tensor, | |
left_hand_features=left_hand_tensor, | |
right_hand_features=right_hand_tensor, | |
pose_features=pose_tensor, | |
max_length=generation_max_length, | |
num_beams=generation_num_beams, | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# Decode generated text | |
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return generated_text | |
def test( | |
face_embeddings: np.ndarray, | |
left_hand_embeddings: np.ndarray, | |
right_hand_embeddings: np.ndarray, | |
body_posture_embeddings: np.ndarray, | |
model_config: str, | |
model_checkpoint: str, | |
tokenizer_checkpoint: str, | |
output_dir: str, | |
): | |
""" | |
Test function for inference - generates text from sign language features. | |
This is a simpler wrapper around the direct inference function. | |
""" | |
return generate_text_from_features( | |
face_embeddings=face_embeddings, | |
left_hand_embeddings=left_hand_embeddings, | |
right_hand_embeddings=right_hand_embeddings, | |
body_posture_embeddings=body_posture_embeddings, | |
model_config=model_config, | |
model_checkpoint=model_checkpoint, | |
tokenizer_checkpoint=tokenizer_checkpoint, | |
output_dir=output_dir, | |
) |