import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class Down(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.net = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_ch, out_ch)) def forward(self, x): return self.net(x) class Up(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, base_c=32): super().__init__() self.inc = DoubleConv(in_channels, base_c) self.down1 = Down(base_c, base_c*2) self.down2 = Down(base_c*2, base_c*4) self.up1 = Up(base_c*4, base_c*2) self.up2 = Up(base_c*2, base_c) self.outc = nn.Conv2d(base_c, out_channels, 1) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x = self.up1(x3, x2) x = self.up2(x, x1) x = self.outc(x) return torch.sigmoid(x)