|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR |
|
from zipvoice.models.modules.zipformer import ( |
|
DownsampledZipformer2Encoder, |
|
TTSZipformer, |
|
Zipformer2Encoder, |
|
Zipformer2EncoderLayer, |
|
) |
|
|
|
|
|
def timestep_embedding(timesteps, dim, max_period=10000): |
|
"""Create sinusoidal timestep embeddings. |
|
|
|
:param timesteps: shape of (N) or (N, T) |
|
:param dim: the dimension of the output. |
|
:param max_period: controls the minimum frequency of the embeddings. |
|
:return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) |
|
""" |
|
half = dim // 2 |
|
freqs = torch.exp( |
|
-math.log(max_period) |
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) |
|
/ half |
|
) |
|
|
|
if timesteps.dim() == 2: |
|
timesteps = timesteps.transpose(0, 1) |
|
|
|
args = timesteps[..., None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
if dim % 2: |
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) |
|
return embedding |
|
|
|
|
|
class TTSZipformerTwoStream(TTSZipformer): |
|
""" |
|
Args: |
|
|
|
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same |
|
length as downsampling_factor if they are single ints or one-element tuples. |
|
The length of downsampling_factor defines the number of stacks. |
|
|
|
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. |
|
Note: this is in addition to the downsampling factor of 2 that is applied in |
|
the frontend (self.encoder_embed). |
|
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, |
|
one per encoder stack. |
|
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack |
|
query_head_dim (int or Tuple[int]): dimension of query and key per attention |
|
head: per stack, if a tuple.. |
|
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection |
|
per attention head |
|
value_head_dim (int or Tuple[int]): dimension of value in each attention head |
|
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. |
|
Must be at least 4. |
|
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules |
|
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module |
|
|
|
pos_dim (int): the dimension of each positional-encoding vector prior to |
|
projection, e.g. 128. |
|
|
|
dropout (float): dropout rate |
|
warmup_batches (float): number of batches to warm up over; this controls |
|
dropout of encoder layers. |
|
use_time_embed: (bool): if True, do not take time embedding as additional input. |
|
time_embed_dim: (int): the dimension of the time embedding. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim: Tuple[int], |
|
out_dim: Tuple[int], |
|
downsampling_factor: Tuple[int] = (2, 4), |
|
num_encoder_layers: Union[int, Tuple[int]] = 4, |
|
cnn_module_kernel: Union[int, Tuple[int]] = 31, |
|
encoder_dim: int = 384, |
|
query_head_dim: int = 24, |
|
pos_head_dim: int = 4, |
|
value_head_dim: int = 12, |
|
num_heads: int = 8, |
|
feedforward_dim: int = 1536, |
|
pos_dim: int = 192, |
|
dropout: FloatLike = None, |
|
warmup_batches: float = 4000.0, |
|
use_time_embed: bool = True, |
|
time_embed_dim: int = 192, |
|
use_conv: bool = True, |
|
) -> None: |
|
nn.Module.__init__(self) |
|
|
|
if dropout is None: |
|
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) |
|
if isinstance(downsampling_factor, int): |
|
downsampling_factor = (downsampling_factor,) |
|
|
|
def _to_tuple(x): |
|
"""Converts a single int or a 1-tuple of an int to a tuple with the same |
|
length as downsampling_factor""" |
|
if isinstance(x, int): |
|
x = (x,) |
|
if len(x) == 1: |
|
x = x * len(downsampling_factor) |
|
else: |
|
assert len(x) == len(downsampling_factor) and isinstance(x[0], int) |
|
return x |
|
|
|
def _assert_downsampling_factor(factors): |
|
"""assert downsampling_factor follows u-net style""" |
|
assert factors[0] == 1 and factors[-1] == 1 |
|
|
|
for i in range(1, len(factors) // 2 + 1): |
|
assert factors[i] == factors[i - 1] * 2 |
|
|
|
for i in range(len(factors) // 2 + 1, len(factors)): |
|
assert factors[i] * 2 == factors[i - 1] |
|
|
|
_assert_downsampling_factor(downsampling_factor) |
|
self.downsampling_factor = downsampling_factor |
|
num_encoder_layers = _to_tuple(num_encoder_layers) |
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) |
|
self.encoder_dim = encoder_dim |
|
self.num_encoder_layers = num_encoder_layers |
|
self.query_head_dim = query_head_dim |
|
self.value_head_dim = value_head_dim |
|
self.num_heads = num_heads |
|
|
|
self.use_time_embed = use_time_embed |
|
|
|
self.time_embed_dim = time_embed_dim |
|
if self.use_time_embed: |
|
assert time_embed_dim != -1 |
|
else: |
|
time_embed_dim = -1 |
|
|
|
assert len(in_dim) == len(out_dim) == 2 |
|
|
|
self.in_dim = in_dim |
|
self.in_proj = nn.ModuleList( |
|
[nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)] |
|
) |
|
self.out_dim = out_dim |
|
self.out_proj = nn.ModuleList( |
|
[nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])] |
|
) |
|
|
|
|
|
encoders = [] |
|
|
|
num_encoders = len(downsampling_factor) |
|
for i in range(num_encoders): |
|
encoder_layer = Zipformer2EncoderLayer( |
|
embed_dim=encoder_dim, |
|
pos_dim=pos_dim, |
|
num_heads=num_heads, |
|
query_head_dim=query_head_dim, |
|
pos_head_dim=pos_head_dim, |
|
value_head_dim=value_head_dim, |
|
feedforward_dim=feedforward_dim, |
|
use_conv=use_conv, |
|
cnn_module_kernel=cnn_module_kernel[i], |
|
dropout=dropout, |
|
) |
|
|
|
|
|
|
|
encoder = Zipformer2Encoder( |
|
encoder_layer, |
|
num_encoder_layers[i], |
|
embed_dim=encoder_dim, |
|
time_embed_dim=time_embed_dim, |
|
pos_dim=pos_dim, |
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), |
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), |
|
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), |
|
) |
|
|
|
if downsampling_factor[i] != 1: |
|
encoder = DownsampledZipformer2Encoder( |
|
encoder, |
|
dim=encoder_dim, |
|
downsample=downsampling_factor[i], |
|
) |
|
|
|
encoders.append(encoder) |
|
|
|
self.encoders = nn.ModuleList(encoders) |
|
if self.use_time_embed: |
|
self.time_embed = nn.Sequential( |
|
nn.Linear(time_embed_dim, time_embed_dim * 2), |
|
SwooshR(), |
|
nn.Linear(time_embed_dim * 2, time_embed_dim), |
|
) |
|
else: |
|
self.time_embed = None |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
t: Optional[Tensor] = None, |
|
padding_mask: Optional[Tensor] = None, |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
x: |
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim). |
|
t: |
|
A t tensor of shape (batch_size,) or (batch_size, seq_len) |
|
padding_mask: |
|
The mask for padding, of shape (batch_size, seq_len); True means |
|
masked position. May be None. |
|
Returns: |
|
Return the output embeddings. its shape is |
|
(batch_size, output_seq_len, encoder_dim) |
|
""" |
|
assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}" |
|
if x.size(2) == self.in_dim[0]: |
|
index = 0 |
|
else: |
|
index = 1 |
|
x = x.permute(1, 0, 2) |
|
x = self.in_proj[index](x) |
|
|
|
if t is not None: |
|
assert t.dim() == 1 or t.dim() == 2, t.shape |
|
time_emb = timestep_embedding(t, self.time_embed_dim) |
|
time_emb = self.time_embed(time_emb) |
|
else: |
|
time_emb = None |
|
|
|
attn_mask = None |
|
|
|
for i, module in enumerate(self.encoders): |
|
x = module( |
|
x, |
|
time_emb=time_emb, |
|
src_key_padding_mask=padding_mask, |
|
attn_mask=attn_mask, |
|
) |
|
x = self.out_proj[index](x) |
|
x = x.permute(1, 0, 2) |
|
return x |
|
|