|
from torch import nn
|
|
from einops import reduce
|
|
from .helper_funcs import get_dct_weights
|
|
|
|
|
|
class FCANet(nn.Module):
|
|
def __init__(self, *, chan_in, chan_out, reduction=4, width):
|
|
super().__init__()
|
|
|
|
freq_w, freq_h = ([0] * 8), list(
|
|
range(8)
|
|
)
|
|
dct_weights = get_dct_weights(
|
|
width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
|
|
)
|
|
self.register_buffer("dct_weights", dct_weights)
|
|
|
|
chan_intermediate = max(3, chan_out // reduction)
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(chan_in, chan_intermediate, 1),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Conv2d(chan_intermediate, chan_out, 1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = reduce(
|
|
x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
|
|
)
|
|
return self.net(x)
|
|
|