Spaces:
Running
on
L4
Running
on
L4
import torch | |
import copy | |
from torch.nn import functional as F | |
from torch.nn.modules.module import Module | |
from torch.nn.modules.container import ModuleList | |
from torch.nn.init import xavier_uniform_ | |
from torch.nn.modules.dropout import Dropout | |
from torch.nn.modules.linear import Linear | |
from torch.nn.modules.normalization import LayerNorm | |
from .attention import MultiheadAttention | |
from .transformer import _get_activation_fn | |
class TransformerEncoderLayerImproved(Module): | |
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): | |
super(TransformerEncoderLayerImproved, self).__init__() | |
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
if d_global2 is not None: | |
self.linear_global2 = Linear(d_global2, d_model) | |
# Implementation of Feedforward model | |
self.linear1 = Linear(d_model, dim_feedforward) | |
self.dropout = Dropout(dropout) | |
self.linear2 = Linear(dim_feedforward, d_model) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.dropout1 = Dropout(dropout) | |
self.dropout2_2 = Dropout(dropout) | |
self.dropout2 = Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
def __setstate__(self, state): | |
if 'activation' not in state: | |
state['activation'] = F.relu | |
super(TransformerEncoderLayerImproved, self).__setstate__(state) | |
def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None): | |
src1 = self.norm1(src) | |
src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] | |
src = src + self.dropout1(src2) | |
if memory2 is not None: | |
src2_2 = self.linear_global2(memory2) | |
src = src + self.dropout2_2(src2_2) | |
src1 = self.norm2(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src1)))) | |
src = src + self.dropout2(src2) | |
return src | |
class TransformerDecoderLayerImproved(Module): | |
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): | |
super(TransformerDecoderLayerImproved, self).__init__() | |
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = Linear(d_model, dim_feedforward) | |
self.dropout = Dropout(dropout) | |
self.linear2 = Linear(dim_feedforward, d_model) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.norm3 = LayerNorm(d_model) | |
self.dropout1 = Dropout(dropout) | |
self.dropout2 = Dropout(dropout) | |
self.dropout3 = Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
def __setstate__(self, state): | |
if 'activation' not in state: | |
state['activation'] = F.relu | |
super(TransformerDecoderLayerImproved, self).__setstate__(state) | |
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, | |
tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
tgt1 = self.norm1(tgt) | |
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt1 = self.norm2(tgt) | |
tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt1 = self.norm3(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) | |
tgt = tgt + self.dropout3(tgt2) | |
return tgt | |
class TransformerDecoderLayerGlobalImproved(Module): | |
def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): | |
super(TransformerDecoderLayerGlobalImproved, self).__init__() | |
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.linear_global = Linear(d_global, d_model) | |
if d_global2 is not None: | |
self.linear_global2 = Linear(d_global2, d_model) | |
# Implementation of Feedforward model | |
self.linear1 = Linear(d_model, dim_feedforward) | |
self.dropout = Dropout(dropout) | |
self.linear2 = Linear(dim_feedforward, d_model) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.dropout1 = Dropout(dropout) | |
self.dropout2 = Dropout(dropout) | |
self.dropout2_2 = Dropout(dropout) | |
self.dropout3 = Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
def __setstate__(self, state): | |
if 'activation' not in state: | |
state['activation'] = F.relu | |
super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state) | |
def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs): | |
tgt1 = self.norm1(tgt) | |
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout1(tgt2) | |
tgt2 = self.linear_global(memory) | |
tgt = tgt + self.dropout2(tgt2) # implicit broadcast | |
if memory2 is not None: | |
tgt2_2 = self.linear_global2(memory2) | |
tgt = tgt + self.dropout2_2(tgt2_2) | |
tgt1 = self.norm2(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) | |
tgt = tgt + self.dropout3(tgt2) | |
return tgt | |