# Copyright (c) 2023, Tri Dao, Albert Gu. import numbers from mamba_ssm.modules.mamba_simple import Mamba import warnings warnings.filterwarnings("ignore") from timm.models.layers import DropPath, to_2tuple import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn, causal_conv1d_update = None try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj except ImportError: selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: selective_state_update = None try: from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None class LightweightModel(nn.Module): def __init__(self, in_channels, out_channels): super(LightweightModel, self).__init__() self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels) self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) return x class ConvMamba(nn.Module): def __init__( self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=True, layer_idx=None, device=None, dtype=None, bimamba_type="none", conv_mode = "deepwise" ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.conv_mode = conv_mode self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.bimamba_type = bimamba_type if self.conv_mode == "orignal": self.local_relation = nn.Sequential( nn.Conv2d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), ) elif self.conv_mode == "orignal_1_5_dmodel": self.local_relation = nn.Sequential( nn.Conv2d(in_channels=self.d_model, out_channels=int(1.5*self.d_model), kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=int(1.5*self.d_model), out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), ) elif self.conv_mode == "orignal_dinner": self.local_relation = nn.Sequential( nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1), ) elif self.conv_mode == "deepwise": self.local_relation = nn.Sequential( LightweightModel(in_channels=self.d_model, out_channels=self.d_model), nn.SiLU(), LightweightModel(in_channels=self.d_model, out_channels=self.d_inner), ) elif self.conv_mode == "deepwise_dinner": self.local_relation = nn.Sequential( LightweightModel(in_channels=self.d_model, out_channels=self.d_inner), nn.SiLU(), LightweightModel(in_channels=self.d_inner, out_channels=self.d_inner), ) self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) self.activation = "silu" self.act = nn.SiLU() self.x_proj = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs ) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True # S4D real initialization A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D._no_weight_decay = True # bidirectional assert bimamba_type == "v2" A_b = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_b_log = torch.log(A_b) # Keep A_b_log in fp32 self.A_b_log = nn.Parameter(A_b_log) self.A_b_log._no_weight_decay = True self.conv1d_b = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) self.x_proj_b = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs ) self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D_b._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) def forward(self, hidden_states, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape h = int(math.sqrt(seqlen)) local_relation = self.local_relation(rearrange(hidden_states, "b (h w) d -> b d h w", h=h)) local_relation = rearrange(local_relation, "b d h w -> b d (h w)") conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out # We do matmul and transpose BLH -> HBL at the same time xz = rearrange( self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), "d (b l) -> b d l", l=seqlen, ) if self.in_proj.bias is not None: xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat if self.use_fast_path and inference_params is None: # Doesn't support outputting the states if self.bimamba_type == "v2": A_b = -torch.exp(self.A_b_log.float()) out = mamba_inner_fn_no_out_proj( xz, self.conv1d.weight, self.conv1d.bias, self.x_proj.weight, self.dt_proj.weight, A, None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) out_b = mamba_inner_fn_no_out_proj( xz.flip([-1]), self.conv1d_b.weight, self.conv1d_b.bias, self.x_proj_b.weight, self.dt_proj_b.weight, A_b, None, None, self.D_b.float(), delta_bias=self.dt_proj_b.bias.float(), delta_softplus=True, ) # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) out = F.linear(rearrange(out + out_b.flip([-1]) + local_relation, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) else: out = mamba_inner_fn( xz, self.conv1d.weight, self.conv1d.bias, self.x_proj.weight, self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, A, None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) else: x, z = xz.chunk(2, dim=1) # Compute short convolution if conv_state is not None: conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, ) if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out def step(self, hidden_states, conv_state, ssm_state): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) x, z = xz.chunk(2, dim=-1) # (B D) # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) conv_state[:, :, -1] = x x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: x = x + self.conv1d.bias x = self.act(x).to(dtype=dtype) else: x = causal_conv1d_update( x, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) x_db = self.x_proj(x) # (B dt_rank+2*d_state) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # Don't add dt_bias here dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # SSM step if selective_state_update is None: # Discretize A and B dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) dB = torch.einsum("bd,bn->bdn", dt, B) ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) y = y + self.D.to(dtype) * x y = y * self.act(z) # (B D) else: y = selective_state_update( ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ) out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype ) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # ssm_dtype = torch.float32 ssm_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype ) return conv_state, ssm_state def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: batch_shape = (batch_size,) conv_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_conv, device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ) ssm_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_state, device=self.dt_proj.weight.device, dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state def to_3d(x): return rearrange(x, 'b c h w -> b (h w) c') def to_4d(x, h, w): return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) self.weight = nn.Parameter(torch.ones(normalized_shape)) def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class LayerNorm(nn.Module): def __init__(self, dim, norm_type='with_bias'): super(LayerNorm, self).__init__() if norm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): if len(x.shape) == 4: h, w = x.shape[-2:] return to_4d(self.body(to_3d(x)), h, w) else: return self.body(x) class M3(nn.Module): def __init__(self, dim): super(M3, self).__init__() self.multi_modal_mamba_block = Mamba(dim, bimamba_type="m3") self.norm1 = LayerNorm(dim, 'with_bias')# fusion self.norm2 = LayerNorm(dim, 'with_bias')# I2 self.norm3 = LayerNorm(dim, 'with_bias')# I1 self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) def forward(self, I1, I2, fusion, test_h, test_w): fusion = self.norm1(fusion) I2 = self.norm2(I2) I1 = self.norm3(I1) global_f = self.multi_modal_mamba_block(fusion, extra_emb1=I2, extra_emb2=I1)# [B, HW, C] B, HW, C = global_f.shape fusion = global_f.transpose(1, 2).view(B, C, test_h, test_w) fusion = (self.dwconv(fusion) + fusion).flatten(2).transpose(1, 2) return fusion, None class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super(PatchEmbed, self).__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.norm = norm_layer(embed_dim) if norm_layer is not None else None def forward(self, x): # x: [B, C, H, W] x = x.flatten(2).transpose(1, 2) # [B, N, C] if self.norm is not None: x = self.norm(x) return x class PatchUnEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super(PatchUnEmbed, self).__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim def forward(self, x, x_size): B, HW, C = x.shape x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) return x class Block(nn.Module): def __init__( self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False ): """ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" This Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA/MLP -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Add -> LN -> Mixer, returning both the hidden_states (output of the mixer) and the residual. This is purely for performance reasons, as we can fuse add and LayerNorm. The residual needs to be provided (except for the very first block). """ super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.mixer = mixer_cls(dim) self.norm = norm_cls(dim) if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" assert isinstance( self.norm, (nn.LayerNorm, RMSNorm) ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: hidden_states = Mixer(LN(residual)) """ if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn hidden_states, residual = fused_add_norm_fn( hidden_states, self.norm.weight, self.norm.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) hidden_states = self.mixer(hidden_states, inference_params=inference_params) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)