MindOmni / src /image_decoder /transformer.py
stevengrove's picture
Update src/image_decoder/transformer.py
3cded83 verified
from typing import List, Optional, Tuple, Union
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from .modeling_phi3 import Phi3Model
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Phi3Transformer(Phi3Model):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
We only modified the attention mask
Args:
config: Phi3Config
"""
def prefetch_layer(self, layer_idx: int, device: torch.device):
"Starts prefetching the next layer cache"
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
for name, param in self.layers[layer_idx].named_parameters():
param.data = param.data.to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()
# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)
# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
offload_model: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
layer_idx = -1
for decoder_layer in self.layers:
layer_idx += 1
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
if offload_model and not self.training:
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)