|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import timm |
|
import numpy as np |
|
import cv2 |
|
|
|
|
|
|
|
|
|
class PyramidPoolingModule(nn.Module): |
|
def __init__(self, in_channels, pool_sizes=[1, 2, 3, 6]): |
|
super().__init__() |
|
self.pool_layers = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.AdaptiveAvgPool2d(pool_size), |
|
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False), |
|
nn.GroupNorm(num_groups=8, num_channels=in_channels // 4), |
|
nn.ReLU(inplace=True) |
|
) for pool_size in pool_sizes |
|
]) |
|
total_channels = in_channels + len(pool_sizes) * (in_channels // 4) |
|
self.conv = nn.Conv2d(total_channels, in_channels, kernel_size=1, bias=False) |
|
|
|
def forward(self, x): |
|
pooled_features = [x] |
|
for layer in self.pool_layers: |
|
pooled = layer(x) |
|
pooled = F.interpolate(pooled, size=x.shape[2:], mode='bilinear', align_corners=False) |
|
pooled_features.append(pooled) |
|
x = torch.cat(pooled_features, dim=1) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
|
|
|
|
class UPerNetDecoder(nn.Module): |
|
def __init__(self, encoder_channels, num_classes=1, dropout_rate=0.1): |
|
super().__init__() |
|
self.ppm = PyramidPoolingModule(encoder_channels[-1]) |
|
self.lateral_conv2 = nn.Conv2d(encoder_channels[2], encoder_channels[-1], kernel_size=1) |
|
self.conv3 = nn.Sequential( |
|
nn.Conv2d(encoder_channels[-1], encoder_channels[2], kernel_size=1), |
|
nn.Dropout2d(p=dropout_rate) |
|
) |
|
self.lateral_conv1 = nn.Conv2d(encoder_channels[1], encoder_channels[2], kernel_size=1) |
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(encoder_channels[2], encoder_channels[1], kernel_size=1), |
|
nn.Dropout2d(p=dropout_rate) |
|
) |
|
self.lateral_conv0 = nn.Conv2d(encoder_channels[0], encoder_channels[1], kernel_size=1) |
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(encoder_channels[1], encoder_channels[0], kernel_size=1), |
|
nn.Dropout2d(p=dropout_rate) |
|
) |
|
self.segmentation_head = nn.Conv2d(encoder_channels[0], num_classes, kernel_size=1) |
|
|
|
def forward(self, features): |
|
f0, f1, f2, f3 = features |
|
x3 = self.ppm(f3) |
|
x3_up = F.interpolate(x3, size=f2.shape[2:], mode="bilinear", align_corners=False) |
|
fuse2 = x3_up + self.lateral_conv2(f2) |
|
fuse2 = self.conv3(fuse2) |
|
fuse2_up = F.interpolate(fuse2, size=f1.shape[2:], mode="bilinear", align_corners=False) |
|
fuse1 = fuse2_up + self.lateral_conv1(f1) |
|
fuse1 = self.conv2(fuse1) |
|
fuse1_up = F.interpolate(fuse1, size=f0.shape[2:], mode="bilinear", align_corners=False) |
|
fuse0 = fuse1_up + self.lateral_conv0(f0) |
|
fuse0 = self.conv1(fuse0) |
|
x_out = F.interpolate(fuse0, size=(224, 224), mode="bilinear", align_corners=False) |
|
output = self.segmentation_head(x_out) |
|
return output |
|
|
|
|
|
|
|
|
|
class SwinTinyUPerNet(nn.Module): |
|
def __init__(self, num_classes=1, dropout_rate=0.1): |
|
super().__init__() |
|
self.encoder = timm.create_model( |
|
"swin_tiny_patch4_window7_224.ms_in22k_ft_in1k", |
|
pretrained=True, |
|
features_only=True |
|
) |
|
encoder_channels = self.encoder.feature_info.channels() |
|
self.decoder = UPerNetDecoder(encoder_channels, num_classes, dropout_rate=dropout_rate) |
|
|
|
def forward(self, x): |
|
features = self.encoder(x) |
|
features = [f.permute(0, 3, 1, 2) if f.dim() == 4 else f for f in features] |
|
output = self.decoder(features) |
|
return F.interpolate(output, size=(224, 224), mode="bilinear", align_corners=False) |
|
|
|
|
|
|
|
|
|
def load_model(): |
|
model = SwinTinyUPerNet(num_classes=1) |
|
model.load_state_dict(torch.load("best_swin_upernet_main.pth", map_location=torch.device("cpu")), strict=False) |
|
model.eval() |
|
return model |
|
|
|
|
|
|
|
|
|
def enable_dropout(m): |
|
if isinstance(m, nn.Dropout) or isinstance(m, nn.Dropout2d): |
|
m.train() |
|
|
|
|
|
|
|
|
|
def predict_with_uncertainty(image_tensor, num_samples=10): |
|
model = load_model() |
|
model.apply(enable_dropout) |
|
preds_list = [] |
|
|
|
with torch.no_grad(): |
|
for _ in range(num_samples): |
|
preds = torch.sigmoid(model(image_tensor)) |
|
preds_list.append(preds) |
|
|
|
preds_array = torch.stack(preds_list, dim=0) |
|
preds_mean = preds_array.mean(dim=0).squeeze().cpu().numpy() |
|
preds_uncertainty = preds_array.std(dim=0).squeeze().cpu().numpy() |
|
|
|
|
|
preds_uncertainty = (preds_uncertainty - preds_uncertainty.min()) / (preds_uncertainty.max() - preds_uncertainty.min() + 1e-8) |
|
return preds_mean, preds_uncertainty |