from torch import nn | |
from .ChanNorm import ChanNorm | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = ChanNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) | |
from torch import nn | |
from .ChanNorm import ChanNorm | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = ChanNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) | |