Spaces:
Running
on
L40S
Running
on
L40S
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration | |
from copy import deepcopy | |
import torch | |
class HunyuanVideoLLMEncoder(LlamaModel): | |
def __init__(self, config: LlamaConfig): | |
super().__init__(config) | |
self.auto_offload = False | |
def enable_auto_offload(self, **kwargs): | |
self.auto_offload = True | |
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2): | |
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens | |
inputs_embeds = embed_tokens(input_ids) | |
past_key_values = DynamicCache() | |
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) | |
position_ids = cache_position.unsqueeze(0) | |
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False) | |
hidden_states = inputs_embeds | |
# create position embeddings to be shared across the decoder layers | |
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb | |
position_embeddings = rotary_emb(hidden_states, position_ids) | |
# decoder layers | |
for layer_id, decoder_layer in enumerate(self.layers): | |
if self.auto_offload: | |
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device) | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=causal_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=False, | |
use_cache=True, | |
cache_position=cache_position, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers): | |
break | |
return hidden_states | |
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration): | |
def __init__(self, config): | |
super().__init__(config) | |
self.auto_offload = False | |
def enable_auto_offload(self, **kwargs): | |
self.auto_offload = True | |
# TODO: implement the low VRAM inference for MLLM. | |
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2): | |
outputs = super().forward(input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
pixel_values=pixel_values) | |
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] | |
return hidden_state | |