from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.segresnet_block import get_conv_layer, get_upsample_layer from monai.networks.layers.factories import Dropout from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import UpsampleMode from einops import rearrange from models.mamba_customer import ConvMamba, M3, PatchEmbed, PatchUnEmbed from models.Blocks import CAB, SAB, VSSBlock, ShallowFusionAttnBlock import warnings warnings.filterwarnings("ignore") def get_dwconv_layer( spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False ): depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels) point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1) return torch.nn.Sequential(depth_conv, point_conv) class SRCMLayer(nn.Module): def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2, conv_mode='deepwise'): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.norm = nn.LayerNorm(input_dim) self.convmamba = ConvMamba( d_model=input_dim, d_state=d_state, d_conv=d_conv, expand=expand, bimamba_type="v2", conv_mode=conv_mode ) self.proj = nn.Linear(input_dim, output_dim) self.skip_scale = nn.Parameter(torch.ones(1)) def forward(self, x): if x.dtype == torch.float16: x = x.type(torch.float32) B, C = x.shape[:2] assert C == self.input_dim n_tokens = x.shape[2:].numel() img_dims = x.shape[2:] x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) x_norm = self.norm(x_flat) x_mamba = self.convmamba(x_norm) + self.skip_scale * x_flat x_mamba = self.norm(x_mamba) x_mamba = self.proj(x_mamba) out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) return out def get_srcm_layer( spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1, conv_mode: str = "deepwise" ): srcm_layer = SRCMLayer(input_dim=in_channels, output_dim=out_channels, conv_mode=conv_mode) if stride != 1: if spatial_dims == 2: return nn.Sequential(srcm_layer, nn.MaxPool2d(kernel_size=stride, stride=stride)) return srcm_layer class SRCMBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, norm: tuple | str, kernel_size: int = 3, conv_mode: str = "deepwise", act: tuple | str = ("RELU", {"inplace": True}), ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2 or 3. in_channels: number of input channels. norm: feature normalization type and arguments. kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. act: activation type and arguments. Defaults to ``RELU``. """ super().__init__() if kernel_size % 2 != 1: raise AssertionError("kernel_size should be an odd number.") # print(conv_mode) self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.act = get_act_layer(act) self.conv1 = get_srcm_layer( spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode ) self.conv2 = get_srcm_layer( spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode ) def forward(self, x): identity = x x = self.norm1(x) x = self.act(x) x = self.conv1(x) x = self.norm2(x) x = self.act(x) x = self.conv2(x) x += identity return x class CSI(nn.Module): def __init__(self, dim): super(CSI, self).__init__() self.shallow_fusion_attn = ShallowFusionAttnBlock(dim) self.m3 = M3(dim) self.vss = VSSBlock(hidden_dim=dim) self.patch_embed = PatchEmbed(in_chans=dim, embed_dim=dim) self.patch_unembed = PatchUnEmbed(in_chans=dim, embed_dim=dim) def forward(self, I1, I2, h, w): I1_fuse, I2_fuse = self.shallow_fusion_attn(I1, I2, h, w) fusion = torch.abs(I1_fuse - I2_fuse) I1_token = self.patch_embed(I1_fuse) I2_token = self.patch_embed(I2_fuse) fusion_token = self.patch_embed(fusion) test_h, test_w = fusion.shape[2], fusion.shape[3] fusion_token, _ = self.m3(I1_token, I2_token, fusion_token, test_h, test_w) fusion_out = self.patch_unembed(fusion_token, (h, w)) return fusion_out class STNR(nn.Module): def __init__( self, spatial_dims: int = 2, init_filters: int = 16, in_channels: int = 1, out_channels: int = 2, conv_mode: str = "deepwise", local_query_model = "orignal_dinner", dropout_prob: float | None = None, act: tuple | str = ("RELU", {"inplace": True}), norm: tuple | str = ("GROUP", {"num_groups": 8}), norm_name: str = "", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), mode: str = "", up_mode="ResMamba", up_conv_mode="deepwise", resdiual=False, stage = 4, diff_abs="later", mamba_act = "silu", upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE, ): super().__init__() if spatial_dims not in (2, 3): raise ValueError("`spatial_dims` can only be 2 or 3.") self.mode = mode self.stage = stage self.up_conv_mode = up_conv_mode self.mamba_act = mamba_act self.resdiual = resdiual self.up_mode = up_mode self.diff_abs = diff_abs self.conv_mode = conv_mode self.local_query_model = local_query_model self.spatial_dims = spatial_dims self.init_filters = init_filters self.channels_list = [self.init_filters, self.init_filters*2, self.init_filters*4, self.init_filters*8] self.in_channels = in_channels self.blocks_down = blocks_down self.blocks_up = blocks_up print(self.blocks_up) self.dropout_prob = dropout_prob self.act = act # input options self.act_mod = get_act_layer(act) if norm_name: if norm_name.lower() != "group": raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.") norm = ("group", {"num_groups": num_groups}) self.norm = norm print(self.norm) self.upsample_mode = UpsampleMode(upsample_mode) self.use_conv_final = use_conv_final self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) self.srcm_encoder_layers = self._make_srcm_encoder_layers() self.srcm_decoder_layers, self.up_samples = self._make_srcm_decoder_layers(up_mode=self.up_mode) self.conv_final = self._make_final_conv(out_channels) self.fusion_blocks = nn.ModuleList( [CSI(self.channels_list[i]) for i in range(self.stage)] ) self.cab_layers = nn.ModuleList([ CAB(ch) for ch in self.channels_list[::-1][1:] ]) self.sab_layers = nn.ModuleList([ SAB(kernel_size=7) for _ in range(len(self.blocks_up)) ]) self.conv_down_layers = nn.ModuleList([ nn.Conv2d(ch * 2, ch, kernel_size=1, stride=1, padding=0) for ch in self.channels_list[::-1][1:] ]) if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob) def _make_srcm_encoder_layers(self): srcm_encoder_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm, conv_mode = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm, self.conv_mode) for i, item in enumerate(blocks_down): layer_in_channels = filters * 2 ** i downsample_mamba = ( get_srcm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2, conv_mode=conv_mode) if i > 0 else nn.Identity() ) down_layer = nn.Sequential( downsample_mamba, *[SRCMBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act, conv_mode=conv_mode) for _ in range(item)] ) srcm_encoder_layers.append(down_layer) return srcm_encoder_layers def _make_srcm_decoder_layers(self, up_mode): srcm_decoder_layers, up_samples = nn.ModuleList(), nn.ModuleList() upsample_mode, blocks_up, spatial_dims, filters, norm = ( self.upsample_mode, self.blocks_up, self.spatial_dims, self.init_filters, self.norm, ) if up_mode == 'SRCM': Block_up = SRCMBlock n_up = len(blocks_up) for i in range(n_up): sample_in_channels = filters * 2 ** (n_up - i) srcm_decoder_layers.append( nn.Sequential( *[ Block_up(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act, conv_mode=self.up_conv_mode) for _ in range(blocks_up[i]) ] ) ) up_samples.append( nn.Sequential( *[ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), ] ) ) return srcm_decoder_layers, up_samples def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters), self.act_mod, get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), ) def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) down_x = [] for down in self.srcm_encoder_layers: x = down(x) down_x.append(x) return x, down_x def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor: for i, (up, upl) in enumerate(zip(self.up_samples, self.srcm_decoder_layers)): skip = down_x[i + 1] x_up = up(x) + skip x_cab = self.cab_layers[i](x_up) * x_up x_sab = self.sab_layers[i](x_cab) * x_cab x_srcm = upl(x_up) combined_out = torch.cat([x_sab, x_srcm], dim=1) final_out = self.conv_down_layers[i](combined_out) x = final_out if self.use_conv_final: x = self.conv_final(x) return x def forward(self, x1: torch.Tensor, x2:torch.Tensor) -> torch.Tensor: b, c, h, w = x1.shape x1, down_x1 = self.encode(x1) x2, down_x2 = self.encode(x2) down_x = [] for i in range(len(down_x1)): x1_level, x2_level = down_x1[i], down_x2[i] H_i, W_i = x1_level.shape[2], x1_level.shape[3] if self.diff_abs == "later": if self.mode == "FUSION": if i < self.stage: zero_res = torch.zeros_like(x1_level) fusion = self.fusion_blocks[i](x1_level, x2_level, H_i, W_i) else: fusion = torch.abs(x1_level - x2_level) else: fusion = torch.abs(x1_level - x2_level) down_x.append(fusion) down_x.reverse() x = self.decode(down_x[0], down_x) return x if __name__ == "__main__": device = "cuda:0" CDMamba = STNR(spatial_dims=2, in_channels=3, out_channels=2, init_filters=16, norm=("GROUP", {"num_groups": 8}), mode="FUSION", conv_mode='orignal', local_query_model="orignal_dinner", stage=4, mamba_act="silu", up_mode="SRCM", up_conv_mode='deepwise', blocks_down=(1, 2, 2, 4), blocks_up=(1, 1, 1), resdiual=False, diff_abs="later").to(device) x = torch.randn(1, 3, 256, 256).to(device) y = CDMamba(x, x) print(y.shape)