|
|
|
from skimage.measure import label, regionprops |
|
import torch |
|
import torchvision |
|
from torch.nn import functional as F |
|
import torch.nn as nn |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
from collections import namedtuple |
|
|
|
|
|
|
|
from timm.models.swin_transformer import swin_base_patch4_window12_384_in22k, SwinTransformer |
|
|
|
torch.manual_seed(0) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pad_value = 10 |
|
|
|
def forward_features(self, x): |
|
x = self.patch_embed(x) |
|
if self.absolute_pos_embed is not None: |
|
x = x + self.absolute_pos_embed |
|
x = self.pos_drop(x) |
|
|
|
hide=[] |
|
for layer in self.layers: |
|
x = layer(x) |
|
|
|
hide.append(x) |
|
|
|
|
|
x = self.norm(x) |
|
return hide |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
|
|
return x |
|
|
|
SwinTransformer.forward_features = forward_features |
|
SwinTransformer.forward = forward |
|
|
|
|
|
|
|
|
|
def extract_regions_Last(img_test, ytruth, pad1=pad_value, pad2=pad_value, pad3=pad_value, pad4=pad_value): |
|
|
|
y_truth_copy = ytruth.copy() |
|
y_truth_copy[y_truth_copy == 2] = 1 |
|
label_img = label(y_truth_copy) |
|
|
|
regions = regionprops(label_img) |
|
max_Area = -1 |
|
cropped_results = dict() |
|
for props in regions: |
|
if props.area > max_Area: |
|
max_Area = props.area |
|
minr, minc, maxr, maxc = props.bbox |
|
bx = (minc, maxc, maxc, minc, minc) |
|
by = (minr, minr, maxr, maxr, minr) |
|
|
|
|
|
|
|
|
|
|
|
if minr - pad1 < 0: |
|
pad1 = 5 |
|
if minr - pad1 < 0: |
|
pad1 = 0 |
|
|
|
if minc - pad2 < 0: |
|
pad2 = 5 |
|
if minc - pad2 < 0: |
|
pad2 = 0 |
|
if maxr + pad3 > label_img.shape[0]: |
|
pad3 = 5 |
|
if maxr + pad3 > label_img.shape[0]: |
|
pad3 = 0 |
|
|
|
if maxc + pad4 > label_img.shape[1]: |
|
pad4 = 5 |
|
if maxc + pad4 > label_img.shape[1]: |
|
pad4 = 0 |
|
|
|
cropped_image = img_test[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4, :] |
|
cropped_truth = ytruth[minr - pad1:maxr + pad3, minc - pad2:maxc + pad4] |
|
txcordi = [] |
|
txcordi.append(minr - pad1) |
|
txcordi.append(maxr + pad3) |
|
txcordi.append(minc - pad2) |
|
txcordi.append(maxc + pad4) |
|
cropped_results['image'] = cropped_image |
|
cropped_results['truth'] = cropped_truth |
|
cropped_results['cord'] = txcordi |
|
|
|
return cropped_results |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
def __init__(self, channel_num): |
|
super(BasicBlock, self).__init__() |
|
|
|
|
|
self.conv_block1 = nn.Sequential( |
|
nn.Conv2d(channel_num, 48, 1, padding=0), |
|
nn.GroupNorm(num_groups=8, num_channels=48), |
|
nn.GELU(), |
|
) |
|
self.conv_block2 = nn.Sequential( |
|
nn.Conv2d(48, channel_num, 3, padding=1), |
|
nn.GroupNorm(num_groups=8, num_channels=channel_num), |
|
nn.GELU(), |
|
) |
|
self.relu = nn.GELU() |
|
|
|
def forward(self, x): |
|
|
|
residual = x |
|
x = self.conv_block1(x) |
|
x = self.conv_block2(x) |
|
x = x + residual |
|
return x |
|
|
|
|
|
class ASPP(nn.Module): |
|
def __init__(self, image_dim=384, head=1): |
|
super(ASPP, self).__init__() |
|
self.image_dim = image_dim |
|
self.Residual2 = BasicBlock(channel_num=head) |
|
self.pixel_shuffle = nn.PixelShuffle(2) |
|
self.head = head |
|
|
|
def forward(self, x): |
|
x21 = F.interpolate(x, size=(self.image_dim, self.image_dim), mode='bilinear', |
|
align_corners=True) |
|
return x21 |
|
|
|
|
|
|
|
class Transformer_Regression(nn.Module): |
|
def __init__(self, image_dim=224, dim_patch=24, num_classes=3, scale=1, feat_dim=192): |
|
super(Transformer_Regression, self).__init__() |
|
self.backbone = swin_base_patch4_window12_384_in22k(pretrained=True) |
|
self.aux = 1 |
|
self.dim_patch = dim_patch |
|
self.image_dim = image_dim |
|
self.num_classes = num_classes |
|
self.ASPP1 = ASPP(image_dim, head=128) |
|
self.ASPP2 = ASPP(image_dim, head=128) |
|
|
|
self.feat_dim = feat_dim |
|
|
|
self.Classifier_main = nn.Sequential( |
|
|
|
nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1), |
|
) |
|
self.Classifier_aux1 = nn.Sequential( |
|
|
|
nn.Conv2d(128, self.num_classes, 3, bias=True, padding=1), |
|
) |
|
|
|
self.conv1 = nn.Sequential(nn.Conv2d(448, 128, kernel_size=(1, 1), padding=1), nn.GELU()) |
|
self.pixelshufler1 = nn.PixelShuffle(2) |
|
self.pixelshufler2 = nn.PixelShuffle(4) |
|
|
|
def forward(self, x): |
|
hide1 = self.backbone(x) |
|
x1 = [] |
|
x1.append((hide1[0][:, 0:].reshape(-1, 48, 48, 256))) |
|
x1.append((hide1[1][:, 0:].reshape(-1, 24, 24, 512))) |
|
x1.append((hide1[2][:, 0:].reshape(-1, 12, 12, 1024))) |
|
for jk in range(len(x1)): |
|
x1[jk] = x1[jk].permute(0, 3, 1, 2) |
|
x1[1] = self.pixelshufler1(x1[1]) |
|
x1[2] = self.pixelshufler2(x1[2]) |
|
|
|
x1[0] = torch.cat((x1[0], x1[1], x1[2]), 1) |
|
|
|
x1[0] = self.conv1(x1[0]) |
|
Score = dict() |
|
x_main1 = self.ASPP1(x1[0]) |
|
x_main = self.Classifier_main(x_main1) |
|
x_aux_1 = self.ASPP2(x1[0]) |
|
x_aux_1 = self.Classifier_aux1(x_aux_1) |
|
|
|
Score['seg'] = x_main |
|
Score['seg_aux_1'] = x_aux_1 |
|
|
|
|
|
return Score |
|
|
|
|
|
Ratios = namedtuple("Ratios", 'cdr hcdr vcdr') |
|
eps = np.finfo(np.float32).eps |
|
|
|
|
|
def compute_ratios(mask_image): |
|
''' |
|
Given an input image containing the cup and disc masks the function returns |
|
a tuple with the area, horizontal, and vertical cup-to-disc ratios |
|
Input: |
|
mask_image: an image with values (0,1,2) or (255,128,0) |
|
for bg, disc, cup respectively |
|
Output: |
|
Ratios(cdr,hcdr,vcdr): a named tuple containing the computed ratios |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disc = 0 |
|
cup = 0 |
|
disc = disc + np.uint8(mask_image > 0) |
|
cup = cup + np.uint8(mask_image > 1) |
|
|
|
disc_area = np.sum(disc) |
|
cup_area = np.sum(cup) |
|
|
|
cup_vert = np.sum(cup, axis=0).max().astype(np.int32) |
|
cup_horz = np.sum(cup, axis=1).max().astype(np.int32) |
|
|
|
disc_vert = np.sum(disc, axis=0).max().astype(np.int32) |
|
disc_horz = np.sum(disc, axis=1).max().astype(np.int32) |
|
|
|
cdr = (cup_area + eps) / (disc_area + eps) |
|
|
|
hcdr = (cup_horz + eps) / (disc_horz + eps) |
|
vcdr = (cup_vert + eps) / (disc_vert + eps) |
|
|
|
return Ratios(cdr, hcdr, vcdr) |
|
|
|
|
|
|
|
|