File size: 974 Bytes
002ca81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import nn
from einops import reduce
from .helper_funcs import get_dct_weights


class FCANet(nn.Module):
    def __init__(self, *, chan_in, chan_out, reduction=4, width):
        super().__init__()

        freq_w, freq_h = ([0] * 8), list(
            range(8)
        )  # in paper, it seems 16 frequencies was ideal
        dct_weights = get_dct_weights(
            width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
        )
        self.register_buffer("dct_weights", dct_weights)

        chan_intermediate = max(3, chan_out // reduction)

        self.net = nn.Sequential(
            nn.Conv2d(chan_in, chan_intermediate, 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(chan_intermediate, chan_out, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = reduce(
            x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
        )
        return self.net(x)