import torch import copy from torch.nn import functional as F from torch.nn.modules.module import Module from torch.nn.modules.container import ModuleList from torch.nn.init import xavier_uniform_ from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm from .attention import MultiheadAttention from .transformer import _get_activation_fn class TransformerEncoderLayerImproved(Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): super(TransformerEncoderLayerImproved, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) if d_global2 is not None: self.linear_global2 = Linear(d_global2, d_model) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = Dropout(dropout) self.dropout2_2 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.activation = _get_activation_fn(activation) def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super(TransformerEncoderLayerImproved, self).__setstate__(state) def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None): src1 = self.norm1(src) src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) if memory2 is not None: src2_2 = self.linear_global2(memory2) src = src + self.dropout2_2(src2_2) src1 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src1)))) src = src + self.dropout2(src2) return src class TransformerDecoderLayerImproved(Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): super(TransformerDecoderLayerImproved, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.dropout3 = Dropout(dropout) self.activation = _get_activation_fn(activation) def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super(TransformerDecoderLayerImproved, self).__setstate__(state) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): tgt1 = self.norm1(tgt) tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt1 = self.norm2(tgt) tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt1 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) tgt = tgt + self.dropout3(tgt2) return tgt class TransformerDecoderLayerGlobalImproved(Module): def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): super(TransformerDecoderLayerGlobalImproved, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) self.linear_global = Linear(d_global, d_model) if d_global2 is not None: self.linear_global2 = Linear(d_global2, d_model) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.dropout2_2 = Dropout(dropout) self.dropout3 = Dropout(dropout) self.activation = _get_activation_fn(activation) def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state) def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs): tgt1 = self.norm1(tgt) tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.linear_global(memory) tgt = tgt + self.dropout2(tgt2) # implicit broadcast if memory2 is not None: tgt2_2 = self.linear_global2(memory2) tgt = tgt + self.dropout2_2(tgt2_2) tgt1 = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) tgt = tgt + self.dropout3(tgt2) return tgt