import math import torch from torch import nn from einops import rearrange from inspect import isfunction def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: try: m.weight.data.normal_(0.0, 0.02) except: pass elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class LayerNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) # self.norm = nn.BatchNorm2d(dim) # self.norm = nn.GroupNorm(dim // 32, dim) def forward(self, x): x = self.norm(x) return self.fn(x) # building block modules class ConvNextBlock(nn.Module): """ https://arxiv.org/abs/2201.03545 """ def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True): super().__init__() self.mlp = nn.Sequential( nn.GELU(), nn.Linear(time_emb_dim, dim*2) ) if exists(time_emb_dim) else None self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) self.net = nn.Sequential( LayerNorm(dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, 1, 1), nn.GELU(), nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1), ) # self.noise_adding = NoiseInjection(dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): h = self.ds_conv(x) if exists(self.mlp): assert exists(time_emb), 'time emb must be passed in' condition = self.mlp(time_emb) condition = rearrange(condition, 'b c -> b c 1 1') weight, bias = torch.split(condition, x.shape[1],dim=1) h = h * (1 + weight) + bias h = self.net(h) # h = self.noise_adding(h) return h + self.res_conv(x) class ConvNextBlock_dis(nn.Module): """ https://arxiv.org/abs/2201.03545 """ def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True): super().__init__() self.mlp = nn.Sequential( nn.GELU(), nn.Linear(time_emb_dim, dim*2) ) if exists(time_emb_dim) else None self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) self.net = nn.Sequential( nn.BatchNorm2d(dim) if norm else nn.Identity(), # LayerNorm(dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, 1, 1), nn.GELU(), nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x): h = self.ds_conv(x) h = self.net(h) return h + self.res_conv(x) class LinearAttention(nn.Module): def __init__(self, dim, heads = 4, dim_head = 32): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) q = q * self.scale k = k.softmax(dim = -1) context = torch.einsum('b h d n, b h e n -> b h d e', k, v) out = torch.einsum('b h d e, b h d n -> b h e n', context, q) out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) return self.to_out(out) # model class UNet(nn.Module): def __init__( self, dim = 32, dim_mults=(1, 2, 4, 8, 16, 32, 32), channels = 3, ): super().__init__() self.channels = dim dims = [dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) self.model_depth = len(dim_mults) time_dim = dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 2), nn.GELU(), nn.Linear(dim * 2, dim) ) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False) for ind, (dim_in, dim_out) in enumerate(in_out): self.downs.append(nn.ModuleList([ ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), nn.AvgPool2d(2), Residual(PreNorm(dim_out, LinearAttention(dim_out))) if ind >= (num_resolutions - 3) else nn.Identity(), ConvNextBlock(dim_out, dim_out, time_emb_dim=time_dim), ])) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): self.ups.append(nn.ModuleList([ ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), nn.Upsample(scale_factor=2, mode='nearest'), Residual(PreNorm(dim_in, LinearAttention(dim_in))) if ind < 3 else nn.Identity(), ConvNextBlock(dim_in, dim_in, time_emb_dim=time_dim), ])) self.final_conv = nn.Conv2d(dim, 3, 1, bias=False) def forward(self, x, time): x = self.initial(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for convnext, downsample, attn, convnext2 in self.downs: x = convnext(x, t) x = downsample(x) h.append(x) x = attn(x) x = convnext2(x, t) for convnext, upsample, attn, convnext2 in self.ups: x = torch.cat((x, h.pop()), dim=1) x = convnext(x, t) x = upsample(x) x = attn(x) x = convnext2(x, t) return self.final_conv(x) class Discriminator(nn.Module): def __init__( self, dim=32, dim_mults=(1, 2, 4, 8, 16, 32, 32), channels=3, with_time_emb=True, ): super().__init__() self.channels = dim dims = [dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) self.model_depth = len(dim_mults) self.downs = nn.ModuleList([]) num_resolutions = len(in_out) self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ ConvNextBlock_dis(dim_in, dim_out, norm=ind != 0), nn.AvgPool2d(2), ConvNextBlock_dis(dim_out, dim_out), ])) dim_out = dim_mults[-1] * dim self.out = nn.Conv2d(dim_out, 1, 1, bias=False) def forward(self, x): x = self.initial(x) for convnext, downsample, convnext2 in self.downs: x = convnext(x) x = downsample(x) x = convnext2(x) return self.out(x).view(x.shape[0], -1)