MonsterForge-small / LightweightGAN /LinearAttention.py
michaelriedl's picture
Initial dump
002ca81
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from .DepthWiseConv2d import DepthWiseConv2d
class LinearAttention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
self.kernel_size = kernel_size
self.nonlin = nn.GELU()
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
# linear attention
lin_q, lin_k, lin_v = (
self.to_lin_q(fmap),
*self.to_lin_kv(fmap).chunk(2, dim=1),
)
lin_q, lin_k, lin_v = map(
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
(lin_q, lin_k, lin_v),
)
lin_q = lin_q.softmax(dim=-1)
lin_k = lin_k.softmax(dim=-2)
lin_q = lin_q * self.scale
context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
# conv-like full attention
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
)
k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
k, v = map(
lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
)
q = rearrange(q, "b c ... -> b (...) c") * self.scale
sim = einsum("b i d, b i j d -> b i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
full_out = einsum("b i j, b i j d -> b i d", attn, v)
full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
# add outputs of linear attention + conv like full attention
lin_out = self.nonlin(lin_out)
out = torch.cat((lin_out, full_out), dim=1)
return self.to_out(out)