zachzzc's picture
Upload tts playground and serving engine
07f1f64
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from .configuration_higgs_audio import HiggsAudioConfig
class HiggsAudioPreTrainedModel(PreTrainedModel):
config_class = HiggsAudioConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()