|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from rscd.models.backbones.vmamba import VSSM, LayerNorm2d, VSSBlock, Permute |
|
|
|
|
|
|
|
|
class ChangeDecoder(nn.Module): |
|
|
def __init__(self, encoder_dims, channel_first, norm_layer, ssm_act_layer, mlp_act_layer, **kwargs): |
|
|
super(ChangeDecoder, self).__init__() |
|
|
|
|
|
|
|
|
self.st_block_41 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1] * 2, out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_42 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
|
|
|
) |
|
|
self.st_block_43 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
|
|
|
self.st_block_31 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2] * 2, out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_32 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_33 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
|
|
|
self.st_block_21 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3] * 2, out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_22 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_23 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
|
|
|
self.st_block_11 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4] * 2, out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_12 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
self.st_block_13 = nn.Sequential( |
|
|
nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), |
|
|
Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
|
|
VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
|
|
ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
|
|
ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
|
|
forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
|
|
gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
|
|
Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
|
|
) |
|
|
|
|
|
|
|
|
self.fuse_layer_4 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
|
|
nn.BatchNorm2d(128), nn.ReLU()) |
|
|
self.fuse_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
|
|
nn.BatchNorm2d(128), nn.ReLU()) |
|
|
self.fuse_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
|
|
nn.BatchNorm2d(128), nn.ReLU()) |
|
|
self.fuse_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
|
|
nn.BatchNorm2d(128), nn.ReLU()) |
|
|
|
|
|
|
|
|
self.smooth_layer_3 = ResBlock(in_channels=128, out_channels=128, stride=1) |
|
|
self.smooth_layer_2 = ResBlock(in_channels=128, out_channels=128, stride=1) |
|
|
self.smooth_layer_1 = ResBlock(in_channels=128, out_channels=128, stride=1) |
|
|
|
|
|
def _upsample_add(self, x, y): |
|
|
_, _, H, W = y.size() |
|
|
return F.interpolate(x, size=(H, W), mode='bilinear') + y |
|
|
|
|
|
def forward(self, pre_features, post_features): |
|
|
|
|
|
pre_feat_1, pre_feat_2, pre_feat_3, pre_feat_4 = pre_features |
|
|
|
|
|
post_feat_1, post_feat_2, post_feat_3, post_feat_4 = post_features |
|
|
|
|
|
''' |
|
|
Stage I |
|
|
''' |
|
|
p41 = self.st_block_41(torch.cat([pre_feat_4, post_feat_4], dim=1)) |
|
|
B, C, H, W = pre_feat_4.size() |
|
|
|
|
|
ct_tensor_42 = torch.empty(B, C, H, 2*W).cuda() |
|
|
|
|
|
ct_tensor_42[:, :, :, ::2] = pre_feat_4 |
|
|
ct_tensor_42[:, :, :, 1::2] = post_feat_4 |
|
|
p42 = self.st_block_42(ct_tensor_42) |
|
|
|
|
|
ct_tensor_43 = torch.empty(B, C, H, 2*W).cuda() |
|
|
ct_tensor_43[:, :, :, 0:W] = pre_feat_4 |
|
|
ct_tensor_43[:, :, :, W:] = post_feat_4 |
|
|
p43 = self.st_block_43(ct_tensor_43) |
|
|
|
|
|
p4 = self.fuse_layer_4(torch.cat([p41, p42[:, :, :, ::2], p42[:, :, :, 1::2], p43[:, :, :, 0:W], p43[:, :, :, W:]], dim=1)) |
|
|
|
|
|
|
|
|
''' |
|
|
Stage II |
|
|
''' |
|
|
p31 = self.st_block_31(torch.cat([pre_feat_3, post_feat_3], dim=1)) |
|
|
B, C, H, W = pre_feat_3.size() |
|
|
|
|
|
ct_tensor_32 = torch.empty(B, C, H, 2*W).cuda() |
|
|
|
|
|
ct_tensor_32[:, :, :, ::2] = pre_feat_3 |
|
|
ct_tensor_32[:, :, :, 1::2] = post_feat_3 |
|
|
p32 = self.st_block_32(ct_tensor_32) |
|
|
|
|
|
ct_tensor_33 = torch.empty(B, C, H, 2*W).cuda() |
|
|
ct_tensor_33[:, :, :, 0:W] = pre_feat_3 |
|
|
ct_tensor_33[:, :, :, W:] = post_feat_3 |
|
|
p33 = self.st_block_33(ct_tensor_33) |
|
|
|
|
|
p3 = self.fuse_layer_3(torch.cat([p31, p32[:, :, :, ::2], p32[:, :, :, 1::2], p33[:, :, :, 0:W], p33[:, :, :, W:]], dim=1)) |
|
|
p3 = self._upsample_add(p4, p3) |
|
|
p3 = self.smooth_layer_3(p3) |
|
|
|
|
|
''' |
|
|
Stage III |
|
|
''' |
|
|
p21 = self.st_block_21(torch.cat([pre_feat_2, post_feat_2], dim=1)) |
|
|
B, C, H, W = pre_feat_2.size() |
|
|
|
|
|
ct_tensor_22 = torch.empty(B, C, H, 2*W).cuda() |
|
|
|
|
|
ct_tensor_22[:, :, :, ::2] = pre_feat_2 |
|
|
ct_tensor_22[:, :, :, 1::2] = post_feat_2 |
|
|
p22 = self.st_block_22(ct_tensor_22) |
|
|
|
|
|
ct_tensor_23 = torch.empty(B, C, H, 2*W).cuda() |
|
|
ct_tensor_23[:, :, :, 0:W] = pre_feat_2 |
|
|
ct_tensor_23[:, :, :, W:] = post_feat_2 |
|
|
p23 = self.st_block_23(ct_tensor_23) |
|
|
|
|
|
p2 = self.fuse_layer_2(torch.cat([p21, p22[:, :, :, ::2], p22[:, :, :, 1::2], p23[:, :, :, 0:W], p23[:, :, :, W:]], dim=1)) |
|
|
p2 = self._upsample_add(p3, p2) |
|
|
p2 = self.smooth_layer_2(p2) |
|
|
|
|
|
''' |
|
|
Stage IV |
|
|
''' |
|
|
p11 = self.st_block_11(torch.cat([pre_feat_1, post_feat_1], dim=1)) |
|
|
B, C, H, W = pre_feat_1.size() |
|
|
|
|
|
ct_tensor_12 = torch.empty(B, C, H, 2*W).cuda() |
|
|
|
|
|
ct_tensor_12[:, :, :, ::2] = pre_feat_1 |
|
|
ct_tensor_12[:, :, :, 1::2] = post_feat_1 |
|
|
p12 = self.st_block_12(ct_tensor_12) |
|
|
|
|
|
ct_tensor_13 = torch.empty(B, C, H, 2*W).cuda() |
|
|
ct_tensor_13[:, :, :, 0:W] = pre_feat_1 |
|
|
ct_tensor_13[:, :, :, W:] = post_feat_1 |
|
|
p13 = self.st_block_13(ct_tensor_13) |
|
|
|
|
|
p1 = self.fuse_layer_1(torch.cat([p11, p12[:, :, :, ::2], p12[:, :, :, 1::2], p13[:, :, :, 0:W], p13[:, :, :, W:]], dim=1)) |
|
|
|
|
|
p1 = self._upsample_add(p2, p1) |
|
|
p1 = self.smooth_layer_1(p1) |
|
|
|
|
|
return p1 |
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None): |
|
|
super(ResBlock, self).__init__() |
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
self.downsample = downsample |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
|
|
|
if self.downsample is not None: |
|
|
identity = self.downsample(x) |
|
|
|
|
|
out += identity |
|
|
out = self.relu(out) |
|
|
|
|
|
return out |
|
|
|
|
|
class CMDecoder(nn.Module): |
|
|
def __init__(self, **kwargs): |
|
|
super(CMDecoder, self).__init__() |
|
|
|
|
|
_NORMLAYERS = dict( |
|
|
ln=nn.LayerNorm, |
|
|
ln2d=LayerNorm2d, |
|
|
bn=nn.BatchNorm2d, |
|
|
) |
|
|
|
|
|
_ACTLAYERS = dict( |
|
|
silu=nn.SiLU, |
|
|
gelu=nn.GELU, |
|
|
relu=nn.ReLU, |
|
|
sigmoid=nn.Sigmoid, |
|
|
) |
|
|
|
|
|
|
|
|
norm_layer: nn.Module = _NORMLAYERS.get(kwargs['norm_layer'].lower(), None) |
|
|
ssm_act_layer: nn.Module = _ACTLAYERS.get(kwargs['ssm_act_layer'].lower(), None) |
|
|
mlp_act_layer: nn.Module = _ACTLAYERS.get(kwargs['mlp_act_layer'].lower(), None) |
|
|
|
|
|
|
|
|
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ['norm_layer', 'ssm_act_layer', 'mlp_act_layer']} |
|
|
self.decoder = ChangeDecoder( |
|
|
encoder_dims= [int(kwargs['dims'] * 2 ** i_layer) for i_layer in range(len(kwargs['depths']))], |
|
|
channel_first=True, |
|
|
norm_layer=norm_layer, |
|
|
ssm_act_layer=ssm_act_layer, |
|
|
mlp_act_layer=mlp_act_layer, |
|
|
**clean_kwargs |
|
|
) |
|
|
|
|
|
self.main_clf = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1) |
|
|
|
|
|
def _upsample_add(self, x, y): |
|
|
_, _, H, W = y.size() |
|
|
return F.interpolate(x, size=(H, W), mode='bilinear') + y |
|
|
|
|
|
def forward(self, xs): |
|
|
pre_features, post_features, pre_data_size = xs |
|
|
|
|
|
output = self.decoder(pre_features, post_features) |
|
|
|
|
|
output = self.main_clf(output) |
|
|
output = F.interpolate(output, size=pre_data_size, mode='bilinear') |
|
|
return output |