|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
class DoubleConv(nn.Module):
|
|
"""
|
|
This is the core building block of the U-Net architecture.
|
|
Use consecutive convolutional layers
|
|
Each followed by batch normalization and ReLU activation
|
|
"""
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
"""
|
|
nn.Conv2d:
|
|
Applies a 2D convolution filter (kernel size 3×3)
|
|
padding=1 ensures the output spatial size stays the same
|
|
First conv changes input channels → output channels
|
|
Second conv keeps it at out_channels
|
|
|
|
nn.BatchNorm2d
|
|
Normalizes activations across the batch and channels
|
|
Helps stabilize and speed up training
|
|
Reduces internal covariate shift
|
|
"""
|
|
self.double_conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.double_conv(x)
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self, in_channels=1, out_channels=3):
|
|
super().__init__()
|
|
|
|
|
|
self.down1 = DoubleConv(in_channels, 64)
|
|
self.pool1 = nn.MaxPool2d(2)
|
|
|
|
self.down2 = DoubleConv(64, 128)
|
|
self.pool2 = nn.MaxPool2d(2)
|
|
|
|
self.down3 = DoubleConv(128, 256)
|
|
self.pool3 = nn.MaxPool2d(2)
|
|
|
|
self.down4 = DoubleConv(256, 512)
|
|
self.pool4 = nn.MaxPool2d(2)
|
|
|
|
|
|
self.bottleneck = DoubleConv(512, 1024)
|
|
|
|
|
|
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
|
|
self.dec4 = DoubleConv(1024, 512)
|
|
|
|
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
|
|
self.dec3 = DoubleConv(512, 256)
|
|
|
|
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
|
self.dec2 = DoubleConv(256, 128)
|
|
|
|
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
|
self.dec1 = DoubleConv(128, 64)
|
|
|
|
|
|
self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
|
|
d1 = self.down1(x)
|
|
d2 = self.down2(self.pool1(d1))
|
|
d3 = self.down3(self.pool2(d2))
|
|
d4 = self.down4(self.pool3(d3))
|
|
|
|
|
|
bn = self.bottleneck(self.pool4(d4))
|
|
|
|
|
|
up4 = self.up4(bn)
|
|
dec4 = self.dec4(torch.cat([up4, d4], dim=1))
|
|
|
|
up3 = self.up3(dec4)
|
|
dec3 = self.dec3(torch.cat([up3, d3], dim=1))
|
|
|
|
up2 = self.up2(dec3)
|
|
dec2 = self.dec2(torch.cat([up2, d2], dim=1))
|
|
|
|
up1 = self.up1(dec2)
|
|
dec1 = self.dec1(torch.cat([up1, d1], dim=1))
|
|
|
|
|
|
return self.out_conv(dec1)
|
|
|
|
|
|
class UNetConfig(PretrainedConfig):
|
|
model_type = "unet"
|
|
|
|
def __init__(self, in_channels=1, out_channels=3, image_size=256, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.image_size = image_size
|
|
|
|
|
|
class UNetTransformerModel(PreTrainedModel):
|
|
config_class = UNetConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = UNet(config.in_channels, config.out_channels)
|
|
|
|
def forward(self, pixel_values):
|
|
return self.model(pixel_values) |