| | import torch |
| | import torch.nn as nn |
| |
|
| | class ConditionalGenerator(nn.Module): |
| | def __init__(self, num_styles=3): |
| | super().__init__() |
| | |
| | self.style_embed = nn.Embedding(num_styles, 64) |
| | |
| | |
| | self.downsample = nn.Sequential( |
| | nn.Conv2d(3 + 64, 64, 7, padding=3), |
| | nn.InstanceNorm2d(64), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| | nn.InstanceNorm2d(128), |
| | nn.ReLU() |
| | ) |
| | self.res_blocks = nn.Sequential(*[ResidualBlock(128) for _ in range(6)]) |
| | self.upsample = nn.Sequential( |
| | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
| | nn.InstanceNorm2d(64), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 3, 7, padding=3), |
| | nn.Tanh() |
| | ) |
| |
|
| | def forward(self, x, style_id): |
| | |
| | style = self.style_embed(style_id).unsqueeze(-1).unsqueeze(-1) |
| | style = style.repeat(1, 1, x.shape[2], x.shape[3]) |
| | x = torch.cat([x, style], dim=1) |
| | |
| | x = self.downsample(x) |
| | x = self.res_blocks(x) |
| | return self.upsample(x) |
| |
|
| | class Discriminator(nn.Module): |
| | def __init__(self, num_styles=3): |
| | super().__init__() |
| | self.style_embed = nn.Embedding(num_styles, 64) |
| | self.model = nn.Sequential( |
| | nn.Conv2d(3 + 64, 64, 4, stride=2, padding=1), |
| | nn.LeakyReLU(0.2), |
| | nn.Conv2d(64, 128, 4, stride=2, padding=1), |
| | nn.InstanceNorm2d(128), |
| | nn.LeakyReLU(0.2), |
| | nn.Conv2d(128, 256, 4, stride=2, padding=1), |
| | nn.InstanceNorm2d(256), |
| | nn.LeakyReLU(0.2), |
| | nn.Conv2d(256, 1, 4, padding=1) |
| | ) |
| |
|
| | def forward(self, x, style_id): |
| | style = self.style_embed(style_id).unsqueeze(-1).unsqueeze(-1) |
| | style = style.repeat(1, 1, x.shape[2], x.shape[3]) |
| | x = torch.cat([x, style], dim=1) |
| | return self.model(x) |