Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Based on fairseq code bases | |
| # https://github.com/facebookresearch/fairseq | |
| # -------------------------------------------------------- | |
| import math | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.distributed import fsdp_wrap | |
| from fairseq.models import FairseqEncoder | |
| from fairseq.modules import ( | |
| FairseqDropout, | |
| LayerDropModuleList, | |
| LayerNorm, | |
| SinusoidalPositionalEmbedding, | |
| ) | |
| from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
| from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ | |
| from torch import Tensor | |
| from fairseq.models.transformer import ( | |
| TransformerConfig, | |
| ) | |
| from speechut.modules import transformer_layer, LearnedPositionalEmbedding | |
| from speechut.modules import RelativePositionalEncoding | |
| # rewrite name for backward compatibility in `make_generation_fast_` | |
| def module_name_fordropout(module_name: str) -> str: | |
| if module_name == "TransformerEncoderBase": | |
| return "TransformerEncoder" | |
| else: | |
| return module_name | |
| class TransformerEncoderBase(FairseqEncoder): | |
| """ | |
| Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer | |
| is a :class:`TransformerEncoderLayer`. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| dictionary (~fairseq.data.Dictionary): encoding dictionary | |
| embed_tokens (torch.nn.Embedding): input embedding | |
| """ | |
| def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0): | |
| self.cfg = cfg | |
| super().__init__(dictionary) | |
| self.register_buffer("version", torch.Tensor([3])) | |
| self.dropout_module = FairseqDropout( | |
| cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) | |
| ) | |
| self.encoder_layerdrop = cfg.encoder.layerdrop | |
| embed_dim = embed_tokens.embedding_dim | |
| self.padding_idx = embed_tokens.padding_idx | |
| self.max_source_positions = cfg.max_source_positions | |
| self.embed_tokens = embed_tokens | |
| self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) | |
| self.embed_positions = ( | |
| PositionalEmbedding( | |
| cfg.max_source_positions, | |
| embed_dim, | |
| self.padding_idx, | |
| learned=cfg.encoder.learned_pos, | |
| ) | |
| if not cfg.no_token_positional_embeddings | |
| else None | |
| ) | |
| if cfg.layernorm_embedding: | |
| self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) | |
| else: | |
| self.layernorm_embedding = None | |
| if not cfg.adaptive_input and cfg.quant_noise.pq > 0: | |
| self.quant_noise = apply_quant_noise_( | |
| nn.Linear(embed_dim, embed_dim, bias=False), | |
| cfg.quant_noise.pq, | |
| cfg.quant_noise.pq_block_size, | |
| ) | |
| else: | |
| self.quant_noise = None | |
| if self.encoder_layerdrop > 0.0: | |
| self.layers = LayerDropModuleList(p=self.encoder_layerdrop) | |
| else: | |
| self.layers = nn.ModuleList([]) | |
| self.use_rel_pos_enc = use_rel_pos_enc | |
| self.scaling_for_att = scaling_for_att | |
| self.layers.extend( | |
| [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] | |
| ) | |
| self.num_layers = len(self.layers) | |
| if cfg.encoder.normalize_before: | |
| self.layer_norm = LayerNorm(embed_dim, export=cfg.export) | |
| else: | |
| self.layer_norm = None | |
| if self.use_rel_pos_enc: | |
| self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160) | |
| def build_encoder_layer(self, cfg): | |
| layer = transformer_layer.TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att) | |
| checkpoint = cfg.checkpoint_activations | |
| if checkpoint: | |
| offload_to_cpu = cfg.offload_activations | |
| layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
| # if we are checkpointing, enforce that FSDP always wraps the | |
| # checkpointed layer, regardless of layer size | |
| min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
| layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
| return layer | |
| def forward_embedding( | |
| self, src_tokens, token_embedding: Optional[torch.Tensor] = None | |
| ): | |
| # embed tokens and positions | |
| if token_embedding is None: | |
| token_embedding = self.embed_tokens(src_tokens) | |
| x = embed = self.embed_scale * token_embedding | |
| if self.embed_positions is not None: | |
| x = embed + self.embed_positions(src_tokens) | |
| if self.layernorm_embedding is not None: | |
| x = self.layernorm_embedding(x) | |
| x = self.dropout_module(x) | |
| if self.quant_noise is not None: | |
| x = self.quant_noise(x) | |
| return x, embed | |
| def forward( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| uniformity_layers: Optional[List[int]] = None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| return self.forward_scriptable( | |
| src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers | |
| ) | |
| # TorchScript doesn't support super() method so that the scriptable Subclass | |
| # can't access the base class model in Torchscript. | |
| # Current workaround is to add a helper function with different name and | |
| # call the helper function from scriptable Subclass. | |
| def forward_scriptable( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| uniformity_layers: Optional[List[int]] = None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| # compute padding mask | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
| has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() | |
| x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) | |
| # account for padding while computing the representation | |
| if has_pads: | |
| x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| if self.use_rel_pos_enc: | |
| x_len = x.shape[0] | |
| pos_seq = torch.arange(0, x_len).long().to(x.device) | |
| pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
| pos_k, pos_v = self.pos_emb(pos_seq) | |
| else: | |
| pos_k = None | |
| encoder_states = [] | |
| uniformity_hiddens = [] | |
| if return_all_hiddens: | |
| encoder_states.append(x) | |
| if uniformity_layers is not None and 0 in uniformity_layers: | |
| x = F.normalize(x.float(), dim=-1).type_as(x) | |
| uniformity_hiddens.append(x) | |
| # encoder layers | |
| for i, layer in enumerate(self.layers): | |
| x = layer( | |
| x, encoder_padding_mask=encoder_padding_mask if has_pads else None, | |
| pos_bias=pos_k, | |
| ) | |
| if uniformity_layers is not None and i+1 in uniformity_layers: | |
| x = F.normalize(x.float(), dim=-1).type_as(x) | |
| uniformity_hiddens.append(x) | |
| if return_all_hiddens: | |
| assert encoder_states is not None | |
| encoder_states.append(x) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
| # `forward` so we use a dictionary instead. | |
| # TorchScript does not support mixed values so the values are all lists. | |
| # The empty list is equivalent to None. | |
| src_lengths = ( | |
| src_tokens.ne(self.padding_idx) | |
| .sum(dim=1, dtype=torch.int32) | |
| .reshape(-1, 1) | |
| .contiguous() | |
| ) | |
| return { | |
| "encoder_out": [x], # T x B x C | |
| "encoder_padding_mask": [encoder_padding_mask], # B x T | |
| "encoder_embedding": [encoder_embedding], # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "uniformity_hiddens": uniformity_hiddens, # List[T x B x C] | |
| "src_tokens": [], | |
| "src_lengths": [src_lengths], | |
| } | |
| def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): | |
| """ | |
| Reorder encoder output according to *new_order*. | |
| Args: | |
| encoder_out: output from the ``forward()`` method | |
| new_order (LongTensor): desired order | |
| Returns: | |
| *encoder_out* rearranged according to *new_order* | |
| """ | |
| if len(encoder_out["encoder_out"]) == 0: | |
| new_encoder_out = [] | |
| else: | |
| new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] | |
| if len(encoder_out["encoder_padding_mask"]) == 0: | |
| new_encoder_padding_mask = [] | |
| else: | |
| new_encoder_padding_mask = [ | |
| encoder_out["encoder_padding_mask"][0].index_select(0, new_order) | |
| ] | |
| if len(encoder_out["encoder_embedding"]) == 0: | |
| new_encoder_embedding = [] | |
| else: | |
| new_encoder_embedding = [ | |
| encoder_out["encoder_embedding"][0].index_select(0, new_order) | |
| ] | |
| if len(encoder_out["src_tokens"]) == 0: | |
| src_tokens = [] | |
| else: | |
| src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] | |
| if len(encoder_out["src_lengths"]) == 0: | |
| src_lengths = [] | |
| else: | |
| src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] | |
| encoder_states = encoder_out["encoder_states"] | |
| if len(encoder_states) > 0: | |
| for idx, state in enumerate(encoder_states): | |
| encoder_states[idx] = state.index_select(1, new_order) | |
| return { | |
| "encoder_out": new_encoder_out, # T x B x C | |
| "encoder_padding_mask": new_encoder_padding_mask, # B x T | |
| "encoder_embedding": new_encoder_embedding, # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "src_tokens": src_tokens, # B x T | |
| "src_lengths": src_lengths, # B x 1 | |
| } | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| if self.embed_positions is None: | |
| return self.max_source_positions | |
| return min(self.max_source_positions, self.embed_positions.max_positions) | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
| if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): | |
| weights_key = "{}.embed_positions.weights".format(name) | |
| if weights_key in state_dict: | |
| print("deleting {0}".format(weights_key)) | |
| del state_dict[weights_key] | |
| state_dict[ | |
| "{}.embed_positions._float_tensor".format(name) | |
| ] = torch.FloatTensor(1) | |
| for i in range(self.num_layers): | |
| # update layer norms | |
| self.layers[i].upgrade_state_dict_named( | |
| state_dict, "{}.layers.{}".format(name, i) | |
| ) | |
| version_key = "{}.version".format(name) | |
| if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: | |
| # earlier checkpoints did not normalize after the stack of layers | |
| self.layer_norm = None | |
| self.normalize = False | |
| state_dict[version_key] = torch.Tensor([1]) | |
| return state_dict | |
| class TransformerEncoder(TransformerEncoderBase): | |
| def __init__(self, args, dictionary, embed_tokens): | |
| self.args = args | |
| super().__init__( | |
| TransformerConfig.from_namespace(args), | |
| dictionary, | |
| embed_tokens, | |
| use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False), | |
| scaling_for_att=getattr(args, "scaling_for_att", 1.0), | |
| ) | |
| def build_encoder_layer(self, args): | |
| return super().build_encoder_layer( | |
| TransformerConfig.from_namespace(args), | |
| ) | |
| def PositionalEmbedding( | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| padding_idx: int, | |
| learned: bool = False, | |
| ): | |
| if learned: | |
| # if padding_idx is specified then offset the embedding ids by | |
| # this index and adjust num_embeddings appropriately | |
| # TODO: The right place for this offset would be inside | |
| # LearnedPositionalEmbedding. Move this there for a cleaner implementation. | |
| if padding_idx is not None: | |
| num_embeddings = num_embeddings + padding_idx + 1 | |
| m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) | |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
| if padding_idx is not None: | |
| nn.init.constant_(m.weight[padding_idx], 0) | |
| else: | |
| m = SinusoidalPositionalEmbedding( | |
| embedding_dim, | |
| padding_idx, | |
| init_size=num_embeddings + padding_idx + 1, | |
| ) | |
| return m | |