File size: 2,679 Bytes
002ca81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn.functional as F

from torch import nn, einsum
from einops import rearrange

from .DepthWiseConv2d import DepthWiseConv2d


class LinearAttention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        self.dim_head = dim_head
        inner_dim = dim_head * heads

        self.kernel_size = kernel_size
        self.nonlin = nn.GELU()

        self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
        self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)

        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)

        self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)

    def forward(self, fmap):
        h, x, y = self.heads, *fmap.shape[-2:]

        # linear attention

        lin_q, lin_k, lin_v = (
            self.to_lin_q(fmap),
            *self.to_lin_kv(fmap).chunk(2, dim=1),
        )
        lin_q, lin_k, lin_v = map(
            lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
            (lin_q, lin_k, lin_v),
        )

        lin_q = lin_q.softmax(dim=-1)
        lin_k = lin_k.softmax(dim=-2)

        lin_q = lin_q * self.scale

        context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
        lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
        lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)

        # conv-like full attention

        q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
        )

        k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
        v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)

        k, v = map(
            lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
        )

        q = rearrange(q, "b c ... -> b (...) c") * self.scale

        sim = einsum("b i d, b i j d -> b i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()

        attn = sim.softmax(dim=-1)

        full_out = einsum("b i j, b i j d -> b i d", attn, v)
        full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)

        # add outputs of linear attention + conv like full attention

        lin_out = self.nonlin(lin_out)
        out = torch.cat((lin_out, full_out), dim=1)
        return self.to_out(out)