Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from typing import List, Tuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Unet(nn.Module): | |
""" | |
PyTorch implementation of a U-Net model. | |
O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks | |
for biomedical image segmentation. In International Conference on Medical | |
image computing and computer-assisted intervention, pages 234–241. | |
Springer, 2015. | |
""" | |
def __init__( | |
self, | |
in_chans: int, | |
out_chans: int, | |
chans: int = 32, | |
num_pool_layers: int = 4, | |
drop_prob: float = 0.0, | |
): | |
""" | |
Parameters | |
---------- | |
in_chans : int | |
Number of channels in the input to the U-Net model. | |
out_chans : int | |
Number of channels in the output to the U-Net model. | |
chans : int, optional | |
Number of output channels of the first convolution layer. Default is 32. | |
num_pool_layers : int, optional | |
Number of down-sampling and up-sampling layers. Default is 4. | |
drop_prob : float, optional | |
Dropout probability. Default is 0.0. | |
""" | |
super().__init__() | |
self.in_chans = in_chans | |
self.out_chans = out_chans | |
self.chans = chans | |
self.num_pool_layers = num_pool_layers | |
self.drop_prob = drop_prob | |
self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) | |
ch = chans | |
for _ in range(num_pool_layers - 1): | |
self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) | |
ch *= 2 | |
self.conv = ConvBlock(ch, ch * 2, drop_prob) | |
self.up_conv = nn.ModuleList() | |
self.up_transpose_conv = nn.ModuleList() | |
for _ in range(num_pool_layers - 1): | |
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) | |
self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) | |
ch //= 2 | |
self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) | |
self.up_conv.append( | |
nn.Sequential( | |
ConvBlock(ch * 2, ch, drop_prob), | |
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), | |
) | |
) | |
def forward(self, image: torch.Tensor) -> torch.Tensor: | |
""" | |
Parameters | |
---------- | |
image : torch.Tensor | |
Input 4D tensor of shape `(N, in_chans, H, W)`. | |
Returns | |
------- | |
torch.Tensor | |
Output tensor of shape `(N, out_chans, H, W)`. | |
""" | |
stack = [] | |
output = image | |
# apply down-sampling layers | |
for layer in self.down_sample_layers: | |
output = layer(output) | |
stack.append(output) | |
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) | |
output = self.conv(output) | |
# apply up-sampling layers | |
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): | |
downsample_layer = stack.pop() | |
output = transpose_conv(output) | |
# reflect pad on the right/botton if needed to handle odd input dimensions | |
padding = [0, 0, 0, 0] | |
if output.shape[-1] != downsample_layer.shape[-1]: | |
padding[1] = 1 # padding right | |
if output.shape[-2] != downsample_layer.shape[-2]: | |
padding[3] = 1 # padding bottom | |
if torch.sum(torch.tensor(padding)) != 0: | |
output = F.pad(output, padding, "reflect") | |
output = torch.cat([output, downsample_layer], dim=1) | |
output = conv(output) | |
return output | |
class ConvBlock(nn.Module): | |
""" | |
A Convolutional Block that consists of two convolution layers each followed by | |
instance normalization, LeakyReLU activation and dropout. | |
""" | |
def __init__(self, in_chans: int, out_chans: int, drop_prob: float): | |
""" | |
Parameters | |
---------- | |
in_chans : int | |
Number of channels in the input. | |
out_chans : int | |
Number of channels in the output. | |
drop_prob : float | |
Dropout probability. | |
""" | |
super().__init__() | |
self.in_chans = in_chans | |
self.out_chans = out_chans | |
self.drop_prob = drop_prob | |
self.layers = nn.Sequential( | |
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), | |
nn.InstanceNorm2d(out_chans), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Dropout2d(drop_prob), | |
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), | |
nn.InstanceNorm2d(out_chans), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Dropout2d(drop_prob), | |
) | |
def forward(self, image: torch.Tensor) -> torch.Tensor: | |
""" | |
Parameters | |
---------- | |
image : ndarray | |
Input 4D tensor of shape `(N, in_chans, H, W)`. | |
Returns | |
------- | |
ndarray | |
Output tensor of shape `(N, out_chans, H, W)`. | |
""" | |
return self.layers(image) | |
class TransposeConvBlock(nn.Module): | |
""" | |
A Transpose Convolutional Block that consists of one convolution transpose | |
layers followed by instance normalization and LeakyReLU activation. | |
""" | |
def __init__(self, in_chans: int, out_chans: int): | |
""" | |
Parameters | |
---------- | |
in_chans : int | |
Number of channels in the input. | |
out_chans : int | |
Number of channels in the output. | |
""" | |
super().__init__() | |
self.in_chans = in_chans | |
self.out_chans = out_chans | |
self.layers = nn.Sequential( | |
nn.ConvTranspose2d( | |
in_chans, out_chans, kernel_size=2, stride=2, bias=False | |
), | |
nn.InstanceNorm2d(out_chans), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
) | |
def forward(self, image: torch.Tensor) -> torch.Tensor: | |
""" | |
Parameters | |
---------- | |
image : torch.Tensor | |
Input 4D tensor of shape `(N, in_chans, H, W)`. | |
Returns | |
------- | |
torch.Tensor | |
Output tensor of shape `(N, out_chans, H*2, W*2)`. | |
""" | |
return self.layers(image) | |