from typing import Optional, Tuple import torch import torch.nn as nn from indextts.gpt.conformer.attention import (MultiHeadedAttention, RelPositionMultiHeadedAttention) from indextts.gpt.conformer.embedding import (NoPositionalEncoding, PositionalEncoding, RelPositionalEncoding) from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2, Conv2dSubsampling4, Conv2dSubsampling6, Conv2dSubsampling8, LinearNoSubsampling) from indextts.utils.common import make_pad_mask class PositionwiseFeedForward(torch.nn.Module): """Positionwise feed forward layer. FeedForward are appied on each position of the sequence. The output dim is same with the input dim. Args: idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. activation (torch.nn.Module): Activation function """ def __init__(self, idim: int, hidden_units: int, dropout_rate: float, activation: torch.nn.Module = torch.nn.ReLU()): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() self.w_1 = torch.nn.Linear(idim, hidden_units) self.activation = activation self.dropout = torch.nn.Dropout(dropout_rate) self.w_2 = torch.nn.Linear(hidden_units, idim) def forward(self, xs: torch.Tensor) -> torch.Tensor: """Forward function. Args: xs: input tensor (B, L, D) Returns: output tensor, (B, L, D) """ return self.w_2(self.dropout(self.activation(self.w_1(xs)))) class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" def __init__(self, channels: int, kernel_size: int = 15, activation: nn.Module = nn.ReLU(), bias: bool = True): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. kernel_size (int): Kernel size of conv layers. causal (int): Whether use causal convolution or not """ super().__init__() self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) # self.lorder is used to distinguish if it's a causal convolution, # if self.lorder > 0: it's a causal convolution, the input will be # padded with self.lorder frames on the left in forward. # else: it's a symmetrical convolution # kernel_size should be an odd number for none causal convolution assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 self.lorder = 0 self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, padding=padding, groups=channels, bias=bias, ) self.use_layer_norm = True self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = activation def forward( self, x: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), cache: torch.Tensor = torch.zeros((0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute convolution module. Args: x (torch.Tensor): Input tensor (#batch, time, channels). mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), (0, 0, 0) means fake mask. cache (torch.Tensor): left context cache, it is only used in causal convolution (#batch, channels, cache_t), (0, 0, 0) meas fake cache. Returns: torch.Tensor: Output tensor (#batch, time, channels). """ # exchange the temporal dimension and the feature dimension x = x.transpose(1, 2) # (#batch, channels, time) # mask batch padding if mask_pad.size(2) > 0: # time > 0 x.masked_fill_(~mask_pad, 0.0) if self.lorder > 0: if cache.size(2) == 0: # cache_t == 0 x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) else: assert cache.size(0) == x.size(0) # equal batch assert cache.size(1) == x.size(1) # equal channel x = torch.cat((cache, x), dim=2) assert (x.size(2) > self.lorder) new_cache = x[:, :, -self.lorder:] else: # It's better we just return None if no cache is required, # However, for JIT export, here we just fake one tensor instead of # None. new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = nn.functional.glu(x, dim=1) # (batch, channel, dim) # 1D Depthwise Conv x = self.depthwise_conv(x) if self.use_layer_norm: x = x.transpose(1, 2) x = self.activation(self.norm(x)) if self.use_layer_norm: x = x.transpose(1, 2) x = self.pointwise_conv2(x) # mask batch padding if mask_pad.size(2) > 0: # time > 0 x.masked_fill_(~mask_pad, 0.0) return x.transpose(1, 2), new_cache class ConformerEncoderLayer(nn.Module): """Encoder layer module. Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. conv_module (torch.nn.Module): Convolution module instance. `ConvlutionModule` instance can be used as the argument. dropout_rate (float): Dropout rate. normalize_before (bool): True: use layer_norm before each sub-block. False: use layer_norm after each sub-block. concat_after (bool): Whether to concat attention layer's input and output. True: x -> x + linear(concat(x, att(x))) False: x -> x + att(x) """ def __init__( self, size: int, self_attn: torch.nn.Module, feed_forward: Optional[nn.Module] = None, feed_forward_macaron: Optional[nn.Module] = None, conv_module: Optional[nn.Module] = None, dropout_rate: float = 0.1, normalize_before: bool = True, concat_after: bool = False, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module if feed_forward_macaron is not None: self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module self.norm_final = nn.LayerNorm( size, eps=1e-5) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) else: self.concat_linear = nn.Identity() def forward( self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute encoded features. Args: x (torch.Tensor): (#batch, time, size) mask (torch.Tensor): Mask tensor for the input (#batch, time,time), (0, 0, 0) means fake mask. pos_emb (torch.Tensor): positional encoding, must not be None for ConformerEncoderLayer. mask_pad (torch.Tensor): batch padding mask used for conv module. (#batch, 1,time), (0, 0, 0) means fake mask. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (torch.Tensor): Convolution cache in conformer layer (#batch=1, size, cache_t2) Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time, time). torch.Tensor: att_cache tensor, (#batch=1, head, cache_t1 + time, d_k * 2). torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). """ # whether to use macaron style if self.feed_forward_macaron is not None: residual = x if self.normalize_before: x = self.norm_ff_macaron(x) x = residual + self.ff_scale * self.dropout( self.feed_forward_macaron(x)) if not self.normalize_before: x = self.norm_ff_macaron(x) # multi-headed self-attention module residual = x if self.normalize_before: x = self.norm_mha(x) x_att, new_att_cache = self.self_attn( x, x, x, mask, pos_emb, att_cache) if self.concat_after: x_concat = torch.cat((x, x_att), dim=-1) x = residual + self.concat_linear(x_concat) else: x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x) # convolution module # Fake new cnn cache here, and then change it in conv_module new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) if self.conv_module is not None: residual = x if self.normalize_before: x = self.norm_conv(x) x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) x = residual + self.dropout(x) if not self.normalize_before: x = self.norm_conv(x) # feed forward module residual = x if self.normalize_before: x = self.norm_ff(x) x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm_ff(x) if self.conv_module is not None: x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache class BaseEncoder(torch.nn.Module): def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "abs_pos", normalize_before: bool = True, concat_after: bool = False, ): """ Args: input_size (int): input dim output_size (int): dimension of attention attention_heads (int): the number of heads of multi head attention linear_units (int): the hidden units number of position-wise feed forward num_blocks (int): the number of decoder blocks dropout_rate (float): dropout rate attention_dropout_rate (float): dropout rate in attention positional_dropout_rate (float): dropout rate after adding positional encoding input_layer (str): input layer type. optional [linear, conv2d, conv2d6, conv2d8] pos_enc_layer_type (str): Encoder positional encoding layer type. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] normalize_before (bool): True: use layer_norm before each sub-block of a layer. False: use layer_norm after each sub-block of a layer. concat_after (bool): whether to concat attention layer's input and output. True: x -> x + linear(concat(x, att(x))) False: x -> x + att(x) static_chunk_size (int): chunk size for static chunk training and decoding use_dynamic_chunk (bool): whether use dynamic chunk size for training or not, You can only use fixed chunk(chunk_size > 0) or dyanmic chunk size(use_dynamic_chunk = True) global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training """ super().__init__() self._output_size = output_size if pos_enc_layer_type == "abs_pos": pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding elif pos_enc_layer_type == "no_pos": pos_enc_class = NoPositionalEncoding else: raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) if input_layer == "linear": subsampling_class = LinearNoSubsampling elif input_layer == "conv2d2": subsampling_class = Conv2dSubsampling2 elif input_layer == "conv2d": subsampling_class = Conv2dSubsampling4 elif input_layer == "conv2d6": subsampling_class = Conv2dSubsampling6 elif input_layer == "conv2d8": subsampling_class = Conv2dSubsampling8 else: raise ValueError("unknown input_layer: " + input_layer) self.embed = subsampling_class( input_size, output_size, dropout_rate, pos_enc_class(output_size, dropout_rate), ) self.normalize_before = normalize_before self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) def output_size(self) -> int: return self._output_size def forward( self, xs: torch.Tensor, xs_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. Args: xs: padded input tensor (B, T, D) xs_lens: input length (B) decoding_chunk_size: decoding chunk size for dynamic chunk 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. >=0: use num_decoding_left_chunks <0: use all left chunks Returns: encoder output tensor xs, and subsampled masks xs: padded output tensor (B, T' ~= T/subsample_rate, D) masks: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) """ T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) xs, pos_emb, masks = self.embed(xs, masks) chunk_masks = masks mask_pad = masks # (B, 1, T/subsample_rate) for layer in self.encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just # return the masks before encoder layers, and the masks will be used # for cross attention with decoder later return xs, masks class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "rel_pos", normalize_before: bool = True, concat_after: bool = False, macaron_style: bool = False, use_cnn_module: bool = True, cnn_module_kernel: int = 15, ): """Construct ConformerEncoder Args: input_size to use_dynamic_chunk, see in BaseEncoder positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. macaron_style (bool): Whether to use macaron style for positionwise layer. selfattention_layer_type (str): Encoder attention layer type, the parameter has no effect now, it's just for configure compatibility. activation_type (str): Encoder activation function type. use_cnn_module (bool): Whether to use convolution module. cnn_module_kernel (int): Kernel size of convolution module. causal (bool): whether to use causal convolution or not. """ super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, input_layer, pos_enc_layer_type, normalize_before, concat_after) activation = torch.nn.SiLU() # self-attention module definition if pos_enc_layer_type != "rel_pos": encoder_selfattn_layer = MultiHeadedAttention else: encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = ( attention_heads, output_size, dropout_rate, ) # feed-forward module definition positionwise_layer = PositionwiseFeedForward positionwise_layer_args = ( output_size, linear_units, dropout_rate, activation, ) # convolution module definition convolution_layer = ConvolutionModule convolution_layer_args = (output_size, cnn_module_kernel, activation,) self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, encoder_selfattn_layer(*encoder_selfattn_layer_args), positionwise_layer(*positionwise_layer_args), positionwise_layer( *positionwise_layer_args) if macaron_style else None, convolution_layer( *convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, concat_after, ) for _ in range(num_blocks) ])