File size: 3,805 Bytes
fd1011e 77c66f1 fd1011e 77c66f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
class DoubleConv(nn.Module):
"""
This is the core building block of the U-Net architecture.
Use consecutive convolutional layers
Each followed by batch normalization and ReLU activation
"""
def __init__(self, in_channels, out_channels):
super().__init__()
"""
nn.Conv2d:
Applies a 2D convolution filter (kernel size 3×3)
padding=1 ensures the output spatial size stays the same
First conv changes input channels → output channels
Second conv keeps it at out_channels
nn.BatchNorm2d
Normalizes activations across the batch and channels
Helps stabilize and speed up training
Reduces internal covariate shift
"""
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=3):
super().__init__()
# Encoder
self.down1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.down2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.down3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.down4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = DoubleConv(512, 1024)
# Decoder
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = DoubleConv(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = DoubleConv(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = DoubleConv(128, 64)
# Final output layer
self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
# Encoder
d1 = self.down1(x)
d2 = self.down2(self.pool1(d1))
d3 = self.down3(self.pool2(d2))
d4 = self.down4(self.pool3(d3))
# Bottleneck
bn = self.bottleneck(self.pool4(d4))
# Decoder
up4 = self.up4(bn)
dec4 = self.dec4(torch.cat([up4, d4], dim=1))
up3 = self.up3(dec4)
dec3 = self.dec3(torch.cat([up3, d3], dim=1))
up2 = self.up2(dec3)
dec2 = self.dec2(torch.cat([up2, d2], dim=1))
up1 = self.up1(dec2)
dec1 = self.dec1(torch.cat([up1, d1], dim=1))
# Output
return self.out_conv(dec1)
# --- Config ---
class UNetConfig(PretrainedConfig):
model_type = "unet"
def __init__(self, in_channels=1, out_channels=3, image_size=256, **kwargs):
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.image_size = image_size
# --- Transformer-compatible Model ---
class UNetTransformerModel(PreTrainedModel):
config_class = UNetConfig
def __init__(self, config):
super().__init__(config)
self.model = UNet(config.in_channels, config.out_channels)
def forward(self, pixel_values):
return self.model(pixel_values) |