Zvo / zipvoice /models /modules /zipformer_two_stream.py
hynt's picture
update zipvoice demo
6f024ab
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR
from zipvoice.models.modules.zipformer import (
DownsampledZipformer2Encoder,
TTSZipformer,
Zipformer2Encoder,
Zipformer2EncoderLayer,
)
def timestep_embedding(timesteps, dim, max_period=10000):
"""Create sinusoidal timestep embeddings.
:param timesteps: shape of (N) or (N, T)
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
/ half
)
if timesteps.dim() == 2:
timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
args = timesteps[..., None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
return embedding
class TTSZipformerTwoStream(TTSZipformer):
"""
Args:
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
length as downsampling_factor if they are single ints or one-element tuples.
The length of downsampling_factor defines the number of stacks.
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
Note: this is in addition to the downsampling factor of 2 that is applied in
the frontend (self.encoder_embed).
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
one per encoder stack.
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple..
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
per attention head
value_head_dim (int or Tuple[int]): dimension of value in each attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
Must be at least 4.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
pos_dim (int): the dimension of each positional-encoding vector prior to
projection, e.g. 128.
dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls
dropout of encoder layers.
use_time_embed: (bool): if True, do not take time embedding as additional input.
time_embed_dim: (int): the dimension of the time embedding.
"""
def __init__(
self,
in_dim: Tuple[int],
out_dim: Tuple[int],
downsampling_factor: Tuple[int] = (2, 4),
num_encoder_layers: Union[int, Tuple[int]] = 4,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
encoder_dim: int = 384,
query_head_dim: int = 24,
pos_head_dim: int = 4,
value_head_dim: int = 12,
num_heads: int = 8,
feedforward_dim: int = 1536,
pos_dim: int = 192,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
use_time_embed: bool = True,
time_embed_dim: int = 192,
use_conv: bool = True,
) -> None:
nn.Module.__init__(self)
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
if isinstance(downsampling_factor, int):
downsampling_factor = (downsampling_factor,)
def _to_tuple(x):
"""Converts a single int or a 1-tuple of an int to a tuple with the same
length as downsampling_factor"""
if isinstance(x, int):
x = (x,)
if len(x) == 1:
x = x * len(downsampling_factor)
else:
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
return x
def _assert_downsampling_factor(factors):
"""assert downsampling_factor follows u-net style"""
assert factors[0] == 1 and factors[-1] == 1
for i in range(1, len(factors) // 2 + 1):
assert factors[i] == factors[i - 1] * 2
for i in range(len(factors) // 2 + 1, len(factors)):
assert factors[i] * 2 == factors[i - 1]
_assert_downsampling_factor(downsampling_factor)
self.downsampling_factor = downsampling_factor # tuple
num_encoder_layers = _to_tuple(num_encoder_layers)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.encoder_dim = encoder_dim
self.num_encoder_layers = num_encoder_layers
self.query_head_dim = query_head_dim
self.value_head_dim = value_head_dim
self.num_heads = num_heads
self.use_time_embed = use_time_embed
self.time_embed_dim = time_embed_dim
if self.use_time_embed:
assert time_embed_dim != -1
else:
time_embed_dim = -1
assert len(in_dim) == len(out_dim) == 2
self.in_dim = in_dim
self.in_proj = nn.ModuleList(
[nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)]
)
self.out_dim = out_dim
self.out_proj = nn.ModuleList(
[nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])]
)
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
encoders = []
num_encoders = len(downsampling_factor)
for i in range(num_encoders):
encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim,
pos_dim=pos_dim,
num_heads=num_heads,
query_head_dim=query_head_dim,
pos_head_dim=pos_head_dim,
value_head_dim=value_head_dim,
feedforward_dim=feedforward_dim,
use_conv=use_conv,
cnn_module_kernel=cnn_module_kernel[i],
dropout=dropout,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = Zipformer2Encoder(
encoder_layer,
num_encoder_layers[i],
embed_dim=encoder_dim,
time_embed_dim=time_embed_dim,
pos_dim=pos_dim,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
if downsampling_factor[i] != 1:
encoder = DownsampledZipformer2Encoder(
encoder,
dim=encoder_dim,
downsample=downsampling_factor[i],
)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
if self.use_time_embed:
self.time_embed = nn.Sequential(
nn.Linear(time_embed_dim, time_embed_dim * 2),
SwooshR(),
nn.Linear(time_embed_dim * 2, time_embed_dim),
)
else:
self.time_embed = None
def forward(
self,
x: Tensor,
t: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
t:
A t tensor of shape (batch_size,) or (batch_size, seq_len)
padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return the output embeddings. its shape is
(batch_size, output_seq_len, encoder_dim)
"""
assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}"
if x.size(2) == self.in_dim[0]:
index = 0
else:
index = 1
x = x.permute(1, 0, 2)
x = self.in_proj[index](x)
if t is not None:
assert t.dim() == 1 or t.dim() == 2, t.shape
time_emb = timestep_embedding(t, self.time_embed_dim)
time_emb = self.time_embed(time_emb)
else:
time_emb = None
attn_mask = None
for i, module in enumerate(self.encoders):
x = module(
x,
time_emb=time_emb,
src_key_padding_mask=padding_mask,
attn_mask=attn_mask,
)
x = self.out_proj[index](x)
x = x.permute(1, 0, 2)
return x