|
|
import math |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
class CMConv(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=3, groups=1, dilation_set=4, |
|
|
bias=False): |
|
|
super(CMConv, self).__init__() |
|
|
self.prim = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=dilation, dilation=dilation, |
|
|
groups=groups * dilation_set, bias=bias) |
|
|
self.prim_shift = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=2 * dilation, dilation=2 * dilation, |
|
|
groups=groups * dilation_set, bias=bias) |
|
|
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=bias) |
|
|
|
|
|
def backward_hook(grad): |
|
|
out = grad.clone() |
|
|
out[self.mask] = 0 |
|
|
return out |
|
|
|
|
|
self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() |
|
|
_in_channels = in_ch // (groups * dilation_set) |
|
|
_out_channels = out_ch // (groups * dilation_set) |
|
|
for i in range(dilation_set): |
|
|
for j in range(groups): |
|
|
self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, |
|
|
i * _in_channels: (i + 1) * _in_channels, :, :] = 1 |
|
|
self.mask[((i + dilation_set // 2) % dilation_set + j * groups) * |
|
|
_out_channels: ((i + dilation_set // 2) % dilation_set + j * groups + 1) * _out_channels, |
|
|
i * _in_channels: (i + 1) * _in_channels, :, :] = 1 |
|
|
self.conv.weight.data[self.mask] = 0 |
|
|
self.conv.weight.register_hook(backward_hook) |
|
|
self.groups = groups |
|
|
|
|
|
def forward(self, x): |
|
|
x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) |
|
|
x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) |
|
|
x_shift = self.prim_shift(x_merge) |
|
|
return self.prim(x) + self.conv(x) + x_shift |
|
|
|
|
|
|
|
|
class SSFC(torch.nn.Module): |
|
|
def __init__(self, in_ch): |
|
|
super(SSFC, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
_, _, h, w = x.size() |
|
|
|
|
|
q = x.mean(dim=[2, 3], keepdim=True) |
|
|
|
|
|
k = x |
|
|
square = (k - q).pow(2) |
|
|
sigma = square.sum(dim=[2, 3], keepdim=True) / (h * w) |
|
|
att_score = square / (2 * sigma + np.finfo(np.float32).eps) + 0.5 |
|
|
att_weight = nn.Sigmoid()(att_score) |
|
|
|
|
|
|
|
|
return x * att_weight |
|
|
|
|
|
|
|
|
class MSDConv_SSFC(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, kernel_size=1, stride=1, padding=0, ratio=2, aux_k=3, dilation=3): |
|
|
super(MSDConv_SSFC, self).__init__() |
|
|
self.out_ch = out_ch |
|
|
native_ch = math.ceil(out_ch / ratio) |
|
|
aux_ch = native_ch * (ratio - 1) |
|
|
|
|
|
|
|
|
self.native = nn.Sequential( |
|
|
nn.Conv2d(in_ch, native_ch, kernel_size, stride, padding=padding, dilation=1, bias=False), |
|
|
nn.BatchNorm2d(native_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
|
|
|
self.aux = nn.Sequential( |
|
|
CMConv(native_ch, aux_ch, aux_k, 1, padding=1, groups=int(native_ch / 4), dilation=dilation, |
|
|
bias=False), |
|
|
nn.BatchNorm2d(aux_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
self.att = SSFC(aux_ch) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.native(x) |
|
|
x2 = self.att(self.aux(x1)) |
|
|
out = torch.cat([x1, x2], dim=1) |
|
|
return out[:, :self.out_ch, :, :] |
|
|
|
|
|
|
|
|
class First_DoubleConv(nn.Module): |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super(First_DoubleConv, self).__init__() |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.conv(input) |
|
|
|
|
|
|
|
|
class DoubleConv(nn.Module): |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super(DoubleConv, self).__init__() |
|
|
self.Conv = nn.Sequential( |
|
|
MSDConv_SSFC(in_ch, out_ch, dilation=3), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
MSDConv_SSFC(out_ch, out_ch, dilation=3), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.Conv(input) |
|
|
|
|
|
|
|
|
class USSFCNet_decoder(nn.Module): |
|
|
def __init__(self, out_ch, ratio=0.5): |
|
|
super(USSFCNet_decoder, self).__init__() |
|
|
|
|
|
self.Up5 = nn.ConvTranspose2d(int(1024 * ratio), int(512 * ratio), 2, stride=2) |
|
|
self.Up_conv5 = DoubleConv(int(1024 * ratio), int(512 * ratio)) |
|
|
|
|
|
self.Up4 = nn.ConvTranspose2d(int(512 * ratio), int(256 * ratio), 2, stride=2) |
|
|
self.Up_conv4 = DoubleConv(int(512 * ratio), int(256 * ratio)) |
|
|
|
|
|
self.Up3 = nn.ConvTranspose2d(int(256 * ratio), int(128 * ratio), 2, stride=2) |
|
|
self.Up_conv3 = DoubleConv(int(256 * ratio), int(128 * ratio)) |
|
|
|
|
|
self.Up2 = nn.ConvTranspose2d(int(128 * ratio), int(64 * ratio), 2, stride=2) |
|
|
self.Up_conv2 = DoubleConv(int(128 * ratio), int(64 * ratio)) |
|
|
|
|
|
self.Conv_1x1 = nn.Conv2d(int(64 * ratio), out_ch, kernel_size=1, stride=1, padding=0) |
|
|
|
|
|
def forward(self, x): |
|
|
x1, x2, x3, x4, x5 = x |
|
|
|
|
|
d5 = self.Up5(x5) |
|
|
d5 = torch.cat((x4, d5), dim=1) |
|
|
d5 = self.Up_conv5(d5) |
|
|
|
|
|
d4 = self.Up4(d5) |
|
|
d4 = torch.cat((x3, d4), dim=1) |
|
|
d4 = self.Up_conv4(d4) |
|
|
|
|
|
d3 = self.Up3(d4) |
|
|
d3 = torch.cat((x2, d3), dim=1) |
|
|
d3 = self.Up_conv3(d3) |
|
|
|
|
|
d2 = self.Up2(d3) |
|
|
d2 = torch.cat((x1, d2), dim=1) |
|
|
d2 = self.Up_conv2(d2) |
|
|
|
|
|
|
|
|
d1 = self.Conv_1x1(d2) |
|
|
out = nn.Sigmoid()(d1) |
|
|
|
|
|
return out |