InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from rscd.models.backbones.vmamba import VSSM, LayerNorm2d, VSSBlock, Permute
class ChangeDecoder(nn.Module):
def __init__(self, encoder_dims, channel_first, norm_layer, ssm_act_layer, mlp_act_layer, **kwargs):
super(ChangeDecoder, self).__init__()
# Define the VSS Block for Spatio-temporal relationship modelling
self.st_block_41 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1] * 2, out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_42 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_43 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_31 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2] * 2, out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_32 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_33 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_21 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3] * 2, out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_22 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_23 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_11 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4] * 2, out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_12 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
self.st_block_13 = nn.Sequential(
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128),
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(),
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first,
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer,
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'],
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'],
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']),
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(),
)
# Fuse layer
self.fuse_layer_4 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128),
nn.BatchNorm2d(128), nn.ReLU())
self.fuse_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128),
nn.BatchNorm2d(128), nn.ReLU())
self.fuse_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128),
nn.BatchNorm2d(128), nn.ReLU())
self.fuse_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128),
nn.BatchNorm2d(128), nn.ReLU())
# Smooth layer
self.smooth_layer_3 = ResBlock(in_channels=128, out_channels=128, stride=1)
self.smooth_layer_2 = ResBlock(in_channels=128, out_channels=128, stride=1)
self.smooth_layer_1 = ResBlock(in_channels=128, out_channels=128, stride=1)
def _upsample_add(self, x, y):
_, _, H, W = y.size()
return F.interpolate(x, size=(H, W), mode='bilinear') + y
def forward(self, pre_features, post_features):
pre_feat_1, pre_feat_2, pre_feat_3, pre_feat_4 = pre_features
post_feat_1, post_feat_2, post_feat_3, post_feat_4 = post_features
'''
Stage I
'''
p41 = self.st_block_41(torch.cat([pre_feat_4, post_feat_4], dim=1))
B, C, H, W = pre_feat_4.size()
# Create an empty tensor of the correct shape (B, C, H, 2*W)
ct_tensor_42 = torch.empty(B, C, H, 2*W).cuda()
# Fill in odd columns with A and even columns with B
ct_tensor_42[:, :, :, ::2] = pre_feat_4 # Odd columns
ct_tensor_42[:, :, :, 1::2] = post_feat_4 # Even columns
p42 = self.st_block_42(ct_tensor_42)
ct_tensor_43 = torch.empty(B, C, H, 2*W).cuda()
ct_tensor_43[:, :, :, 0:W] = pre_feat_4
ct_tensor_43[:, :, :, W:] = post_feat_4
p43 = self.st_block_43(ct_tensor_43)
p4 = self.fuse_layer_4(torch.cat([p41, p42[:, :, :, ::2], p42[:, :, :, 1::2], p43[:, :, :, 0:W], p43[:, :, :, W:]], dim=1))
'''
Stage II
'''
p31 = self.st_block_31(torch.cat([pre_feat_3, post_feat_3], dim=1))
B, C, H, W = pre_feat_3.size()
# Create an empty tensor of the correct shape (B, C, H, 2*W)
ct_tensor_32 = torch.empty(B, C, H, 2*W).cuda()
# Fill in odd columns with A and even columns with B
ct_tensor_32[:, :, :, ::2] = pre_feat_3 # Odd columns
ct_tensor_32[:, :, :, 1::2] = post_feat_3 # Even columns
p32 = self.st_block_32(ct_tensor_32)
ct_tensor_33 = torch.empty(B, C, H, 2*W).cuda()
ct_tensor_33[:, :, :, 0:W] = pre_feat_3
ct_tensor_33[:, :, :, W:] = post_feat_3
p33 = self.st_block_33(ct_tensor_33)
p3 = self.fuse_layer_3(torch.cat([p31, p32[:, :, :, ::2], p32[:, :, :, 1::2], p33[:, :, :, 0:W], p33[:, :, :, W:]], dim=1))
p3 = self._upsample_add(p4, p3)
p3 = self.smooth_layer_3(p3)
'''
Stage III
'''
p21 = self.st_block_21(torch.cat([pre_feat_2, post_feat_2], dim=1))
B, C, H, W = pre_feat_2.size()
# Create an empty tensor of the correct shape (B, C, H, 2*W)
ct_tensor_22 = torch.empty(B, C, H, 2*W).cuda()
# Fill in odd columns with A and even columns with B
ct_tensor_22[:, :, :, ::2] = pre_feat_2 # Odd columns
ct_tensor_22[:, :, :, 1::2] = post_feat_2 # Even columns
p22 = self.st_block_22(ct_tensor_22)
ct_tensor_23 = torch.empty(B, C, H, 2*W).cuda()
ct_tensor_23[:, :, :, 0:W] = pre_feat_2
ct_tensor_23[:, :, :, W:] = post_feat_2
p23 = self.st_block_23(ct_tensor_23)
p2 = self.fuse_layer_2(torch.cat([p21, p22[:, :, :, ::2], p22[:, :, :, 1::2], p23[:, :, :, 0:W], p23[:, :, :, W:]], dim=1))
p2 = self._upsample_add(p3, p2)
p2 = self.smooth_layer_2(p2)
'''
Stage IV
'''
p11 = self.st_block_11(torch.cat([pre_feat_1, post_feat_1], dim=1))
B, C, H, W = pre_feat_1.size()
# Create an empty tensor of the correct shape (B, C, H, 2*W)
ct_tensor_12 = torch.empty(B, C, H, 2*W).cuda()
# Fill in odd columns with A and even columns with B
ct_tensor_12[:, :, :, ::2] = pre_feat_1 # Odd columns
ct_tensor_12[:, :, :, 1::2] = post_feat_1 # Even columns
p12 = self.st_block_12(ct_tensor_12)
ct_tensor_13 = torch.empty(B, C, H, 2*W).cuda()
ct_tensor_13[:, :, :, 0:W] = pre_feat_1
ct_tensor_13[:, :, :, W:] = post_feat_1
p13 = self.st_block_13(ct_tensor_13)
p1 = self.fuse_layer_1(torch.cat([p11, p12[:, :, :, ::2], p12[:, :, :, 1::2], p13[:, :, :, 0:W], p13[:, :, :, W:]], dim=1))
p1 = self._upsample_add(p2, p1)
p1 = self.smooth_layer_1(p1)
return p1
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class CMDecoder(nn.Module):
def __init__(self, **kwargs):
super(CMDecoder, self).__init__()
_NORMLAYERS = dict(
ln=nn.LayerNorm,
ln2d=LayerNorm2d,
bn=nn.BatchNorm2d,
)
_ACTLAYERS = dict(
silu=nn.SiLU,
gelu=nn.GELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
)
norm_layer: nn.Module = _NORMLAYERS.get(kwargs['norm_layer'].lower(), None)
ssm_act_layer: nn.Module = _ACTLAYERS.get(kwargs['ssm_act_layer'].lower(), None)
mlp_act_layer: nn.Module = _ACTLAYERS.get(kwargs['mlp_act_layer'].lower(), None)
# Remove the explicitly passed args from kwargs to avoid "got multiple values" error
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ['norm_layer', 'ssm_act_layer', 'mlp_act_layer']}
self.decoder = ChangeDecoder(
encoder_dims= [int(kwargs['dims'] * 2 ** i_layer) for i_layer in range(len(kwargs['depths']))],
channel_first=True,
norm_layer=norm_layer,
ssm_act_layer=ssm_act_layer,
mlp_act_layer=mlp_act_layer,
**clean_kwargs
)
self.main_clf = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1)
def _upsample_add(self, x, y):
_, _, H, W = y.size()
return F.interpolate(x, size=(H, W), mode='bilinear') + y
def forward(self, xs):
pre_features, post_features, pre_data_size = xs
# Decoder processing - passing encoder outputs to the decoder
output = self.decoder(pre_features, post_features)
output = self.main_clf(output)
output = F.interpolate(output, size=pre_data_size, mode='bilinear')
return output