michaelriedl's picture
Initial dump
002ca81
raw
history blame contribute delete
271 Bytes
from torch import nn
from .ChanNorm import ChanNorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = ChanNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))