| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class TimeEmbedding(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| | half_dim = dim // 2 |
| | emb = torch.log(torch.tensor(10000)) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
| | self.register_buffer('emb', emb) |
| |
|
| | def forward(self, t): |
| | emb = t.float()[:, None] * self.emb[None, :] |
| | emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) |
| | return emb |
| |
|
| | class Block(nn.Module): |
| | def __init__(self, in_ch, out_ch, time_emb_dim, up=False): |
| | super().__init__() |
| | self.time_mlp = nn.Linear(time_emb_dim, out_ch) |
| | if up: |
| | self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1) |
| | else: |
| | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) |
| | self.norm = nn.GroupNorm(8, out_ch) |
| | self.act = nn.SiLU() |
| |
|
| | def forward(self, x, t): |
| | h = self.conv(x) |
| | time_emb = self.time_mlp(t) |
| | h = h + time_emb[:, :, None, None] |
| | h = self.norm(h) |
| | h = self.act(h) |
| | return h |
| |
|
| | class SmoothDiffusionUNet(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | |
| | |
| | self.time_mlp = TimeEmbedding(config.time_emb_dim) |
| | |
| | |
| | self.down1 = Block(config.in_channels, config.base_channels, config.time_emb_dim) |
| | self.down2 = Block(config.base_channels, config.base_channels*2, config.time_emb_dim) |
| | self.down3 = Block(config.base_channels*2, config.base_channels*4, config.time_emb_dim) |
| | |
| | |
| | self.mid1 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) |
| | self.mid2 = Block(config.base_channels*4, config.base_channels*4, config.time_emb_dim) |
| | |
| | |
| | self.up1 = Block(config.base_channels*4, config.base_channels*2, config.time_emb_dim, up=True) |
| | self.up2 = Block(config.base_channels*6, config.base_channels, config.time_emb_dim, up=True) |
| | self.up3 = Block(config.base_channels*3, config.base_channels, config.time_emb_dim, up=True) |
| | |
| | |
| | self.out = nn.Conv2d(config.base_channels*2, config.in_channels, kernel_size=3, padding=1) |
| |
|
| | def forward(self, x, t): |
| | |
| | t_emb = self.time_mlp(t) |
| | |
| | |
| | h1 = self.down1(x, t_emb) |
| | h2 = self.down2(F.max_pool2d(h1, 2), t_emb) |
| | h3 = self.down3(F.max_pool2d(h2, 2), t_emb) |
| | |
| | |
| | h = self.mid1(F.max_pool2d(h3, 2), t_emb) |
| | h = self.mid2(h, t_emb) |
| | |
| | |
| | h = self.up1(h, t_emb) |
| | h = torch.cat([h, h3], dim=1) |
| | h = self.up2(h, t_emb) |
| | h = torch.cat([h, h2], dim=1) |
| | h = self.up3(h, t_emb) |
| | h = torch.cat([h, h1], dim=1) |
| | |
| | return self.out(h) |