|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
"""
|
|
|
Linear Embedding
|
|
|
"""
|
|
|
def __init__(self, input_dim=2048, embed_dim=768):
|
|
|
super().__init__()
|
|
|
self.proj = nn.Linear(input_dim, embed_dim)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
x = self.proj(x)
|
|
|
return x
|
|
|
|
|
|
class UpsampleConvLayer(torch.nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
|
|
super(UpsampleConvLayer, self).__init__()
|
|
|
self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)
|
|
|
|
|
|
def forward(self, x):
|
|
|
out = self.conv2d(x)
|
|
|
return out
|
|
|
|
|
|
class ResidualBlock(torch.nn.Module):
|
|
|
def __init__(self, channels):
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
|
|
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
|
def forward(self, x):
|
|
|
residual = x
|
|
|
out = self.relu(self.conv1(x))
|
|
|
out = self.conv2(out) * 0.1
|
|
|
out = torch.add(out, residual)
|
|
|
return out
|
|
|
|
|
|
class ConvLayer(nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
|
|
super(ConvLayer, self).__init__()
|
|
|
|
|
|
|
|
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
out = self.conv2d(x)
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def conv_diff(in_channels, out_channels):
|
|
|
return nn.Sequential(
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
|
nn.ReLU(),
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
|
|
nn.ReLU()
|
|
|
)
|
|
|
|
|
|
|
|
|
def make_prediction(in_channels, out_channels):
|
|
|
return nn.Sequential(
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
|
nn.ReLU(),
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
|
)
|
|
|
|
|
|
def resize(input,
|
|
|
size=None,
|
|
|
scale_factor=None,
|
|
|
mode='nearest',
|
|
|
align_corners=None,
|
|
|
warning=True):
|
|
|
if warning:
|
|
|
if size is not None and align_corners:
|
|
|
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
|
|
output_h, output_w = tuple(int(x) for x in size)
|
|
|
if output_h > input_h or output_w > output_h:
|
|
|
if ((output_h > 1 and output_w > 1 and input_h > 1
|
|
|
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
|
|
and (output_w - 1) % (input_w - 1)):
|
|
|
warnings.warn(
|
|
|
f'When align_corners={align_corners}, '
|
|
|
'the output would more aligned if '
|
|
|
f'input size {(input_h, input_w)} is `x+1` and '
|
|
|
f'out size {(output_h, output_w)} is `nx+1`')
|
|
|
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
|
|
|
|
|
|
|
|
class DecoderTransformer_v3(nn.Module):
|
|
|
"""
|
|
|
Transformer Decoder
|
|
|
"""
|
|
|
def __init__(self, input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=True,
|
|
|
in_channels = [32, 64, 128, 256], embedding_dim= 64, output_nc=2,
|
|
|
decoder_softmax = False, feature_strides=[2, 4, 8, 16]):
|
|
|
super(DecoderTransformer_v3, self).__init__()
|
|
|
|
|
|
assert len(feature_strides) == len(in_channels)
|
|
|
assert min(feature_strides) == feature_strides[0]
|
|
|
|
|
|
|
|
|
self.feature_strides = feature_strides
|
|
|
self.input_transform = input_transform
|
|
|
self.in_index = in_index
|
|
|
self.align_corners = align_corners
|
|
|
self.in_channels = in_channels
|
|
|
self.embedding_dim = embedding_dim
|
|
|
self.output_nc = output_nc
|
|
|
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
|
|
|
|
|
|
|
|
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
|
|
|
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
|
|
|
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
|
|
|
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
|
|
|
|
|
|
|
|
|
self.diff_c4 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
|
|
self.diff_c3 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
|
|
self.diff_c2 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
|
|
self.diff_c1 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim)
|
|
|
|
|
|
|
|
|
self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
|
|
self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
|
|
self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
|
|
self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
|
|
|
|
|
|
|
|
|
self.linear_fuse = nn.Sequential(
|
|
|
nn.Conv2d( in_channels=self.embedding_dim*len(in_channels), out_channels=self.embedding_dim,
|
|
|
kernel_size=1),
|
|
|
nn.BatchNorm2d(self.embedding_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
|
|
self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
|
|
self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
|
|
|
self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim))
|
|
|
self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
|
|
|
|
|
|
|
|
|
self.output_softmax = decoder_softmax
|
|
|
self.active = nn.Sigmoid()
|
|
|
|
|
|
def _transform_inputs(self, inputs):
|
|
|
"""Transform inputs for decoder.
|
|
|
Args:
|
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
|
Returns:
|
|
|
Tensor: The transformed inputs
|
|
|
"""
|
|
|
|
|
|
if self.input_transform == 'resize_concat':
|
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
|
upsampled_inputs = [
|
|
|
resize(
|
|
|
input=x,
|
|
|
size=inputs[0].shape[2:],
|
|
|
mode='bilinear',
|
|
|
align_corners=self.align_corners) for x in inputs
|
|
|
]
|
|
|
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
|
elif self.input_transform == 'multiple_select':
|
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
|
else:
|
|
|
inputs = inputs[self.in_index]
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
def forward(self, inputs1, inputs2):
|
|
|
|
|
|
x_1 = self._transform_inputs(inputs1)
|
|
|
x_2 = self._transform_inputs(inputs2)
|
|
|
|
|
|
|
|
|
c1_1, c2_1, c3_1, c4_1 = x_1
|
|
|
c1_2, c2_2, c3_2, c4_2 = x_2
|
|
|
|
|
|
|
|
|
n, _, h, w = c4_1.shape
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
_c4_1 = self.linear_c4(c4_1).permute(0,2,1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
|
|
|
_c4_2 = self.linear_c4(c4_2).permute(0,2,1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
|
|
|
_c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1))
|
|
|
p_c4 = self.make_pred_c4(_c4)
|
|
|
outputs.append(p_c4)
|
|
|
_c4_up= resize(_c4, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
|
|
|
|
_c3_1 = self.linear_c3(c3_1).permute(0,2,1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
|
|
|
_c3_2 = self.linear_c3(c3_2).permute(0,2,1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
|
|
|
_c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear")
|
|
|
p_c3 = self.make_pred_c3(_c3)
|
|
|
outputs.append(p_c3)
|
|
|
_c3_up= resize(_c3, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
|
|
|
|
_c2_1 = self.linear_c2(c2_1).permute(0,2,1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
|
|
|
_c2_2 = self.linear_c2(c2_2).permute(0,2,1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
|
|
|
_c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear")
|
|
|
p_c2 = self.make_pred_c2(_c2)
|
|
|
outputs.append(p_c2)
|
|
|
_c2_up= resize(_c2, size=c1_2.size()[2:], mode='bilinear', align_corners=False)
|
|
|
|
|
|
|
|
|
_c1_1 = self.linear_c1(c1_1).permute(0,2,1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
|
|
|
_c1_2 = self.linear_c1(c1_2).permute(0,2,1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
|
|
|
_c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear")
|
|
|
p_c1 = self.make_pred_c1(_c1)
|
|
|
outputs.append(p_c1)
|
|
|
|
|
|
|
|
|
_c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.convd2x(_c)
|
|
|
|
|
|
x = self.dense_2x(x)
|
|
|
|
|
|
x = self.convd1x(x)
|
|
|
|
|
|
x = self.dense_1x(x)
|
|
|
|
|
|
|
|
|
cp = self.change_probability(x)
|
|
|
|
|
|
outputs.append(cp)
|
|
|
|
|
|
if self.output_softmax:
|
|
|
temp = outputs
|
|
|
outputs = []
|
|
|
for pred in temp:
|
|
|
outputs.append(self.active(pred))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
class ChangeFormer_DE(nn.Module):
|
|
|
|
|
|
def __init__(self, output_nc=2, decoder_softmax=False, embed_dim=256):
|
|
|
super(ChangeFormer_DE, self).__init__()
|
|
|
|
|
|
self.embed_dims = [64, 128, 320, 512]
|
|
|
self.embedding_dim = embed_dim
|
|
|
|
|
|
|
|
|
self.TDec_x2 = DecoderTransformer_v3(input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=False,
|
|
|
in_channels = self.embed_dims, embedding_dim= self.embedding_dim, output_nc=output_nc,
|
|
|
decoder_softmax = decoder_softmax, feature_strides=[2, 4, 8, 16])
|
|
|
|
|
|
def forward(self, f):
|
|
|
fx1, fx2 = f[0], f[1]
|
|
|
cp = self.TDec_x2(fx1, fx2)
|
|
|
return cp[-1] |