|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
from itertools import chain |
|
|
|
|
|
from einops import rearrange |
|
|
from torch.hub import load_state_dict_from_url |
|
|
|
|
|
GlobalAvgPool2D = lambda: nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
model_urls = { |
|
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', |
|
|
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', |
|
|
} |
|
|
|
|
|
class Cross_transformer_backbone(nn.Module): |
|
|
def __init__(self, in_channels = 48): |
|
|
super(Cross_transformer_backbone, self).__init__() |
|
|
|
|
|
self.to_key = nn.Linear(in_channels * 2, in_channels, bias=False) |
|
|
self.to_value = nn.Linear(in_channels * 2, in_channels, bias=False) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
self.gamma_cam_lay3 = nn.Parameter(torch.zeros(1)) |
|
|
self.cam_layer0 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU() |
|
|
) |
|
|
self.cam_layer1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.cam_layer2 = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, input_feature, features): |
|
|
Query_features = input_feature |
|
|
Query_features = self.cam_layer0(Query_features) |
|
|
key_features = self.cam_layer1(features) |
|
|
value_features = self.cam_layer2(features) |
|
|
|
|
|
QK = torch.einsum("nlhd,nshd->nlsh", Query_features, key_features) |
|
|
softmax_temp = 1. / Query_features.size(3)**.5 |
|
|
A = torch.softmax(softmax_temp * QK, dim=2) |
|
|
queried_values = torch.einsum("nlsh,nshd->nlhd", A, value_features).contiguous() |
|
|
message = self.mlp(torch.cat([input_feature, queried_values], dim=1)) |
|
|
|
|
|
return input_feature + message |
|
|
|
|
|
class Cross_transformer(nn.Module): |
|
|
def __init__(self, in_channels = 48): |
|
|
super(Cross_transformer, self).__init__() |
|
|
self.fa = nn.Linear(in_channels , in_channels, bias=False) |
|
|
self.fb = nn.Linear(in_channels, in_channels, bias=False) |
|
|
self.fc = nn.Linear(in_channels , in_channels, bias=False) |
|
|
self.fd = nn.Linear(in_channels, in_channels, bias=False) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
self.to_out = nn.Sequential( |
|
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.gamma_cam_lay3 = nn.Parameter(torch.zeros(1)) |
|
|
self.fuse = nn.Sequential( |
|
|
nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
|
|
|
def attention_layer(self, q, k, v, m_batchsize, C, height, width): |
|
|
k = k.permute(0, 2, 1) |
|
|
energy = torch.bmm(q, k) |
|
|
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy |
|
|
attention = self.softmax(energy_new) |
|
|
out = torch.bmm(attention, v) |
|
|
out = out.view(m_batchsize, C, height, width) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def forward(self, input_feature, features): |
|
|
fa = input_feature |
|
|
fb = features[0] |
|
|
fc = features[1] |
|
|
fd = features[2] |
|
|
|
|
|
|
|
|
m_batchsize, C, height, width = fa.size() |
|
|
fa = self.fa(fa.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
|
|
fb = self.fb(fb.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
|
|
fc = self.fc(fc.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
|
|
fd = self.fd(fd.view(m_batchsize, C, -1).permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
qkv_1 = self.attention_layer(fa, fa, fa, m_batchsize, C, height, width) |
|
|
qkv_2 = self.attention_layer(fa, fb, fb, m_batchsize, C, height, width) |
|
|
qkv_3 = self.attention_layer(fa, fc, fc, m_batchsize, C, height, width) |
|
|
qkv_4 = self.attention_layer(fa, fd, fd, m_batchsize, C, height, width) |
|
|
|
|
|
atten = self.fuse(torch.cat((qkv_1, qkv_2, qkv_3, qkv_4), dim = 1)) |
|
|
|
|
|
|
|
|
out = self.gamma_cam_lay3 * atten + input_feature |
|
|
|
|
|
out = self.to_out(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class SceneRelation(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels, |
|
|
channel_list, |
|
|
out_channels, |
|
|
scale_aware_proj=True): |
|
|
super(SceneRelation, self).__init__() |
|
|
self.scale_aware_proj = scale_aware_proj |
|
|
|
|
|
if scale_aware_proj: |
|
|
self.scene_encoder = nn.ModuleList( |
|
|
[nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, 1), |
|
|
nn.ReLU(True), |
|
|
nn.Conv2d(out_channels, out_channels, 1), |
|
|
) for _ in range(len(channel_list))] |
|
|
) |
|
|
else: |
|
|
|
|
|
self.scene_encoder = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, 1), |
|
|
nn.ReLU(True), |
|
|
nn.Conv2d(out_channels, out_channels, 1), |
|
|
) |
|
|
self.content_encoders = nn.ModuleList() |
|
|
self.feature_reencoders = nn.ModuleList() |
|
|
for c in channel_list: |
|
|
self.content_encoders.append( |
|
|
nn.Sequential( |
|
|
nn.Conv2d(c, out_channels, 1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(True) |
|
|
) |
|
|
) |
|
|
self.feature_reencoders.append( |
|
|
nn.Sequential( |
|
|
nn.Conv2d(c, out_channels, 1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(True) |
|
|
) |
|
|
) |
|
|
|
|
|
self.normalizer = nn.Sigmoid() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, scene_feature, features: list): |
|
|
content_feats = [c_en(p_feat) for c_en, p_feat in zip(self.content_encoders, features)] |
|
|
|
|
|
scene_feats = [op(scene_feature) for op in self.scene_encoder] |
|
|
relations = [self.normalizer(sf) * cf for sf, cf in |
|
|
zip(scene_feats, content_feats)] |
|
|
|
|
|
|
|
|
return relations |
|
|
|
|
|
class PSPModule(nn.Module): |
|
|
def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]): |
|
|
super(PSPModule, self).__init__() |
|
|
out_channels = in_channels // len(bin_sizes) |
|
|
self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) |
|
|
for b_s in bin_sizes]) |
|
|
self.bottleneck = nn.Sequential( |
|
|
nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels, |
|
|
kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(in_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout2d(0.1) |
|
|
) |
|
|
|
|
|
def _make_stages(self, in_channels, out_channels, bin_sz): |
|
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
bn = nn.BatchNorm2d(out_channels) |
|
|
relu = nn.ReLU(inplace=True) |
|
|
return nn.Sequential(conv, bn, relu) |
|
|
|
|
|
def forward(self, features): |
|
|
h, w = features.size()[2], features.size()[3] |
|
|
pyramids = [features] |
|
|
|
|
|
pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', |
|
|
align_corners=True) for stage in self.stages]) |
|
|
output = self.bottleneck(torch.cat(pyramids, dim=1)) |
|
|
return output |
|
|
|
|
|
class Change_detection(nn.Module): |
|
|
|
|
|
def __init__(self, num_classes=2, use_aux=True, fpn_out=48, freeze_bn=False, **_): |
|
|
super(Change_detection, self).__init__() |
|
|
|
|
|
f_channels = [64, 128, 256, 512] |
|
|
|
|
|
|
|
|
self.PPN = PSPModule(f_channels[-1]) |
|
|
|
|
|
|
|
|
self.Cross_transformer_backbone_a3 = Cross_transformer_backbone(in_channels = f_channels[3]) |
|
|
self.Cross_transformer_backbone_a2 = Cross_transformer_backbone(in_channels = f_channels[2]) |
|
|
self.Cross_transformer_backbone_a1 = Cross_transformer_backbone(in_channels = f_channels[1]) |
|
|
self.Cross_transformer_backbone_a0 = Cross_transformer_backbone(in_channels = f_channels[0]) |
|
|
self.Cross_transformer_backbone_a33 = Cross_transformer_backbone(in_channels = f_channels[3]) |
|
|
self.Cross_transformer_backbone_a22 = Cross_transformer_backbone(in_channels = f_channels[2]) |
|
|
self.Cross_transformer_backbone_a11 = Cross_transformer_backbone(in_channels = f_channels[1]) |
|
|
self.Cross_transformer_backbone_a00 = Cross_transformer_backbone(in_channels = f_channels[0]) |
|
|
|
|
|
self.Cross_transformer_backbone_b3 = Cross_transformer_backbone(in_channels = f_channels[3]) |
|
|
self.Cross_transformer_backbone_b2 = Cross_transformer_backbone(in_channels = f_channels[2]) |
|
|
self.Cross_transformer_backbone_b1 = Cross_transformer_backbone(in_channels = f_channels[1]) |
|
|
self.Cross_transformer_backbone_b0 = Cross_transformer_backbone(in_channels = f_channels[0]) |
|
|
self.Cross_transformer_backbone_b33 = Cross_transformer_backbone(in_channels = f_channels[3]) |
|
|
self.Cross_transformer_backbone_b22 = Cross_transformer_backbone(in_channels = f_channels[2]) |
|
|
self.Cross_transformer_backbone_b11 = Cross_transformer_backbone(in_channels = f_channels[1]) |
|
|
self.Cross_transformer_backbone_b00 = Cross_transformer_backbone(in_channels = f_channels[0]) |
|
|
|
|
|
|
|
|
|
|
|
self.sig = nn.Sigmoid() |
|
|
self.gap = GlobalAvgPool2D() |
|
|
self.sr1 = SceneRelation(in_channels = f_channels[3], channel_list = f_channels, out_channels = f_channels[3], scale_aware_proj=True) |
|
|
self.sr2 = SceneRelation(in_channels = f_channels[2], channel_list = f_channels, out_channels = f_channels[2], scale_aware_proj=True) |
|
|
self.sr3 = SceneRelation(in_channels = f_channels[1], channel_list = f_channels, out_channels = f_channels[1], scale_aware_proj=True) |
|
|
self.sr4 = SceneRelation(in_channels = f_channels[0], channel_list =f_channels, out_channels = f_channels[0], scale_aware_proj=True) |
|
|
|
|
|
|
|
|
|
|
|
self.Cross_transformer1 = Cross_transformer(in_channels = f_channels[3]) |
|
|
self.Cross_transformer2 = Cross_transformer(in_channels = f_channels[2]) |
|
|
self.Cross_transformer3 = Cross_transformer(in_channels = f_channels[1]) |
|
|
self.Cross_transformer4 = Cross_transformer(in_channels = f_channels[0]) |
|
|
|
|
|
|
|
|
|
|
|
self.conv_fusion = nn.Sequential( |
|
|
nn.Conv2d(960 , fpn_out, kernel_size=3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(fpn_out), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.output_fill = nn.Sequential( |
|
|
nn.ConvTranspose2d(fpn_out , fpn_out, kernel_size=2, stride = 2, bias=False), |
|
|
nn.BatchNorm2d(fpn_out), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(fpn_out, num_classes, kernel_size=3, padding=1) |
|
|
) |
|
|
self.active = nn.Sigmoid() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
features1, features2 = x |
|
|
|
|
|
features, features11, features22= [], [],[] |
|
|
|
|
|
|
|
|
for i in range(len(features1)): |
|
|
if i == 0: |
|
|
features11.append(self.Cross_transformer_backbone_a00(features1[i] , self.Cross_transformer_backbone_a0(features1[i], features2[i]))) |
|
|
features22.append(self.Cross_transformer_backbone_b00(features2[i], self.Cross_transformer_backbone_b0(features2[i], features1[i]))) |
|
|
elif i == 1: |
|
|
features11.append(self.Cross_transformer_backbone_a11(features1[i] , self.Cross_transformer_backbone_a1(features1[i], features2[i]))) |
|
|
features22.append(self.Cross_transformer_backbone_b11(features2[i], self.Cross_transformer_backbone_b1(features2[i], features1[i]))) |
|
|
elif i == 2: |
|
|
features11.append(self.Cross_transformer_backbone_a22(features1[i] , self.Cross_transformer_backbone_a2(features1[i], features2[i]))) |
|
|
features22.append(self.Cross_transformer_backbone_b22(features2[i], self.Cross_transformer_backbone_b2(features2[i], features1[i]))) |
|
|
elif i == 3: |
|
|
features11.append(self.Cross_transformer_backbone_a33(features1[i] , self.Cross_transformer_backbone_a3(features1[i], features2[i]))) |
|
|
features22.append(self.Cross_transformer_backbone_b33(features2[i], self.Cross_transformer_backbone_b3(features2[i], features1[i]))) |
|
|
|
|
|
|
|
|
for i in range(len(features1)): |
|
|
features.append(abs(features11[i] - features22[i])) |
|
|
features[-1] = self.PPN(features[-1]) |
|
|
|
|
|
|
|
|
|
|
|
H, W = features[0].size(2), features[0].size(3) |
|
|
|
|
|
c6 = self.gap(features[-1]) |
|
|
c7 = self.gap(features[-2]) |
|
|
c8 = self.gap(features[-3]) |
|
|
c9 = self.gap(features[-4]) |
|
|
|
|
|
features1, features2, features3, features4 = [], [], [], [] |
|
|
features1[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
|
|
list_3 = self.sr1(c6, features1) |
|
|
fe3 = self.Cross_transformer1(list_3[-1], [list_3[-2], list_3[-3], list_3[-4]]) |
|
|
|
|
|
features2[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
|
|
list_2 = self.sr2(c7, features2) |
|
|
fe2 = self.Cross_transformer2(list_2[-2], [list_2[-1], list_2[-3], list_2[-4]]) |
|
|
|
|
|
features3[:] = [F.interpolate(feature, size=(64, 64), mode='nearest') for feature in features[:]] |
|
|
list_1 = self.sr3(c8, features3) |
|
|
fe1 = self.Cross_transformer3(list_1[-3], [list_1[-1], list_1[-2], list_1[-4]]) |
|
|
|
|
|
features4[:] = [F.interpolate(feature, size=(128, 128), mode='nearest') for feature in features[:]] |
|
|
list_0 = self.sr4(c9, features4) |
|
|
fe0 = self.Cross_transformer4(list_0[-4], [list_0[-1], list_0[-2], list_0[-3]]) |
|
|
|
|
|
refined_fpn_feat_list = [fe3, fe2, fe1, fe0] |
|
|
|
|
|
|
|
|
refined_fpn_feat_list[0] = F.interpolate(refined_fpn_feat_list[0], scale_factor=4, mode='nearest') |
|
|
refined_fpn_feat_list[1] = F.interpolate(refined_fpn_feat_list[1], scale_factor=4, mode='nearest') |
|
|
refined_fpn_feat_list[2] = F.interpolate(refined_fpn_feat_list[2], scale_factor=4, mode='nearest') |
|
|
refined_fpn_feat_list[3] = F.interpolate(refined_fpn_feat_list[3], scale_factor=2, mode='nearest') |
|
|
|
|
|
|
|
|
x = self.conv_fusion(torch.cat((refined_fpn_feat_list), dim=1)) |
|
|
x = self.output_fill(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
xa = torch.randn(4, 3, 256, 256) |
|
|
xb = torch.randn(4, 3, 256, 256) |
|
|
net = Change_detection() |
|
|
out = net(xa, xb) |
|
|
print(out.shape) |