michaelriedl's picture
Initial dump
002ca81
raw
history blame contribute delete
974 Bytes
from torch import nn
from einops import reduce
from .helper_funcs import get_dct_weights
class FCANet(nn.Module):
def __init__(self, *, chan_in, chan_out, reduction=4, width):
super().__init__()
freq_w, freq_h = ([0] * 8), list(
range(8)
) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(
width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
)
self.register_buffer("dct_weights", dct_weights)
chan_intermediate = max(3, chan_out // reduction)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = reduce(
x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
)
return self.net(x)