import torch import torch.nn as nn import torch.nn.functional as F import warnings # Transformer Decoder class MLP(nn.Module): """ Linear Embedding """ def __init__(self, input_dim=2048, embed_dim=768): super().__init__() self.proj = nn.Linear(input_dim, embed_dim) def forward(self, x): x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class UpsampleConvLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride): super(UpsampleConvLayer, self).__init__() self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1) def forward(self, x): out = self.conv2d(x) return out class ResidualBlock(torch.nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x): residual = x out = self.relu(self.conv1(x)) out = self.conv2(out) * 0.1 out = torch.add(out, residual) return out class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super(ConvLayer, self).__init__() # reflection_padding = kernel_size // 2 # self.reflection_pad = nn.ReflectionPad2d(reflection_padding) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) def forward(self, x): # out = self.reflection_pad(x) out = self.conv2d(x) return out #Difference module def conv_diff(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU() ) #Intermediate prediction module def make_prediction(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) ) def resize(input, size=None, scale_factor=None, mode='nearest', align_corners=None, warning=True): if warning: if size is not None and align_corners: input_h, input_w = tuple(int(x) for x in input.shape[2:]) output_h, output_w = tuple(int(x) for x in size) if output_h > input_h or output_w > output_h: if ((output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) and (output_h - 1) % (input_h - 1) and (output_w - 1) % (input_w - 1)): warnings.warn( f'When align_corners={align_corners}, ' 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') return F.interpolate(input, size, scale_factor, mode, align_corners) class DecoderTransformer_v3(nn.Module): """ Transformer Decoder """ def __init__(self, input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=True, in_channels = [32, 64, 128, 256], embedding_dim= 64, output_nc=2, decoder_softmax = False, feature_strides=[2, 4, 8, 16]): super(DecoderTransformer_v3, self).__init__() #assert assert len(feature_strides) == len(in_channels) assert min(feature_strides) == feature_strides[0] #settings self.feature_strides = feature_strides self.input_transform = input_transform self.in_index = in_index self.align_corners = align_corners self.in_channels = in_channels self.embedding_dim = embedding_dim self.output_nc = output_nc c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels #MLP decoder heads self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) #convolutional Difference Modules self.diff_c4 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim) self.diff_c3 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim) self.diff_c2 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim) self.diff_c1 = conv_diff(in_channels=2*self.embedding_dim, out_channels=self.embedding_dim) #taking outputs from middle of the encoder self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) #Final linear fusion layer self.linear_fuse = nn.Sequential( nn.Conv2d( in_channels=self.embedding_dim*len(in_channels), out_channels=self.embedding_dim, kernel_size=1), nn.BatchNorm2d(self.embedding_dim) ) #Final predction head self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim)) self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim)) self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) #Final activation self.output_softmax = decoder_softmax self.active = nn.Sigmoid() def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs def forward(self, inputs1, inputs2): #Transforming encoder features (select layers) x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16 x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16 #img1 and img2 features c1_1, c2_1, c3_1, c4_1 = x_1 c1_2, c2_2, c3_2, c4_2 = x_2 ############## MLP decoder on C1-C4 ########### n, _, h, w = c4_1.shape outputs = [] # Stage 4: x1/32 scale _c4_1 = self.linear_c4(c4_1).permute(0,2,1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3]) _c4_2 = self.linear_c4(c4_2).permute(0,2,1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3]) _c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1)) p_c4 = self.make_pred_c4(_c4) outputs.append(p_c4) _c4_up= resize(_c4, size=c1_2.size()[2:], mode='bilinear', align_corners=False) # Stage 3: x1/16 scale _c3_1 = self.linear_c3(c3_1).permute(0,2,1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3]) _c3_2 = self.linear_c3(c3_2).permute(0,2,1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3]) _c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear") p_c3 = self.make_pred_c3(_c3) outputs.append(p_c3) _c3_up= resize(_c3, size=c1_2.size()[2:], mode='bilinear', align_corners=False) # Stage 2: x1/8 scale _c2_1 = self.linear_c2(c2_1).permute(0,2,1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3]) _c2_2 = self.linear_c2(c2_2).permute(0,2,1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3]) _c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear") p_c2 = self.make_pred_c2(_c2) outputs.append(p_c2) _c2_up= resize(_c2, size=c1_2.size()[2:], mode='bilinear', align_corners=False) # Stage 1: x1/4 scale _c1_1 = self.linear_c1(c1_1).permute(0,2,1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3]) _c1_2 = self.linear_c1(c1_2).permute(0,2,1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3]) _c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear") p_c1 = self.make_pred_c1(_c1) outputs.append(p_c1) #Linear Fusion of difference image from all scales _c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1)) # #Dropout # if dropout_ratio > 0: # self.dropout = nn.Dropout2d(dropout_ratio) # else: # self.dropout = None #Upsampling x2 (x1/2 scale) x = self.convd2x(_c) #Residual block x = self.dense_2x(x) #Upsampling x2 (x1 scale) x = self.convd1x(x) #Residual block x = self.dense_1x(x) #Final prediction cp = self.change_probability(x) outputs.append(cp) if self.output_softmax: temp = outputs outputs = [] for pred in temp: outputs.append(self.active(pred)) return outputs class ChangeFormer_DE(nn.Module): def __init__(self, output_nc=2, decoder_softmax=False, embed_dim=256): super(ChangeFormer_DE, self).__init__() #Transformer Encoder self.embed_dims = [64, 128, 320, 512] self.embedding_dim = embed_dim #Transformer Decoder self.TDec_x2 = DecoderTransformer_v3(input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=False, in_channels = self.embed_dims, embedding_dim= self.embedding_dim, output_nc=output_nc, decoder_softmax = decoder_softmax, feature_strides=[2, 4, 8, 16]) def forward(self, f): fx1, fx2 = f[0], f[1] cp = self.TDec_x2(fx1, fx2) return cp[-1]