| import torch.nn as nn | |
| class ConvINRelu(nn.Module): | |
| """ | |
| A sequence of Convolution, Instance Normalization, and ReLU activation | |
| """ | |
| def __init__(self, channels_in, channels_out, stride): | |
| super(ConvINRelu, self).__init__() | |
| self.layers = nn.Sequential( | |
| nn.Conv2d(channels_in, channels_out, 3, stride, padding=1), | |
| nn.InstanceNorm2d(channels_out), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class ConvBlock(nn.Module): | |
| ''' | |
| Network that composed by layers of ConvINRelu | |
| ''' | |
| def __init__(self, in_channels, out_channels, blocks=1, stride=1): | |
| super(ConvBlock, self).__init__() | |
| layers = [ConvINRelu(in_channels, out_channels, stride)] if blocks != 0 else [] | |
| for _ in range(blocks - 1): | |
| layer = ConvINRelu(out_channels, out_channels, 1) | |
| layers.append(layer) | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.layers(x) | |