import math import torch.nn as nn import torch import numpy as np class CMConv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=3, groups=1, dilation_set=4, bias=False): super(CMConv, self).__init__() self.prim = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=dilation, dilation=dilation, groups=groups * dilation_set, bias=bias) self.prim_shift = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=2 * dilation, dilation=2 * dilation, groups=groups * dilation_set, bias=bias) self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() _in_channels = in_ch // (groups * dilation_set) _out_channels = out_ch // (groups * dilation_set) for i in range(dilation_set): for j in range(groups): self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.mask[((i + dilation_set // 2) % dilation_set + j * groups) * _out_channels: ((i + dilation_set // 2) % dilation_set + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.groups = groups def forward(self, x): x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) x_shift = self.prim_shift(x_merge) return self.prim(x) + self.conv(x) + x_shift class SSFC(torch.nn.Module): def __init__(self, in_ch): super(SSFC, self).__init__() # self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1) # generate k by conv def forward(self, x): _, _, h, w = x.size() q = x.mean(dim=[2, 3], keepdim=True) # k = self.proj(x) k = x square = (k - q).pow(2) sigma = square.sum(dim=[2, 3], keepdim=True) / (h * w) att_score = square / (2 * sigma + np.finfo(np.float32).eps) + 0.5 att_weight = nn.Sigmoid()(att_score) # print(sigma) return x * att_weight class MSDConv_SSFC(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=1, stride=1, padding=0, ratio=2, aux_k=3, dilation=3): super(MSDConv_SSFC, self).__init__() self.out_ch = out_ch native_ch = math.ceil(out_ch / ratio) aux_ch = native_ch * (ratio - 1) # native feature maps self.native = nn.Sequential( nn.Conv2d(in_ch, native_ch, kernel_size, stride, padding=padding, dilation=1, bias=False), nn.BatchNorm2d(native_ch), nn.ReLU(inplace=True), ) # auxiliary feature maps self.aux = nn.Sequential( CMConv(native_ch, aux_ch, aux_k, 1, padding=1, groups=int(native_ch / 4), dilation=dilation, bias=False), nn.BatchNorm2d(aux_ch), nn.ReLU(inplace=True), ) self.att = SSFC(aux_ch) def forward(self, x): x1 = self.native(x) x2 = self.att(self.aux(x1)) out = torch.cat([x1, x2], dim=1) return out[:, :self.out_ch, :, :] class First_DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(First_DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.Conv = nn.Sequential( MSDConv_SSFC(in_ch, out_ch, dilation=3), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), MSDConv_SSFC(out_ch, out_ch, dilation=3), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.Conv(input) class USSFCNet_encoder(nn.Module): def __init__(self, in_ch, ratio=0.5): super(USSFCNet_encoder, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Conv1_1 = First_DoubleConv(in_ch, int(64 * ratio)) self.Conv1_2 = First_DoubleConv(in_ch, int(64 * ratio)) self.Conv2_1 = DoubleConv(int(64 * ratio), int(128 * ratio)) self.Conv2_2 = DoubleConv(int(64 * ratio), int(128 * ratio)) self.Conv3_1 = DoubleConv(int(128 * ratio), int(256 * ratio)) self.Conv3_2 = DoubleConv(int(128 * ratio), int(256 * ratio)) self.Conv4_1 = DoubleConv(int(256 * ratio), int(512 * ratio)) self.Conv4_2 = DoubleConv(int(256 * ratio), int(512 * ratio)) self.Conv5_1 = DoubleConv(int(512 * ratio), int(1024 * ratio)) self.Conv5_2 = DoubleConv(int(512 * ratio), int(1024 * ratio)) def forward(self, x1, x2): # encoding # x1, x2 = torch.unsqueeze(x1[0], dim=0), torch.unsqueeze(x1[1], dim=0) c1_1 = self.Conv1_1(x1) c1_2 = self.Conv1_2(x2) x1 = torch.abs(torch.sub(c1_1, c1_2)) c2_1 = self.Maxpool(c1_1) c2_1 = self.Conv2_1(c2_1) c2_2 = self.Maxpool(c1_2) c2_2 = self.Conv2_2(c2_2) x2 = torch.abs(torch.sub(c2_1, c2_2)) c3_1 = self.Maxpool(c2_1) c3_1 = self.Conv3_1(c3_1) c3_2 = self.Maxpool(c2_2) c3_2 = self.Conv3_2(c3_2) x3 = torch.abs(torch.sub(c3_1, c3_2)) c4_1 = self.Maxpool(c3_1) c4_1 = self.Conv4_1(c4_1) c4_2 = self.Maxpool(c3_2) c4_2 = self.Conv4_2(c4_2) x4 = torch.abs(torch.sub(c4_1, c4_2)) c5_1 = self.Maxpool(c4_1) c5_1 = self.Conv5_1(c5_1) c5_2 = self.Maxpool(c4_2) c5_2 = self.Conv5_2(c5_2) x5 = torch.abs(torch.sub(c5_1, c5_2)) return [x1, x2, x3, x4, x5]