InPeerReview's picture
Upload 161 files
226675b verified
import math
import torch.nn as nn
import torch
import numpy as np
class CMConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dilation=3, groups=1, dilation_set=4,
bias=False):
super(CMConv, self).__init__()
self.prim = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=dilation, dilation=dilation,
groups=groups * dilation_set, bias=bias)
self.prim_shift = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=2 * dilation, dilation=2 * dilation,
groups=groups * dilation_set, bias=bias)
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=bias)
def backward_hook(grad):
out = grad.clone()
out[self.mask] = 0
return out
self.mask = torch.zeros(self.conv.weight.shape).byte().cuda()
_in_channels = in_ch // (groups * dilation_set)
_out_channels = out_ch // (groups * dilation_set)
for i in range(dilation_set):
for j in range(groups):
self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels,
i * _in_channels: (i + 1) * _in_channels, :, :] = 1
self.mask[((i + dilation_set // 2) % dilation_set + j * groups) *
_out_channels: ((i + dilation_set // 2) % dilation_set + j * groups + 1) * _out_channels,
i * _in_channels: (i + 1) * _in_channels, :, :] = 1
self.conv.weight.data[self.mask] = 0
self.conv.weight.register_hook(backward_hook)
self.groups = groups
def forward(self, x):
x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1))
x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1)
x_shift = self.prim_shift(x_merge)
return self.prim(x) + self.conv(x) + x_shift
class SSFC(torch.nn.Module):
def __init__(self, in_ch):
super(SSFC, self).__init__()
# self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1) # generate k by conv
def forward(self, x):
_, _, h, w = x.size()
q = x.mean(dim=[2, 3], keepdim=True)
# k = self.proj(x)
k = x
square = (k - q).pow(2)
sigma = square.sum(dim=[2, 3], keepdim=True) / (h * w)
att_score = square / (2 * sigma + np.finfo(np.float32).eps) + 0.5
att_weight = nn.Sigmoid()(att_score)
# print(sigma)
return x * att_weight
class MSDConv_SSFC(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=1, stride=1, padding=0, ratio=2, aux_k=3, dilation=3):
super(MSDConv_SSFC, self).__init__()
self.out_ch = out_ch
native_ch = math.ceil(out_ch / ratio)
aux_ch = native_ch * (ratio - 1)
# native feature maps
self.native = nn.Sequential(
nn.Conv2d(in_ch, native_ch, kernel_size, stride, padding=padding, dilation=1, bias=False),
nn.BatchNorm2d(native_ch),
nn.ReLU(inplace=True),
)
# auxiliary feature maps
self.aux = nn.Sequential(
CMConv(native_ch, aux_ch, aux_k, 1, padding=1, groups=int(native_ch / 4), dilation=dilation,
bias=False),
nn.BatchNorm2d(aux_ch),
nn.ReLU(inplace=True),
)
self.att = SSFC(aux_ch)
def forward(self, x):
x1 = self.native(x)
x2 = self.att(self.aux(x1))
out = torch.cat([x1, x2], dim=1)
return out[:, :self.out_ch, :, :]
class First_DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(First_DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.Conv = nn.Sequential(
MSDConv_SSFC(in_ch, out_ch, dilation=3),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
MSDConv_SSFC(out_ch, out_ch, dilation=3),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.Conv(input)
class USSFCNet_decoder(nn.Module):
def __init__(self, out_ch, ratio=0.5):
super(USSFCNet_decoder, self).__init__()
self.Up5 = nn.ConvTranspose2d(int(1024 * ratio), int(512 * ratio), 2, stride=2)
self.Up_conv5 = DoubleConv(int(1024 * ratio), int(512 * ratio))
self.Up4 = nn.ConvTranspose2d(int(512 * ratio), int(256 * ratio), 2, stride=2)
self.Up_conv4 = DoubleConv(int(512 * ratio), int(256 * ratio))
self.Up3 = nn.ConvTranspose2d(int(256 * ratio), int(128 * ratio), 2, stride=2)
self.Up_conv3 = DoubleConv(int(256 * ratio), int(128 * ratio))
self.Up2 = nn.ConvTranspose2d(int(128 * ratio), int(64 * ratio), 2, stride=2)
self.Up_conv2 = DoubleConv(int(128 * ratio), int(64 * ratio))
self.Conv_1x1 = nn.Conv2d(int(64 * ratio), out_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1, x2, x3, x4, x5 = x
# decoding
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
# out = self.Conv_1x1(d2)
d1 = self.Conv_1x1(d2)
out = nn.Sigmoid()(d1)
return out