import math import torch.nn as nn import torch import numpy as np class CMConv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=3, groups=1, dilation_set=4, bias=False): super(CMConv, self).__init__() self.prim = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=dilation, dilation=dilation, groups=groups * dilation_set, bias=bias) self.prim_shift = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=2 * dilation, dilation=2 * dilation, groups=groups * dilation_set, bias=bias) self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() _in_channels = in_ch // (groups * dilation_set) _out_channels = out_ch // (groups * dilation_set) for i in range(dilation_set): for j in range(groups): self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.mask[((i + dilation_set // 2) % dilation_set + j * groups) * _out_channels: ((i + dilation_set // 2) % dilation_set + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.groups = groups def forward(self, x): x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) x_shift = self.prim_shift(x_merge) return self.prim(x) + self.conv(x) + x_shift class SSFC(torch.nn.Module): def __init__(self, in_ch): super(SSFC, self).__init__() # self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1) # generate k by conv def forward(self, x): _, _, h, w = x.size() q = x.mean(dim=[2, 3], keepdim=True) # k = self.proj(x) k = x square = (k - q).pow(2) sigma = square.sum(dim=[2, 3], keepdim=True) / (h * w) att_score = square / (2 * sigma + np.finfo(np.float32).eps) + 0.5 att_weight = nn.Sigmoid()(att_score) # print(sigma) return x * att_weight class MSDConv_SSFC(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=1, stride=1, padding=0, ratio=2, aux_k=3, dilation=3): super(MSDConv_SSFC, self).__init__() self.out_ch = out_ch native_ch = math.ceil(out_ch / ratio) aux_ch = native_ch * (ratio - 1) # native feature maps self.native = nn.Sequential( nn.Conv2d(in_ch, native_ch, kernel_size, stride, padding=padding, dilation=1, bias=False), nn.BatchNorm2d(native_ch), nn.ReLU(inplace=True), ) # auxiliary feature maps self.aux = nn.Sequential( CMConv(native_ch, aux_ch, aux_k, 1, padding=1, groups=int(native_ch / 4), dilation=dilation, bias=False), nn.BatchNorm2d(aux_ch), nn.ReLU(inplace=True), ) self.att = SSFC(aux_ch) def forward(self, x): x1 = self.native(x) x2 = self.att(self.aux(x1)) out = torch.cat([x1, x2], dim=1) return out[:, :self.out_ch, :, :] class First_DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(First_DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.Conv = nn.Sequential( MSDConv_SSFC(in_ch, out_ch, dilation=3), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), MSDConv_SSFC(out_ch, out_ch, dilation=3), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.Conv(input) class USSFCNet_decoder(nn.Module): def __init__(self, out_ch, ratio=0.5): super(USSFCNet_decoder, self).__init__() self.Up5 = nn.ConvTranspose2d(int(1024 * ratio), int(512 * ratio), 2, stride=2) self.Up_conv5 = DoubleConv(int(1024 * ratio), int(512 * ratio)) self.Up4 = nn.ConvTranspose2d(int(512 * ratio), int(256 * ratio), 2, stride=2) self.Up_conv4 = DoubleConv(int(512 * ratio), int(256 * ratio)) self.Up3 = nn.ConvTranspose2d(int(256 * ratio), int(128 * ratio), 2, stride=2) self.Up_conv3 = DoubleConv(int(256 * ratio), int(128 * ratio)) self.Up2 = nn.ConvTranspose2d(int(128 * ratio), int(64 * ratio), 2, stride=2) self.Up_conv2 = DoubleConv(int(128 * ratio), int(64 * ratio)) self.Conv_1x1 = nn.Conv2d(int(64 * ratio), out_ch, kernel_size=1, stride=1, padding=0) def forward(self, x): x1, x2, x3, x4, x5 = x # decoding d5 = self.Up5(x5) d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_conv5(d5) d4 = self.Up4(d5) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) # out = self.Conv_1x1(d2) d1 = self.Conv_1x1(d2) out = nn.Sigmoid()(d1) return out