|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from torch import nn, einsum
|
|
from einops import rearrange
|
|
|
|
from .DepthWiseConv2d import DepthWiseConv2d
|
|
|
|
|
|
class LinearAttention(nn.Module):
|
|
def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
inner_dim = dim_head * heads
|
|
|
|
self.kernel_size = kernel_size
|
|
self.nonlin = nn.GELU()
|
|
|
|
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
|
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
|
|
|
|
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
|
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
|
|
|
|
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
|
|
|
|
def forward(self, fmap):
|
|
h, x, y = self.heads, *fmap.shape[-2:]
|
|
|
|
|
|
|
|
lin_q, lin_k, lin_v = (
|
|
self.to_lin_q(fmap),
|
|
*self.to_lin_kv(fmap).chunk(2, dim=1),
|
|
)
|
|
lin_q, lin_k, lin_v = map(
|
|
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
|
|
(lin_q, lin_k, lin_v),
|
|
)
|
|
|
|
lin_q = lin_q.softmax(dim=-1)
|
|
lin_k = lin_k.softmax(dim=-2)
|
|
|
|
lin_q = lin_q * self.scale
|
|
|
|
context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
|
|
lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
|
|
lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
|
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
|
|
q, k, v = map(
|
|
lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
|
|
)
|
|
|
|
k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
|
v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
|
|
|
k, v = map(
|
|
lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
|
|
)
|
|
|
|
q = rearrange(q, "b c ... -> b (...) c") * self.scale
|
|
|
|
sim = einsum("b i d, b i j d -> b i j", q, k)
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
full_out = einsum("b i j, b i j d -> b i d", attn, v)
|
|
full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
|
lin_out = self.nonlin(lin_out)
|
|
out = torch.cat((lin_out, full_out), dim=1)
|
|
return self.to_out(out)
|
|
|