File size: 974 Bytes
002ca81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torch import nn
from einops import reduce
from .helper_funcs import get_dct_weights
class FCANet(nn.Module):
def __init__(self, *, chan_in, chan_out, reduction=4, width):
super().__init__()
freq_w, freq_h = ([0] * 8), list(
range(8)
) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(
width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
)
self.register_buffer("dct_weights", dct_weights)
chan_intermediate = max(3, chan_out // reduction)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = reduce(
x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
)
return self.net(x)
|