|
|
from __future__ import annotations
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from monai.networks.blocks.convolutions import Convolution
|
|
|
from monai.networks.blocks.segresnet_block import get_conv_layer, get_upsample_layer
|
|
|
from monai.networks.layers.factories import Dropout
|
|
|
from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
|
|
from monai.utils import UpsampleMode
|
|
|
from einops import rearrange
|
|
|
from models.mamba_customer import ConvMamba, M3, PatchEmbed, PatchUnEmbed
|
|
|
from models.Blocks import CAB, SAB, VSSBlock, ShallowFusionAttnBlock
|
|
|
import warnings
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def get_dwconv_layer(
|
|
|
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
|
|
bias: bool = False
|
|
|
):
|
|
|
depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
|
|
|
strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
|
|
|
point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
|
|
|
strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
|
|
|
return torch.nn.Sequential(depth_conv, point_conv)
|
|
|
|
|
|
|
|
|
class SRCMLayer(nn.Module):
|
|
|
def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2, conv_mode='deepwise'):
|
|
|
super().__init__()
|
|
|
self.input_dim = input_dim
|
|
|
self.output_dim = output_dim
|
|
|
self.norm = nn.LayerNorm(input_dim)
|
|
|
self.convmamba = ConvMamba(
|
|
|
d_model=input_dim,
|
|
|
d_state=d_state,
|
|
|
d_conv=d_conv,
|
|
|
expand=expand,
|
|
|
bimamba_type="v2",
|
|
|
conv_mode=conv_mode
|
|
|
)
|
|
|
self.proj = nn.Linear(input_dim, output_dim)
|
|
|
self.skip_scale = nn.Parameter(torch.ones(1))
|
|
|
|
|
|
def forward(self, x):
|
|
|
if x.dtype == torch.float16:
|
|
|
x = x.type(torch.float32)
|
|
|
B, C = x.shape[:2]
|
|
|
assert C == self.input_dim
|
|
|
n_tokens = x.shape[2:].numel()
|
|
|
img_dims = x.shape[2:]
|
|
|
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
|
|
|
x_norm = self.norm(x_flat)
|
|
|
x_mamba = self.convmamba(x_norm) + self.skip_scale * x_flat
|
|
|
x_mamba = self.norm(x_mamba)
|
|
|
x_mamba = self.proj(x_mamba)
|
|
|
out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
|
|
|
return out
|
|
|
|
|
|
|
|
|
def get_srcm_layer(
|
|
|
spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1, conv_mode: str = "deepwise"
|
|
|
):
|
|
|
srcm_layer = SRCMLayer(input_dim=in_channels, output_dim=out_channels, conv_mode=conv_mode)
|
|
|
if stride != 1:
|
|
|
if spatial_dims == 2:
|
|
|
return nn.Sequential(srcm_layer, nn.MaxPool2d(kernel_size=stride, stride=stride))
|
|
|
return srcm_layer
|
|
|
|
|
|
class SRCMBlock(nn.Module):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
spatial_dims: int,
|
|
|
in_channels: int,
|
|
|
norm: tuple | str,
|
|
|
kernel_size: int = 3,
|
|
|
conv_mode: str = "deepwise",
|
|
|
act: tuple | str = ("RELU", {"inplace": True}),
|
|
|
) -> None:
|
|
|
"""
|
|
|
Args:
|
|
|
spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
|
|
|
in_channels: number of input channels.
|
|
|
norm: feature normalization type and arguments.
|
|
|
kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
|
|
|
act: activation type and arguments. Defaults to ``RELU``.
|
|
|
"""
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
if kernel_size % 2 != 1:
|
|
|
raise AssertionError("kernel_size should be an odd number.")
|
|
|
|
|
|
self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
|
|
self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
|
|
self.act = get_act_layer(act)
|
|
|
self.conv1 = get_srcm_layer(
|
|
|
spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
|
|
)
|
|
|
self.conv2 = get_srcm_layer(
|
|
|
spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
identity = x
|
|
|
|
|
|
x = self.norm1(x)
|
|
|
x = self.act(x)
|
|
|
x = self.conv1(x)
|
|
|
|
|
|
x = self.norm2(x)
|
|
|
x = self.act(x)
|
|
|
x = self.conv2(x)
|
|
|
|
|
|
x += identity
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class CSI(nn.Module):
|
|
|
def __init__(self, dim):
|
|
|
super(CSI, self).__init__()
|
|
|
self.shallow_fusion_attn = ShallowFusionAttnBlock(dim)
|
|
|
self.m3 = M3(dim)
|
|
|
self.vss = VSSBlock(hidden_dim=dim)
|
|
|
self.patch_embed = PatchEmbed(in_chans=dim, embed_dim=dim)
|
|
|
self.patch_unembed = PatchUnEmbed(in_chans=dim, embed_dim=dim)
|
|
|
def forward(self, I1, I2, h, w):
|
|
|
I1_fuse, I2_fuse = self.shallow_fusion_attn(I1, I2, h, w)
|
|
|
fusion = torch.abs(I1_fuse - I2_fuse)
|
|
|
I1_token = self.patch_embed(I1_fuse)
|
|
|
I2_token = self.patch_embed(I2_fuse)
|
|
|
fusion_token = self.patch_embed(fusion)
|
|
|
test_h, test_w = fusion.shape[2], fusion.shape[3]
|
|
|
fusion_token, _ = self.m3(I1_token, I2_token, fusion_token, test_h, test_w)
|
|
|
fusion_out = self.patch_unembed(fusion_token, (h, w))
|
|
|
return fusion_out
|
|
|
|
|
|
class STNR(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
spatial_dims: int = 2,
|
|
|
init_filters: int = 16,
|
|
|
in_channels: int = 1,
|
|
|
out_channels: int = 2,
|
|
|
conv_mode: str = "deepwise",
|
|
|
local_query_model = "orignal_dinner",
|
|
|
dropout_prob: float | None = None,
|
|
|
act: tuple | str = ("RELU", {"inplace": True}),
|
|
|
norm: tuple | str = ("GROUP", {"num_groups": 8}),
|
|
|
norm_name: str = "",
|
|
|
num_groups: int = 8,
|
|
|
use_conv_final: bool = True,
|
|
|
blocks_down: tuple = (1, 2, 2, 4),
|
|
|
blocks_up: tuple = (1, 1, 1),
|
|
|
mode: str = "",
|
|
|
up_mode="ResMamba",
|
|
|
up_conv_mode="deepwise",
|
|
|
resdiual=False,
|
|
|
stage = 4,
|
|
|
diff_abs="later",
|
|
|
mamba_act = "silu",
|
|
|
upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
if spatial_dims not in (2, 3):
|
|
|
raise ValueError("`spatial_dims` can only be 2 or 3.")
|
|
|
self.mode = mode
|
|
|
self.stage = stage
|
|
|
self.up_conv_mode = up_conv_mode
|
|
|
self.mamba_act = mamba_act
|
|
|
self.resdiual = resdiual
|
|
|
self.up_mode = up_mode
|
|
|
self.diff_abs = diff_abs
|
|
|
self.conv_mode = conv_mode
|
|
|
self.local_query_model = local_query_model
|
|
|
self.spatial_dims = spatial_dims
|
|
|
self.init_filters = init_filters
|
|
|
self.channels_list = [self.init_filters, self.init_filters*2, self.init_filters*4, self.init_filters*8]
|
|
|
self.in_channels = in_channels
|
|
|
self.blocks_down = blocks_down
|
|
|
self.blocks_up = blocks_up
|
|
|
print(self.blocks_up)
|
|
|
self.dropout_prob = dropout_prob
|
|
|
self.act = act
|
|
|
self.act_mod = get_act_layer(act)
|
|
|
if norm_name:
|
|
|
if norm_name.lower() != "group":
|
|
|
raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.")
|
|
|
norm = ("group", {"num_groups": num_groups})
|
|
|
self.norm = norm
|
|
|
print(self.norm)
|
|
|
self.upsample_mode = UpsampleMode(upsample_mode)
|
|
|
self.use_conv_final = use_conv_final
|
|
|
self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
|
|
|
self.srcm_encoder_layers = self._make_srcm_encoder_layers()
|
|
|
self.srcm_decoder_layers, self.up_samples = self._make_srcm_decoder_layers(up_mode=self.up_mode)
|
|
|
self.conv_final = self._make_final_conv(out_channels)
|
|
|
self.fusion_blocks = nn.ModuleList(
|
|
|
[CSI(self.channels_list[i]) for i in range(self.stage)]
|
|
|
)
|
|
|
self.cab_layers = nn.ModuleList([
|
|
|
CAB(ch) for ch in self.channels_list[::-1][1:]
|
|
|
])
|
|
|
self.sab_layers = nn.ModuleList([
|
|
|
SAB(kernel_size=7) for _ in range(len(self.blocks_up))
|
|
|
])
|
|
|
self.conv_down_layers = nn.ModuleList([
|
|
|
nn.Conv2d(ch * 2, ch, kernel_size=1, stride=1, padding=0) for ch in self.channels_list[::-1][1:]
|
|
|
])
|
|
|
if dropout_prob is not None:
|
|
|
self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
|
|
|
|
|
|
def _make_srcm_encoder_layers(self):
|
|
|
srcm_encoder_layers = nn.ModuleList()
|
|
|
blocks_down, spatial_dims, filters, norm, conv_mode = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm, self.conv_mode)
|
|
|
for i, item in enumerate(blocks_down):
|
|
|
layer_in_channels = filters * 2 ** i
|
|
|
downsample_mamba = (
|
|
|
get_srcm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2, conv_mode=conv_mode)
|
|
|
if i > 0
|
|
|
else nn.Identity()
|
|
|
)
|
|
|
down_layer = nn.Sequential(
|
|
|
downsample_mamba,
|
|
|
*[SRCMBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act, conv_mode=conv_mode) for _ in range(item)]
|
|
|
)
|
|
|
srcm_encoder_layers.append(down_layer)
|
|
|
return srcm_encoder_layers
|
|
|
|
|
|
def _make_srcm_decoder_layers(self, up_mode):
|
|
|
srcm_decoder_layers, up_samples = nn.ModuleList(), nn.ModuleList()
|
|
|
upsample_mode, blocks_up, spatial_dims, filters, norm = (
|
|
|
self.upsample_mode,
|
|
|
self.blocks_up,
|
|
|
self.spatial_dims,
|
|
|
self.init_filters,
|
|
|
self.norm,
|
|
|
)
|
|
|
if up_mode == 'SRCM':
|
|
|
Block_up = SRCMBlock
|
|
|
n_up = len(blocks_up)
|
|
|
for i in range(n_up):
|
|
|
sample_in_channels = filters * 2 ** (n_up - i)
|
|
|
srcm_decoder_layers.append(
|
|
|
nn.Sequential(
|
|
|
*[
|
|
|
Block_up(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act, conv_mode=self.up_conv_mode)
|
|
|
for _ in range(blocks_up[i])
|
|
|
]
|
|
|
)
|
|
|
)
|
|
|
up_samples.append(
|
|
|
nn.Sequential(
|
|
|
*[
|
|
|
get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1),
|
|
|
get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode),
|
|
|
]
|
|
|
)
|
|
|
)
|
|
|
return srcm_decoder_layers, up_samples
|
|
|
|
|
|
def _make_final_conv(self, out_channels: int):
|
|
|
return nn.Sequential(
|
|
|
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
|
|
|
self.act_mod,
|
|
|
get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
|
|
|
)
|
|
|
|
|
|
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
|
x = self.convInit(x)
|
|
|
if self.dropout_prob is not None:
|
|
|
x = self.dropout(x)
|
|
|
down_x = []
|
|
|
|
|
|
for down in self.srcm_encoder_layers:
|
|
|
x = down(x)
|
|
|
down_x.append(x)
|
|
|
|
|
|
return x, down_x
|
|
|
|
|
|
def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor:
|
|
|
for i, (up, upl) in enumerate(zip(self.up_samples, self.srcm_decoder_layers)):
|
|
|
skip = down_x[i + 1]
|
|
|
x_up = up(x) + skip
|
|
|
x_cab = self.cab_layers[i](x_up) * x_up
|
|
|
x_sab = self.sab_layers[i](x_cab) * x_cab
|
|
|
x_srcm = upl(x_up)
|
|
|
combined_out = torch.cat([x_sab, x_srcm], dim=1)
|
|
|
final_out = self.conv_down_layers[i](combined_out)
|
|
|
x = final_out
|
|
|
if self.use_conv_final:
|
|
|
x = self.conv_final(x)
|
|
|
return x
|
|
|
|
|
|
def forward(self, x1: torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
|
|
|
b, c, h, w = x1.shape
|
|
|
x1, down_x1 = self.encode(x1)
|
|
|
x2, down_x2 = self.encode(x2)
|
|
|
down_x = []
|
|
|
for i in range(len(down_x1)):
|
|
|
x1_level, x2_level = down_x1[i], down_x2[i]
|
|
|
H_i, W_i = x1_level.shape[2], x1_level.shape[3]
|
|
|
if self.diff_abs == "later":
|
|
|
if self.mode == "FUSION":
|
|
|
if i < self.stage:
|
|
|
zero_res = torch.zeros_like(x1_level)
|
|
|
fusion = self.fusion_blocks[i](x1_level, x2_level, H_i, W_i)
|
|
|
else:
|
|
|
fusion = torch.abs(x1_level - x2_level)
|
|
|
else:
|
|
|
fusion = torch.abs(x1_level - x2_level)
|
|
|
down_x.append(fusion)
|
|
|
down_x.reverse()
|
|
|
x = self.decode(down_x[0], down_x)
|
|
|
return x
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
device = "cuda:0"
|
|
|
CDMamba = STNR(spatial_dims=2, in_channels=3, out_channels=2, init_filters=16, norm=("GROUP", {"num_groups": 8}),
|
|
|
mode="FUSION", conv_mode='orignal', local_query_model="orignal_dinner",
|
|
|
stage=4, mamba_act="silu", up_mode="SRCM", up_conv_mode='deepwise', blocks_down=(1, 2, 2, 4), blocks_up=(1, 1, 1),
|
|
|
resdiual=False, diff_abs="later").to(device)
|
|
|
x = torch.randn(1, 3, 256, 256).to(device)
|
|
|
y = CDMamba(x, x)
|
|
|
print(y.shape) |