import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms class UNet(nn.Module): def __init__(self, num_classes): super(UNet, self).__init__() self.num_classes = num_classes self.contracting_11 = self.conv_block(in_channels=3, out_channels=64) self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2) self.contracting_21 = self.conv_block(in_channels=64, out_channels=128) self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2) self.contracting_31 = self.conv_block(in_channels=128, out_channels=256) self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2) self.contracting_41 = self.conv_block(in_channels=256, out_channels=512) self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2) self.middle = self.conv_block(in_channels=512, out_channels=1024) self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1) self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512) self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) self.expansive_22 = self.conv_block(in_channels=512, out_channels=256) self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1) self.expansive_32 = self.conv_block(in_channels=256, out_channels=128) self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1) self.expansive_42 = self.conv_block(in_channels=128, out_channels=64) self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1) def conv_block(self, in_channels, out_channels): block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=out_channels), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=out_channels)) return block def forward(self, X): contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256] contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128] contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128] contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64] contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64] contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32] contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32] contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16] middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16] expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32] expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32] expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64] expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64] expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128] expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128] expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256] expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256] output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256] return output_out