"""Projector that maps hidden states from the LLM component to multimodal logits.""" import torch from torch import nn from dataclasses import dataclass from typing import Optional, Tuple from .common import HiggsAudioPreTrainedModel from .configuration_higgs_audio import HiggsAudioConfig @dataclass class HiggsAudioDecoderLayerOutput: logits: torch.FloatTensor audio_logits: torch.FloatTensor attentions: Optional[Tuple[torch.FloatTensor, ...]] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel): """Projection layers that map hidden states from the LLM component to audio / text logits. We support two type of audio head: - Basic Audio Head: Directly map the hidden states to audio logits for all the codebooks. """ def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None): super().__init__(config) self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.audio_lm_head = nn.Linear( config.text_config.hidden_size, config.audio_num_codebooks * (config.audio_codebook_size + 2), bias=False, ) # Initialize weights and apply final processing self.post_init() def forward( self, hidden_states, audio_out_mask, label_audio_ids=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, output_audio_hidden_states=False, cache_position=None, ): """ Args: hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): Hidden states from the LLM component audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): Mask for identifying the audio out tokens. label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`): Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used. attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): Mask to avoid performing attention on padding token indices position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): Position ids for the input tokens Returns: logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): Logits for text tokens audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`): Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len` """ logits = self.text_lm_head(hidden_states) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input. if self.config.audio_decoder_proj_num_layers > 0: # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.transformer_layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if output_attentions: all_self_attns += (layer_outputs[1],) if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] next_cache = next_decoder_cache if use_cache else None audio_logits = self.audio_lm_head(hidden_states[audio_out_mask]) if output_audio_hidden_states: audio_hidden_states = hidden_states[audio_out_mask] else: audio_hidden_states = None return ( logits, audio_logits, all_self_attns, all_hidden_states, audio_hidden_states, next_cache, )