OmniSVG-3B / deepsvg /model /layers /improved_transformer.py
OmniSVG's picture
Upload 80 files
c1ce505 verified
import torch
import copy
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from .attention import MultiheadAttention
from .transformer import _get_activation_fn
class TransformerEncoderLayerImproved(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
super(TransformerEncoderLayerImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
if d_global2 is not None:
self.linear_global2 = Linear(d_global2, d_model)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2_2 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayerImproved, self).__setstate__(state)
def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None):
src1 = self.norm1(src)
src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
if memory2 is not None:
src2_2 = self.linear_global2(memory2)
src = src + self.dropout2_2(src2_2)
src1 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src1))))
src = src + self.dropout2(src2)
return src
class TransformerDecoderLayerImproved(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
super(TransformerDecoderLayerImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerDecoderLayerImproved, self).__setstate__(state)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
tgt1 = self.norm1(tgt)
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt1 = self.norm2(tgt)
tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt1 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class TransformerDecoderLayerGlobalImproved(Module):
def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None):
super(TransformerDecoderLayerGlobalImproved, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear_global = Linear(d_global, d_model)
if d_global2 is not None:
self.linear_global2 = Linear(d_global2, d_model)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout2_2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state)
def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs):
tgt1 = self.norm1(tgt)
tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.linear_global(memory)
tgt = tgt + self.dropout2(tgt2) # implicit broadcast
if memory2 is not None:
tgt2_2 = self.linear_global2(memory2)
tgt = tgt + self.dropout2_2(tgt2_2)
tgt1 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1))))
tgt = tgt + self.dropout3(tgt2)
return tgt