# Copyright 2025 Xiaomi Corp. (authors: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream from zipvoice.models.zipvoice import ZipVoice from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels class ZipVoiceDialog(ZipVoice): """The ZipVoice-Dialog model.""" def __init__( self, fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1], fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4], fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31], fm_decoder_feedforward_dim: int = 1536, fm_decoder_num_heads: int = 4, fm_decoder_dim: int = 512, text_encoder_num_layers: int = 4, text_encoder_feedforward_dim: int = 512, text_encoder_cnn_module_kernel: int = 9, text_encoder_num_heads: int = 4, text_encoder_dim: int = 192, time_embed_dim: int = 192, text_embed_dim: int = 192, query_head_dim: int = 32, value_head_dim: int = 12, pos_head_dim: int = 4, pos_dim: int = 48, feat_dim: int = 100, vocab_size: int = 26, pad_id: int = 0, spk_a_id: int = 360, spk_b_id: int = 361, ): """ Initialize the model with specified configuration parameters. Args: fm_decoder_downsampling_factor: List of downsampling factors for each layer in the flow-matching decoder. fm_decoder_num_layers: List of the number of layers for each block in the flow-matching decoder. fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the flow-matching decoder. fm_decoder_feedforward_dim: Dimension of the feedforward network in the flow-matching decoder. fm_decoder_num_heads: Number of attention heads in the flow-matching decoder. fm_decoder_dim: Hidden dimension of the flow-matching decoder. text_encoder_num_layers: Number of layers in the text encoder. text_encoder_feedforward_dim: Dimension of the feedforward network in the text encoder. text_encoder_cnn_module_kernel: Kernel size for the CNN module in the text encoder. text_encoder_num_heads: Number of attention heads in the text encoder. text_encoder_dim: Hidden dimension of the text encoder. time_embed_dim: Dimension of the time embedding. text_embed_dim: Dimension of the text embedding. query_head_dim: Dimension of the query attention head. value_head_dim: Dimension of the value attention head. pos_head_dim: Dimension of the position attention head. pos_dim: Dimension of the positional encoding. feat_dim: Dimension of the acoustic features. vocab_size: Size of the vocabulary. pad_id: ID used for padding tokens. spk_a_id: ID of speaker A / [S1]. spk_b_id: ID of speaker B / [S2]. """ super().__init__( fm_decoder_downsampling_factor=fm_decoder_downsampling_factor, fm_decoder_num_layers=fm_decoder_num_layers, fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel, fm_decoder_feedforward_dim=fm_decoder_feedforward_dim, fm_decoder_num_heads=fm_decoder_num_heads, fm_decoder_dim=fm_decoder_dim, text_encoder_num_layers=text_encoder_num_layers, text_encoder_feedforward_dim=text_encoder_feedforward_dim, text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel, text_encoder_num_heads=text_encoder_num_heads, text_encoder_dim=text_encoder_dim, time_embed_dim=time_embed_dim, text_embed_dim=text_embed_dim, query_head_dim=query_head_dim, value_head_dim=value_head_dim, pos_head_dim=pos_head_dim, pos_dim=pos_dim, feat_dim=feat_dim, vocab_size=vocab_size, pad_id=pad_id, ) self.spk_a_id = spk_a_id self.spk_b_id = spk_b_id self.spk_embed = nn.Embedding(2, feat_dim) torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1) def extract_spk_indices(self, tensor): turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long() turn_counts = turn_mask.cumsum(dim=1) spk_mask = turn_counts % 2 spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask) spk_a_indices = torch.where(spk_mask == 0) spk_b_indices = torch.where(spk_mask == 1) return spk_a_indices, spk_b_indices def forward_text_embed( self, tokens: List[List[int]], ): """ Get the text embeddings. Args: tokens: a list of list of token ids. Returns: embed: the text embeddings, shape (batch, seq_len, emb_dim). tokens_lens: the length of each token sequence, shape (batch,). """ device = ( self.device if isinstance(self, DDP) else next(self.parameters()).device ) tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S) embed = self.embed(tokens_padded) # (B, S, C) spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded) tokens_lens = torch.tensor( [len(token) for token in tokens], dtype=torch.int64, device=device ) tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S) embed = self.text_encoder( x=embed, t=None, padding_mask=tokens_padding_mask ) # (B, S, C) embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to( embed.dtype ) embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to( embed.dtype ) return embed, tokens_lens def forward( self, tokens: List[List[int]], features: torch.Tensor, features_lens: torch.Tensor, noise: torch.Tensor, t: torch.Tensor, condition_drop_ratio: float = 0.0, ) -> torch.Tensor: """Forward pass of the model for training. Args: tokens: a list of list of token ids. features: the acoustic features, with the shape (batch, seq_len, feat_dim). features_lens: the length of each acoustic feature sequence, shape (batch,). noise: the intitial noise, with the shape (batch, seq_len, feat_dim). t: the time step, with the shape (batch, 1, 1). condition_drop_ratio: the ratio of dropped text condition. Returns: fm_loss: the flow-matching loss. """ (text_condition, padding_mask,) = self.forward_text_train( tokens=tokens, features_lens=features_lens, ) speech_condition_mask = condition_time_mask_suffix( features_lens=features_lens, mask_percent=(0.5, 1.0), max_len=features.size(1), ) speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) if condition_drop_ratio > 0.0: drop_mask = ( torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) > condition_drop_ratio ) text_condition = text_condition * drop_mask xt = features * t + noise * (1 - t) ut = features - noise # (B, T, F) vt = self.forward_fm_decoder( t=t, xt=xt, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, ) loss_mask = speech_condition_mask & (~padding_mask) fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) return fm_loss class ZipVoiceDialogStereo(ZipVoiceDialog): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) required_params = { "feat_dim", "fm_decoder_downsampling_factor", "fm_decoder_num_layers", "fm_decoder_cnn_module_kernel", "fm_decoder_dim", "fm_decoder_feedforward_dim", "fm_decoder_num_heads", "query_head_dim", "pos_head_dim", "value_head_dim", "pos_dim", "time_embed_dim", } missing = [p for p in required_params if p not in kwargs] if missing: raise ValueError(f"Missing required parameters: {', '.join(missing)}") self.fm_decoder = TTSZipformerTwoStream( in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3), out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]), downsampling_factor=kwargs["fm_decoder_downsampling_factor"], num_encoder_layers=kwargs["fm_decoder_num_layers"], cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"], encoder_dim=kwargs["fm_decoder_dim"], feedforward_dim=kwargs["fm_decoder_feedforward_dim"], num_heads=kwargs["fm_decoder_num_heads"], query_head_dim=kwargs["query_head_dim"], pos_head_dim=kwargs["pos_head_dim"], value_head_dim=kwargs["value_head_dim"], pos_dim=kwargs["pos_dim"], use_time_embed=True, time_embed_dim=kwargs["time_embed_dim"], ) def forward( self, tokens: List[List[int]], features: torch.Tensor, features_lens: torch.Tensor, noise: torch.Tensor, t: torch.Tensor, condition_drop_ratio: float = 0.0, se_weight: float = 1.0, ) -> torch.Tensor: """Forward pass of the model for training. Args: tokens: a list of list of token ids. features: the acoustic features, with the shape (batch, seq_len, feat_dim). features_lens: the length of each acoustic feature sequence, shape (batch,). noise: the intitial noise, with the shape (batch, seq_len, feat_dim). t: the time step, with the shape (batch, 1, 1). condition_drop_ratio: the ratio of dropped text condition. se_weight: the weight of the speaker exclusive loss. Returns: fm_loss: the flow-matching loss. """ (text_condition, padding_mask,) = self.forward_text_train( tokens=tokens, features_lens=features_lens, ) speech_condition_mask = condition_time_mask_suffix( features_lens=features_lens, mask_percent=(0.5, 1.0), max_len=features.size(1), ) speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) if condition_drop_ratio > 0.0: drop_mask = ( torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) > condition_drop_ratio ) text_condition = text_condition * drop_mask xt = features * t + noise * (1 - t) ut = features - noise # (B, T, F) vt = self.forward_fm_decoder( t=t, xt=xt, text_condition=text_condition, speech_condition=speech_condition, padding_mask=padding_mask, ) loss_mask = speech_condition_mask & (~padding_mask) fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) if se_weight > 0: target = xt + vt * (1 - t) fbank_1 = target[:, :, : self.feat_dim] fbank_2 = target[:, :, self.feat_dim :] energy_loss = torch.mean( self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask] ) loss = fm_loss + energy_loss * se_weight else: loss = fm_loss return loss def energy_based_loss(self, fbank1, fbank2, gt_fbank): energy1 = self.energy(fbank1) energy2 = self.energy(fbank2) energy_thresholds = self.adaptive_threshold_from_gt( torch.cat( [ gt_fbank[:, :, : self.feat_dim], gt_fbank[:, :, self.feat_dim :], ], dim=1, ) ) both_speaking = ( (energy1 > energy_thresholds) & (energy2 > energy_thresholds) ).float() penalty = ( both_speaking * (energy1 - energy_thresholds) * (energy2 - energy_thresholds) ) return penalty def energy(self, fbank): return torch.mean(fbank, dim=-1) def adaptive_threshold_from_gt(self, gt_fbank, percentile=50): frame_energies = self.energy(gt_fbank) thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1) return thresholds.unsqueeze(1)