|
|
|
|
|
from typing import List, Callable |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from matanyone.model.channel_attn import CAResBlock |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
def __init__(self, |
|
dim: int, |
|
nhead: int, |
|
dropout: float = 0.0, |
|
batch_first: bool = True, |
|
add_pe_to_qkv: List[bool] = [True, True, False]): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) |
|
self.norm = nn.LayerNorm(dim) |
|
self.dropout = nn.Dropout(dropout) |
|
self.add_pe_to_qkv = add_pe_to_qkv |
|
|
|
def forward(self, |
|
x: torch.Tensor, |
|
pe: torch.Tensor, |
|
attn_mask: bool = None, |
|
key_padding_mask: bool = None) -> torch.Tensor: |
|
x = self.norm(x) |
|
if any(self.add_pe_to_qkv): |
|
x_with_pe = x + pe |
|
q = x_with_pe if self.add_pe_to_qkv[0] else x |
|
k = x_with_pe if self.add_pe_to_qkv[1] else x |
|
v = x_with_pe if self.add_pe_to_qkv[2] else x |
|
else: |
|
q = k = v = x |
|
|
|
r = x |
|
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] |
|
return r + self.dropout(x) |
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__(self, |
|
dim: int, |
|
nhead: int, |
|
dropout: float = 0.0, |
|
batch_first: bool = True, |
|
add_pe_to_qkv: List[bool] = [True, True, False], |
|
residual: bool = True, |
|
norm: bool = True): |
|
super().__init__() |
|
self.cross_attn = nn.MultiheadAttention(dim, |
|
nhead, |
|
dropout=dropout, |
|
batch_first=batch_first) |
|
if norm: |
|
self.norm = nn.LayerNorm(dim) |
|
else: |
|
self.norm = nn.Identity() |
|
self.dropout = nn.Dropout(dropout) |
|
self.add_pe_to_qkv = add_pe_to_qkv |
|
self.residual = residual |
|
|
|
def forward(self, |
|
x: torch.Tensor, |
|
mem: torch.Tensor, |
|
x_pe: torch.Tensor, |
|
mem_pe: torch.Tensor, |
|
attn_mask: bool = None, |
|
*, |
|
need_weights: bool = False) -> (torch.Tensor, torch.Tensor): |
|
x = self.norm(x) |
|
if self.add_pe_to_qkv[0]: |
|
q = x + x_pe |
|
else: |
|
q = x |
|
|
|
if any(self.add_pe_to_qkv[1:]): |
|
mem_with_pe = mem + mem_pe |
|
k = mem_with_pe if self.add_pe_to_qkv[1] else mem |
|
v = mem_with_pe if self.add_pe_to_qkv[2] else mem |
|
else: |
|
k = v = mem |
|
r = x |
|
x, weights = self.cross_attn(q, |
|
k, |
|
v, |
|
attn_mask=attn_mask, |
|
need_weights=need_weights, |
|
average_attn_weights=False) |
|
|
|
if self.residual: |
|
return r + self.dropout(x), weights |
|
else: |
|
return self.dropout(x), weights |
|
|
|
|
|
class FFN(nn.Module): |
|
def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): |
|
super().__init__() |
|
self.linear1 = nn.Linear(dim_in, dim_ff) |
|
self.linear2 = nn.Linear(dim_ff, dim_in) |
|
self.norm = nn.LayerNorm(dim_in) |
|
|
|
if isinstance(activation, str): |
|
self.activation = _get_activation_fn(activation) |
|
else: |
|
self.activation = activation |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
r = x |
|
x = self.norm(x) |
|
x = self.linear2(self.activation(self.linear1(x))) |
|
x = r + x |
|
return x |
|
|
|
|
|
class PixelFFN(nn.Module): |
|
def __init__(self, dim: int): |
|
super().__init__() |
|
self.dim = dim |
|
self.conv = CAResBlock(dim, dim) |
|
|
|
def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
bs, num_objects, _, h, w = pixel.shape |
|
pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) |
|
pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() |
|
|
|
x = self.conv(pixel_flat) |
|
x = x.view(bs, num_objects, self.dim, h, w) |
|
return x |
|
|
|
|
|
class OutputFFN(nn.Module): |
|
def __init__(self, dim_in: int, dim_out: int, activation=F.relu): |
|
super().__init__() |
|
self.linear1 = nn.Linear(dim_in, dim_out) |
|
self.linear2 = nn.Linear(dim_out, dim_out) |
|
|
|
if isinstance(activation, str): |
|
self.activation = _get_activation_fn(activation) |
|
else: |
|
self.activation = activation |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.linear2(self.activation(self.linear1(x))) |
|
return x |
|
|
|
|
|
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|