import torch import torch.nn as nn import torch.nn.functional as F from rscd.models.backbones.vmamba import VSSM, LayerNorm2d, VSSBlock, Permute class ChangeDecoder(nn.Module): def __init__(self, encoder_dims, channel_first, norm_layer, ssm_act_layer, mlp_act_layer, **kwargs): super(ChangeDecoder, self).__init__() # Define the VSS Block for Spatio-temporal relationship modelling self.st_block_41 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1] * 2, out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_42 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_43 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_31 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2] * 2, out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_32 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_33 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_21 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3] * 2, out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_22 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_23 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_11 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4] * 2, out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_12 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) self.st_block_13 = nn.Sequential( nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), ) # Fuse layer self.fuse_layer_4 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), nn.BatchNorm2d(128), nn.ReLU()) self.fuse_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), nn.BatchNorm2d(128), nn.ReLU()) self.fuse_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), nn.BatchNorm2d(128), nn.ReLU()) self.fuse_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), nn.BatchNorm2d(128), nn.ReLU()) # Smooth layer self.smooth_layer_3 = ResBlock(in_channels=128, out_channels=128, stride=1) self.smooth_layer_2 = ResBlock(in_channels=128, out_channels=128, stride=1) self.smooth_layer_1 = ResBlock(in_channels=128, out_channels=128, stride=1) def _upsample_add(self, x, y): _, _, H, W = y.size() return F.interpolate(x, size=(H, W), mode='bilinear') + y def forward(self, pre_features, post_features): pre_feat_1, pre_feat_2, pre_feat_3, pre_feat_4 = pre_features post_feat_1, post_feat_2, post_feat_3, post_feat_4 = post_features ''' Stage I ''' p41 = self.st_block_41(torch.cat([pre_feat_4, post_feat_4], dim=1)) B, C, H, W = pre_feat_4.size() # Create an empty tensor of the correct shape (B, C, H, 2*W) ct_tensor_42 = torch.empty(B, C, H, 2*W).cuda() # Fill in odd columns with A and even columns with B ct_tensor_42[:, :, :, ::2] = pre_feat_4 # Odd columns ct_tensor_42[:, :, :, 1::2] = post_feat_4 # Even columns p42 = self.st_block_42(ct_tensor_42) ct_tensor_43 = torch.empty(B, C, H, 2*W).cuda() ct_tensor_43[:, :, :, 0:W] = pre_feat_4 ct_tensor_43[:, :, :, W:] = post_feat_4 p43 = self.st_block_43(ct_tensor_43) p4 = self.fuse_layer_4(torch.cat([p41, p42[:, :, :, ::2], p42[:, :, :, 1::2], p43[:, :, :, 0:W], p43[:, :, :, W:]], dim=1)) ''' Stage II ''' p31 = self.st_block_31(torch.cat([pre_feat_3, post_feat_3], dim=1)) B, C, H, W = pre_feat_3.size() # Create an empty tensor of the correct shape (B, C, H, 2*W) ct_tensor_32 = torch.empty(B, C, H, 2*W).cuda() # Fill in odd columns with A and even columns with B ct_tensor_32[:, :, :, ::2] = pre_feat_3 # Odd columns ct_tensor_32[:, :, :, 1::2] = post_feat_3 # Even columns p32 = self.st_block_32(ct_tensor_32) ct_tensor_33 = torch.empty(B, C, H, 2*W).cuda() ct_tensor_33[:, :, :, 0:W] = pre_feat_3 ct_tensor_33[:, :, :, W:] = post_feat_3 p33 = self.st_block_33(ct_tensor_33) p3 = self.fuse_layer_3(torch.cat([p31, p32[:, :, :, ::2], p32[:, :, :, 1::2], p33[:, :, :, 0:W], p33[:, :, :, W:]], dim=1)) p3 = self._upsample_add(p4, p3) p3 = self.smooth_layer_3(p3) ''' Stage III ''' p21 = self.st_block_21(torch.cat([pre_feat_2, post_feat_2], dim=1)) B, C, H, W = pre_feat_2.size() # Create an empty tensor of the correct shape (B, C, H, 2*W) ct_tensor_22 = torch.empty(B, C, H, 2*W).cuda() # Fill in odd columns with A and even columns with B ct_tensor_22[:, :, :, ::2] = pre_feat_2 # Odd columns ct_tensor_22[:, :, :, 1::2] = post_feat_2 # Even columns p22 = self.st_block_22(ct_tensor_22) ct_tensor_23 = torch.empty(B, C, H, 2*W).cuda() ct_tensor_23[:, :, :, 0:W] = pre_feat_2 ct_tensor_23[:, :, :, W:] = post_feat_2 p23 = self.st_block_23(ct_tensor_23) p2 = self.fuse_layer_2(torch.cat([p21, p22[:, :, :, ::2], p22[:, :, :, 1::2], p23[:, :, :, 0:W], p23[:, :, :, W:]], dim=1)) p2 = self._upsample_add(p3, p2) p2 = self.smooth_layer_2(p2) ''' Stage IV ''' p11 = self.st_block_11(torch.cat([pre_feat_1, post_feat_1], dim=1)) B, C, H, W = pre_feat_1.size() # Create an empty tensor of the correct shape (B, C, H, 2*W) ct_tensor_12 = torch.empty(B, C, H, 2*W).cuda() # Fill in odd columns with A and even columns with B ct_tensor_12[:, :, :, ::2] = pre_feat_1 # Odd columns ct_tensor_12[:, :, :, 1::2] = post_feat_1 # Even columns p12 = self.st_block_12(ct_tensor_12) ct_tensor_13 = torch.empty(B, C, H, 2*W).cuda() ct_tensor_13[:, :, :, 0:W] = pre_feat_1 ct_tensor_13[:, :, :, W:] = post_feat_1 p13 = self.st_block_13(ct_tensor_13) p1 = self.fuse_layer_1(torch.cat([p11, p12[:, :, :, ::2], p12[:, :, :, 1::2], p13[:, :, :, 0:W], p13[:, :, :, W:]], dim=1)) p1 = self._upsample_add(p2, p1) p1 = self.smooth_layer_1(p1) return p1 class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(ResBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class CMDecoder(nn.Module): def __init__(self, **kwargs): super(CMDecoder, self).__init__() _NORMLAYERS = dict( ln=nn.LayerNorm, ln2d=LayerNorm2d, bn=nn.BatchNorm2d, ) _ACTLAYERS = dict( silu=nn.SiLU, gelu=nn.GELU, relu=nn.ReLU, sigmoid=nn.Sigmoid, ) norm_layer: nn.Module = _NORMLAYERS.get(kwargs['norm_layer'].lower(), None) ssm_act_layer: nn.Module = _ACTLAYERS.get(kwargs['ssm_act_layer'].lower(), None) mlp_act_layer: nn.Module = _ACTLAYERS.get(kwargs['mlp_act_layer'].lower(), None) # Remove the explicitly passed args from kwargs to avoid "got multiple values" error clean_kwargs = {k: v for k, v in kwargs.items() if k not in ['norm_layer', 'ssm_act_layer', 'mlp_act_layer']} self.decoder = ChangeDecoder( encoder_dims= [int(kwargs['dims'] * 2 ** i_layer) for i_layer in range(len(kwargs['depths']))], channel_first=True, norm_layer=norm_layer, ssm_act_layer=ssm_act_layer, mlp_act_layer=mlp_act_layer, **clean_kwargs ) self.main_clf = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1) def _upsample_add(self, x, y): _, _, H, W = y.size() return F.interpolate(x, size=(H, W), mode='bilinear') + y def forward(self, xs): pre_features, post_features, pre_data_size = xs # Decoder processing - passing encoder outputs to the decoder output = self.decoder(pre_features, post_features) output = self.main_clf(output) output = F.interpolate(output, size=pre_data_size, mode='bilinear') return output