Spaces:
Running
on
L4
Running
on
L4
File size: 5,654 Bytes
c1ce505 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import copy
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from .attention import MultiheadAttention
from .transformer import _get_activation_fn
class TransformerEncoderLayerImproved(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
super(TransformerEncoderLayerImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
if d_global2 is not None:
self.linear_global2 = Linear(d_global2, d_model)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2_2 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayerImproved, self).__setstate__(state)
def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None):
src1 = self.norm1(src)
src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
if memory2 is not None:
src2_2 = self.linear_global2(memory2)
src = src + self.dropout2_2(src2_2)
src1 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src1))))
src = src + self.dropout2(src2)
return src
class TransformerDecoderLayerImproved(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
super(TransformerDecoderLayerImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerDecoderLayerImproved, self).__setstate__(state)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
tgt1 = self.norm1(tgt)
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt1 = self.norm2(tgt)
tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt1 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class TransformerDecoderLayerGlobalImproved(Module):
def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
super(TransformerDecoderLayerGlobalImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear_global = Linear(d_global, d_model)
if d_global2 is not None:
self.linear_global2 = Linear(d_global2, d_model)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout2_2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state)
def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs):
tgt1 = self.norm1(tgt)
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.linear_global(memory)
tgt = tgt + self.dropout2(tgt2) # implicit broadcast
if memory2 is not None:
tgt2_2 = self.linear_global2(memory2)
tgt = tgt + self.dropout2_2(tgt2_2)
tgt1 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
tgt = tgt + self.dropout3(tgt2)
return tgt
|