import os import time import math import copy from functools import partial from typing import Optional, Callable, Any from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" # triton cross scan, 2x speed than pytorch implementation ========================= from rscd.models.backbones.csm_triton import CrossScanTriton, CrossMergeTriton, CrossScanTriton1b1 # pytorch cross scan ============= class CrossScan(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) xs = x.new_empty((B, 4, C, H * W)) xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) return y.view(B, -1, H, W) class CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape xs = x.new_empty((B, 4, C, L)) xs[:, 0] = x xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) xs = xs.view(B, 4, C, H, W) return xs # these are for ablations ============= class CrossScan_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) x = torch.cat([x, x.flip(dims=[-1])], dim=1) return x @staticmethod def backward(ctx, ys: torch.Tensor): B, C, H, W = ctx.shape L = H * W ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) return ys.sum(1).view(B, -1, H, W) class CrossMerge_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) return ys.contiguous().sum(1) @staticmethod def backward(ctx, x: torch.Tensor): H, W = ctx.shape B, C, L = x.shape x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) x = torch.cat([x, x.flip(dims=[-1])], dim=1) return x.view(B, 4, C, H, W) class CrossScan_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) return x @staticmethod def backward(ctx, ys: torch.Tensor): B, C, H, W = ctx.shape return ys.view(B, 4, -1, H, W).sum(1) class CrossMerge_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, C, H, W = ys.shape ctx.shape = (B, C, H, W) return ys.view(B, 4, -1, H * W).sum(1) @staticmethod def backward(ctx, x: torch.Tensor): B, C, H, W = ctx.shape return x.view(B, 1, C, H, W).repeat(1, 4, 1, 1, 1) # import selective scan ============================== try: import selective_scan_cuda_oflex except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) # print(e, flush=True) try: import selective_scan_cuda_core except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) # print(e, flush=True) try: import selective_scan_cuda except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda.", flush=True) # print(e, flush=True) def check_nan_inf(tag: str, x: torch.Tensor, enable=True): if enable: if torch.isinf(x).any() or torch.isnan(x).any(): print(tag, torch.isinf(x).any(), torch.isnan(x).any(), flush=True) import pdb; pdb.set_trace() # fvcore flops ======================================= def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ assert not with_complex # https://github.com/state-spaces/mamba/issues/110 flops = 9 * B * L * D * N if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops # this is only for selective_scan_ref... def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops def print_jit_input_names(inputs): print("input params: ", end=" ", flush=True) try: for i in range(10): print(inputs[i].debugName(), end=" ", flush=True) except Exception as e: pass print("", flush=True) # cross selective scan =============================== # comment all checks if inside cross_selective_scan class SelectiveScanMamba(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanCore(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanOflex(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # ============= # Note: we did not use csm_triton in and before vssm1_0230, we used pytorch version ! # Note: we did not use no_einsum in and before vssm1_0230, we used einsum version ! def cross_selective_scan( x: torch.Tensor=None, x_proj_weight: torch.Tensor=None, x_proj_bias: torch.Tensor=None, dt_projs_weight: torch.Tensor=None, dt_projs_bias: torch.Tensor=None, A_logs: torch.Tensor=None, Ds: torch.Tensor=None, delta_softplus = True, out_norm: torch.nn.Module=None, out_norm_shape="v0", channel_first=False, # ============================== to_dtype=True, # True: final out to dtype force_fp32=False, # True: input fp32 # ============================== nrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; backnrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore # ============================== SelectiveScan=None, CrossScan=CrossScan, CrossMerge=CrossMerge, no_einsum=False, # replace einsum with linear or conv1d to raise throughput dt_low_rank=True, ): # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... B, D, H, W = x.shape D, N = A_logs.shape K, D, R = dt_projs_weight.shape L = H * W if nrows == 0: if D % 4 == 0: nrows = 4 elif D % 3 == 0: nrows = 3 elif D % 2 == 0: nrows = 2 else: nrows = 1 if backnrows == 0: if D % 4 == 0: backnrows = 4 elif D % 3 == 0: backnrows = 3 elif D % 2 == 0: backnrows = 2 else: backnrows = 1 def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex) if (not dt_low_rank): x_dbl = F.conv1d(x.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K) dts, Bs, Cs = torch.split(x_dbl.view(B, -1, L), [D, 4 * N, 4 * N], dim=1) xs = CrossScan.apply(x) dts = CrossScan.apply(dts) elif no_einsum: xs = CrossScan.apply(x) x_dbl = F.conv1d(xs.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K) dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2) dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * D, -1, 1), groups=K) else: xs = CrossScan.apply(x) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) if x_proj_bias is not None: x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) xs = xs.view(B, -1, L) dts = dts.contiguous().view(B, -1, L) As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) Bs = Bs.contiguous().view(B, K, N, L) Cs = Cs.contiguous().view(B, K, N, L) Ds = Ds.to(torch.float) # (K * c) delta_bias = dt_projs_bias.view(-1).to(torch.float) if force_fp32: xs = xs.to(torch.float) dts = dts.to(torch.float) Bs = Bs.to(torch.float) Cs = Cs.to(torch.float) ys: torch.Tensor = selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus ).view(B, K, -1, H, W) y: torch.Tensor = CrossMerge.apply(ys) if channel_first: y = y.view(B, -1, H, W) if out_norm_shape in ["v1"]: y = out_norm(y) else: y = out_norm(y.permute(0, 2, 3, 1)) y = y.permute(0, 3, 1, 2) return (y.to(x.dtype) if to_dtype else y) if out_norm_shape in ["v1"]: # (B, C, H, W) y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C) else: # (B, L, C) y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) y = out_norm(y).view(B, H, W, -1) return (y.to(x.dtype) if to_dtype else y) def selective_scan_flop_jit(inputs, outputs): print_jit_input_names(inputs) B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) return flops # ===================================================== # we have this class as linear and conv init differ from each other # this function enable loading from both conv2d or linear class Linear2d(nn.Linear): def forward(self, x: torch.Tensor): # B, C, H, W = x.shape return F.conv2d(x, self.weight[:, :, None, None], self.bias) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape) return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor): x = x.permute(0, 2, 3, 1) x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x class PatchMerging2D(nn.Module): def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) self.norm = norm_layer(4 * dim) @staticmethod def _patch_merging_pad(x: torch.Tensor): H, W, _ = x.shape[-3:] if (W % 2 != 0) or (H % 2 != 0): x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x def forward(self, x): x = self._patch_merging_pad(x) x = self.norm(x) x = self.reduction(x) return x class Permute(nn.Module): def __init__(self, *args): super().__init__() self.args = args def forward(self, x: torch.Tensor): return x.permute(*self.args) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = Linear2d if channels_first else nn.Linear self.fc1 = Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class gMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): super().__init__() self.channel_first = channels_first out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = Linear2d if channels_first else nn.Linear self.fc1 = Linear(in_features, 2 * hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor): x = self.fc1(x) x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) x = self.fc2(x * self.act(z)) x = self.drop(x) return x # ===================================================== class SS2D(nn.Module): def __init__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", channel_first=False, # ====================== **kwargs, ): kwargs.update( d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank, act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias, dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor, initialize=initialize, forward_type=forward_type, channel_first=channel_first, ) # only used to run previous version if forward_type.startswith("v0"): self.__initv0__(seq=("seq" in forward_type), **kwargs) return elif forward_type.startswith("xv"): self.__initxv__(**kwargs) return else: self.__initv2__(**kwargs) return # only used to run previous version def __initv0__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", # ====================== dropout=0.0, # ====================== seq=False, force_fp32=True, **kwargs, ): if "channel_first" in kwargs: assert not kwargs["channel_first"] act_layer = nn.SiLU dt_min = 0.001 dt_max = 0.1 dt_init = "random" dt_scale = 1.0 dt_init_floor = 1e-4 bias = False conv_bias = True d_conv = 3 k_group = 4 factory_kwargs = {"device": None, "dtype": None} super().__init__() d_inner = int(ssm_ratio * d_model) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.forward = self.forwardv0 if seq: self.forward = partial(self.forwardv0, seq=True) if not force_fp32: self.forward = partial(self.forwardv0, force_fp32=False) # in proj ============================ self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias, **factory_kwargs) self.act: nn.Module = act_layer() self.conv2d = nn.Conv2d( in_channels=d_inner, out_channels=d_inner, groups=d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) # x proj ============================ self.x_proj = [ nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) for _ in range(k_group) ] self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) del self.x_proj # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) # out proj ======================================= self.out_norm = nn.LayerNorm(d_inner) self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() def __initv2__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", channel_first=False, # ====================== **kwargs, ): factory_kwargs = {"device": None, "dtype": None} super().__init__() d_inner = int(ssm_ratio * d_model) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.d_conv = d_conv self.channel_first = channel_first Linear = Linear2d if channel_first else nn.Linear self.forward = self.forwardv2 # tags for forward_type ============================== def checkpostfix(tag, value): ret = value[-len(tag):] == tag if ret: value = value[:-len(tag)] return ret, value self.disable_force32, forward_type = checkpostfix("no32", forward_type) self.disable_z, forward_type = checkpostfix("noz", forward_type) self.disable_z_act, forward_type = checkpostfix("nozact", forward_type) # softmax | sigmoid | dwconv | norm =========================== self.out_norm_shape = "v1" if forward_type[-len("none"):] == "none": forward_type = forward_type[:-len("none")] self.out_norm = nn.Identity() elif forward_type[-len("dwconv3"):] == "dwconv3": forward_type = forward_type[:-len("dwconv3")] self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False) elif forward_type[-len("softmax"):] == "softmax": forward_type = forward_type[:-len("softmax")] class SoftmaxSpatial(nn.Softmax): def forward(self, x: torch.Tensor): B, C, H, W = x.shape return super().forward(x.view(B, C, -1)).view(B, C, H, W) self.out_norm = SoftmaxSpatial(dim=-1) elif forward_type[-len("sigmoid"):] == "sigmoid": forward_type = forward_type[:-len("sigmoid")] self.out_norm = nn.Sigmoid() elif channel_first: self.out_norm = LayerNorm2d(d_inner) else: self.out_norm_shape = "v0" self.out_norm = nn.LayerNorm(d_inner) # forward_type debug ======================================= FORWARD_TYPES = dict( v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba), v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore), v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex), v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction, ), v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction, ), v4=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, no_einsum=True, CrossScan=CrossScanTriton, CrossMerge=CrossMergeTriton), # =============================== v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex), ) self.forward_core = FORWARD_TYPES.get(forward_type, None) k_group = 4 # in proj ======================================= d_proj = d_inner if self.disable_z else (d_inner * 2) self.in_proj = Linear(d_model, d_proj, bias=bias, **factory_kwargs) self.act: nn.Module = act_layer() # conv ======================================= if d_conv > 1: self.conv2d = nn.Conv2d( in_channels=d_inner, out_channels=d_inner, groups=d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) # x proj ============================ self.x_proj = [ nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) for _ in range(k_group) ] self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) del self.x_proj # out proj ======================================= self.out_proj = Linear(d_inner, d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if initialize in ["v0"]: # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) elif initialize in ["v1"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) elif initialize in ["v2"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner))) def __initxv__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", channel_first=False, # ====================== **kwargs, ): factory_kwargs = {"device": None, "dtype": None} super().__init__() d_inner = int(ssm_ratio * d_model) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.d_conv = d_conv self.channel_first = channel_first self.d_state = d_state self.dt_rank = dt_rank self.d_inner = d_inner Linear = Linear2d if channel_first else nn.Linear self.forward = self.forwardxv # tags for forward_type ============================== def checkpostfix(tag, value): ret = value[-len(tag):] == tag if ret: value = value[:-len(tag)] return ret, value self.disable_force32, forward_type = checkpostfix("no32", forward_type) # softmax | sigmoid | dwconv | norm =========================== self.out_norm_shape = "v1" if forward_type[-len("none"):] == "none": forward_type = forward_type[:-len("none")] self.out_norm = nn.Identity() elif forward_type[-len("dwconv3"):] == "dwconv3": forward_type = forward_type[:-len("dwconv3")] self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False) elif forward_type[-len("softmax"):] == "softmax": forward_type = forward_type[:-len("softmax")] class SoftmaxSpatial(nn.Softmax): def forward(self, x: torch.Tensor): B, C, H, W = x.shape return super().forward(x.view(B, C, -1)).view(B, C, H, W) self.out_norm = SoftmaxSpatial(dim=-1) elif forward_type[-len("sigmoid"):] == "sigmoid": forward_type = forward_type[:-len("sigmoid")] self.out_norm = nn.Sigmoid() elif channel_first: self.out_norm = LayerNorm2d(d_inner) else: self.out_norm_shape = "v0" self.out_norm = nn.LayerNorm(d_inner) k_group = 4 # in proj ======================================= self.out_act: nn.Module = nn.Identity() # 0309 -> 0319 needs to be rerun... if False: # change Conv2d to Linear2d Next if forward_type.startswith("xv1"): self.in_proj = nn.Conv2d(d_model, d_inner + dt_rank + 8 * d_state, 1, bias=bias, **factory_kwargs) if forward_type.startswith("xv2"): self.in_proj = nn.Conv2d(d_model, d_inner + d_inner + 8 * d_state, 1, bias=bias, **factory_kwargs) self.forward = partial(self.forwardxv, mode="xv2") del self.dt_projs_weight if forward_type.startswith("xv3"): self.forward = partial(self.forwardxv, mode="xv3") self.in_proj = nn.Conv2d(d_model, d_inner + 4 * dt_rank + 8 * d_state, 1, bias=bias, **factory_kwargs) if forward_type.startswith("xv4"): self.forward = partial(self.forwardxv, mode="xv3") self.in_proj = nn.Conv2d(d_model, d_inner + 4 * dt_rank + 8 * d_state, 1, bias=bias, **factory_kwargs) self.out_act = nn.GELU() if forward_type.startswith("xv5"): self.in_proj = nn.Conv2d(d_model, d_inner + d_inner + 8 * d_state, 1, bias=bias, **factory_kwargs) self.forward = partial(self.forwardxv, mode="xv2") del self.dt_projs_weight self.out_act = nn.GELU() if forward_type.startswith("xv6"): self.forward = partial(self.forwardxv, mode="xv1") self.in_proj = nn.Conv2d(d_model, d_inner + dt_rank + 8 * d_state, 1, bias=bias, **factory_kwargs) self.out_act = nn.GELU() # to see if Linear2d and nn.Conv2d differ, as they will be inited differ if forward_type.startswith("xv61"): self.forward = partial(self.forwardxv, mode="xv1") self.in_proj = Linear2d(d_model, d_inner + dt_rank + 8 * d_state, bias=bias, **factory_kwargs) self.out_act = nn.GELU() if forward_type.startswith("xv7"): self.forward = partial(self.forwardxv, mode="xv1", omul=True) self.in_proj = Linear2d(d_model, d_inner + dt_rank + 8 * d_state, bias=bias, **factory_kwargs) self.out_act = nn.GELU() if True: omul, forward_type = checkpostfix("mul", forward_type) if omul: self.omul = nn.Identity() oact, forward_type = checkpostfix("act", forward_type) self.out_act = nn.GELU() if oact else nn.Identity() if forward_type.startswith("xv1a"): self.forward = partial(self.forwardxv, mode="xv1a", omul=omul) self.in_proj = Linear2d(d_model, d_inner + dt_rank + 8 * d_state, bias=bias, **factory_kwargs) if forward_type.startswith("xv2a"): self.forward = partial(self.forwardxv, mode="xv2a", omul=omul) self.in_proj = Linear2d(d_model, d_inner + d_inner + 8 * d_state,bias=bias, **factory_kwargs) if forward_type.startswith("xv3a"): self.forward = partial(self.forwardxv, mode="xv3a", omul=omul) self.in_proj = Linear2d(d_model, d_inner + 4 * dt_rank + 8 * d_state,bias=bias, **factory_kwargs) # conv ======================================= if d_conv > 1: self.conv2d = nn.Conv2d( in_channels=d_model, out_channels=d_model, groups=d_model, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) self.act: nn.Module = act_layer() # out proj ======================================= self.out_proj = Linear(d_inner, d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if initialize in ["v0"]: # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) elif initialize in ["v1"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) elif initialize in ["v2"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner))) if forward_type.startswith("xv2"): del self.dt_projs_weight @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit # dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 0: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=-1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 0: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D # only used to run previous version def forwardv0(self, x: torch.Tensor, SelectiveScan = SelectiveScanMamba, seq=False, force_fp32=True, **kwargs): x = self.in_proj(x) x, z = x.chunk(2, dim=-1) # (b, h, w, d) z = self.act(z) x = x.permute(0, 3, 1, 2).contiguous() x = self.conv2d(x) # (b, d, h, w) x = self.act(x) def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False) B, D, H, W = x.shape D, N = self.A_logs.shape K, D, R = self.dt_projs_weight.shape L = H * W x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) xs = xs.view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) Bs = Bs.contiguous() # (b, k, d_state, l) Cs = Cs.contiguous() # (b, k, d_state, l) As = -torch.exp(self.A_logs.float()) # (k * d, d_state) Ds = self.Ds.float() # (k * d) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) if force_fp32: xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) if seq: out_y = [] for i in range(4): yi = selective_scan( xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i], As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i], delta_bias=dt_projs_bias.view(K, -1)[i], delta_softplus=True, ).view(B, -1, L) out_y.append(yi) out_y = torch.stack(out_y, dim=1) else: out_y = selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias=dt_projs_bias, delta_softplus=True, ).view(B, K, -1, L) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) y = self.out_norm(y).view(B, H, W, -1) y = y * z out = self.dropout(self.out_proj(y)) return out def forward_corev2(self, x: torch.Tensor, cross_selective_scan=cross_selective_scan, **kwargs): x_proj_weight = self.x_proj_weight dt_projs_weight = self.dt_projs_weight dt_projs_bias = self.dt_projs_bias A_logs = self.A_logs Ds = self.Ds out_norm = getattr(self, "out_norm", None) out_norm_shape = getattr(self, "out_norm_shape", "v0") return cross_selective_scan( x, x_proj_weight, None, dt_projs_weight, dt_projs_bias, A_logs, Ds, delta_softplus=True, out_norm=out_norm, channel_first=self.channel_first, out_norm_shape=out_norm_shape, **kwargs, ) def forwardv2(self, x: torch.Tensor, **kwargs): with_dconv = (self.d_conv > 1) x = self.in_proj(x) if not self.disable_z: x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d) if not self.disable_z_act: z = self.act(z) if not self.channel_first: x = x.permute(0, 3, 1, 2).contiguous() if with_dconv: x = self.conv2d(x) # (b, d, h, w) x = self.act(x) y = self.forward_core(x) if not self.disable_z: y = y * z out = self.dropout(self.out_proj(y)) return out def forwardxv(self, x: torch.Tensor, mode="xv1a", omul=False, **kwargs): B, C, H, W = x.shape if not self.channel_first: B, H, W, C = x.shape L = H * W K = 4 dt_projs_weight = getattr(self, "dt_projs_weight", None) A_logs = self.A_logs dt_projs_bias = self.dt_projs_bias force_fp32 = False delta_softplus = True out_norm_shape = getattr(self, "out_norm_shape", "v0") out_norm = self.out_norm to_dtype = True Ds = self.Ds to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus): return SelectiveScanOflex.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, 1, True) if not self.channel_first: x = x.permute(0, 3, 1, 2).contiguous() if self.d_conv > 1: x = self.conv2d(x) # (b, d, h, w) x = self.act(x) x = self.in_proj(x) if mode in ["xv1", "xv2", "xv3", "xv7"]: print(f"ERROR: MODE {mode} will be deleted in the future, use {mode}a instead.") if mode in ["xv1"]: _us, dts, Bs, Cs = x.split([self.d_inner, self.dt_rank, 4 * self.d_state, 4 * self.d_state], dim=1) us = CrossScanTriton.apply(_us.contiguous()).view(B, -1, L) dts = CrossScanTriton.apply(dts.contiguous()).view(B, -1, L) dts = F.conv1d(dts, dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K).contiguous().view(B, -1, L) elif mode in ["xv2"]: _us, dts, Bs, Cs = x.split([self.d_inner, self.d_inner, 4 * self.d_state, 4 * self.d_state], dim=1) us = CrossScanTriton.apply(_us.contiguous()).view(B, -1, L) dts = CrossScanTriton.apply(dts).contiguous().view(B, -1, L) elif mode in ["xv3"]: _us, dts, Bs, Cs = x.split([self.d_inner, 4 * self.dt_rank, 4 * self.d_state, 4 * self.d_state], dim=1) us = CrossScanTriton.apply(_us.contiguous()).view(B, -1, L) dts = CrossScanTriton1b1.apply(dts.contiguous().view(B, K, -1, H, W)) dts = F.conv1d(dts.view(B, -1, L), dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K).contiguous().view(B, -1, L) else: ... if mode in ["xv1a"]: us, dts, Bs, Cs = x.split([self.d_inner, self.dt_rank, 4 * self.d_state, 4 * self.d_state], dim=1) _us = us us = CrossScanTriton.apply(us.contiguous()).view(B, 4, -1, L) dts = CrossScanTriton.apply(dts.contiguous()).view(B, 4, -1, L) Bs = CrossScanTriton1b1.apply(Bs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) Cs = CrossScanTriton1b1.apply(Cs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K) us, dts = us.contiguous().view(B, -1, L), dts _us = us.view(B, K, -1, H, W)[:, 0, :, :, :] elif mode in ["xv2a"]: us, dts, Bs, Cs = x.split([self.d_inner, self.d_inner, 4 * self.d_state, 4 * self.d_state], dim=1) _us = us us = CrossScanTriton.apply(us.contiguous()).view(B, 4, -1, L) dts = CrossScanTriton.apply(dts.contiguous()).view(B, 4, -1, L) Bs = CrossScanTriton1b1.apply(Bs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) Cs = CrossScanTriton1b1.apply(Cs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) us, dts = us.contiguous().view(B, -1, L), dts.contiguous().view(B, -1, L) elif mode in ["xv3a"]: # us, dtBCs = x.split([self.d_inner, 4 * self.dt_rank + 4 * self.d_state + 4 * self.d_state], dim=1) # _us = us # us = CrossScanTriton.apply(us.contiguous()).view(B, 4, -1, L) # dtBCs = CrossScanTriton1b1.apply(dtBCs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) # dts, Bs, Cs = dtBCs.split([self.dt_rank, self.d_state, self.d_state], dim=2) # dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K) # us, dts = us.contiguous().view(B, -1, L), dts us, dts, Bs, Cs = x.split([self.d_inner, 4 * self.dt_rank, 4 * self.d_state, 4 * self.d_state], dim=1) _us = us us = CrossScanTriton.apply(us.contiguous()).view(B, 4, -1, L) dts = CrossScanTriton1b1.apply(dts.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) Bs = CrossScanTriton1b1.apply(Bs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) Cs = CrossScanTriton1b1.apply(Cs.view(B, 4, -1, H, W).contiguous()).view(B, 4, -1, L) dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * self.d_inner, self.dt_rank, 1), None, groups=K) us, dts = us.contiguous().view(B, -1, L), dts else: ... Bs, Cs = Bs.view(B, K, -1, L).contiguous(), Cs.view(B, K, -1, L).contiguous() As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) Ds = Ds.to(torch.float) # (K * c) delta_bias = dt_projs_bias.view(-1).to(torch.float) # (K * c) if force_fp32: us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs) ys: torch.Tensor = selective_scan( us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus ).view(B, K, -1, H, W) y: torch.Tensor = CrossMergeTriton.apply(ys) y = y.view(B, -1, H, W) # originally: # y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) # y = out_norm(y).view(B, H, W, -1) if (not self.channel_first) or (out_norm_shape in ["v0"]): y = out_norm(y.permute(0, 2, 3, 1)) if self.channel_first: y = y.permute(0, 3, 1, 2) else: y = out_norm(y) y = (y.to(x.dtype) if to_dtype else y) y = self.out_act(y) if omul: y = y * (_us.permute(0, 2, 3, 1) if not self.channel_first else _us) out = self.dropout(self.out_proj(y)) return out class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, norm_layer: nn.Module = nn.LayerNorm, channel_first=False, # ============================= ssm_d_state: int = 16, ssm_ratio=2.0, ssm_dt_rank: Any = "auto", ssm_act_layer=nn.SiLU, ssm_conv: int = 3, ssm_conv_bias=True, ssm_drop_rate: float = 0, ssm_init="v0", forward_type="v2", # ============================= mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate: float = 0.0, gmlp=False, # ============================= use_checkpoint: bool = False, post_norm: bool = False, **kwargs, ): super().__init__() self.ssm_branch = ssm_ratio > 0 self.mlp_branch = mlp_ratio > 0 self.use_checkpoint = use_checkpoint self.post_norm = post_norm if self.ssm_branch: self.norm = norm_layer(hidden_dim) self.op = SS2D( d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio, dt_rank=ssm_dt_rank, act_layer=ssm_act_layer, # ========================== d_conv=ssm_conv, conv_bias=ssm_conv_bias, # ========================== dropout=ssm_drop_rate, # bias=False, # ========================== # dt_min=0.001, # dt_max=0.1, # dt_init="random", # dt_scale="random", # dt_init_floor=1e-4, initialize=ssm_init, # ========================== forward_type=forward_type, channel_first=channel_first, ) self.drop_path = DropPath(drop_path) if self.mlp_branch: _MLP = Mlp if not gmlp else gMlp self.norm2 = norm_layer(hidden_dim) mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first) def _forward(self, input: torch.Tensor): if self.ssm_branch: if self.post_norm: x = input + self.drop_path(self.norm(self.op(input))) else: x = input + self.drop_path(self.op(self.norm(input))) if self.mlp_branch: if self.post_norm: x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN else: x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN return x def forward(self, input: torch.Tensor): if self.use_checkpoint: return checkpoint.checkpoint(self._forward, input) else: return self._forward(input) class VSSM(nn.Module): def __init__( self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], # ========================= ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # ========================= mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, # ========================= drop_path_rate=0.1, patch_norm=True, norm_layer="LN", # "BN", "LN2D" downsample_version: str = "v2", # "v1", "v2", "v3" patchembed_version: str = "v1", # "v1", "v2" use_checkpoint=False, **kwargs, ): super().__init__() self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) self.num_classes = num_classes self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.num_features = dims[-1] self.dims = dims dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule _NORMLAYERS = dict( ln=nn.LayerNorm, ln2d=LayerNorm2d, bn=nn.BatchNorm2d, ) _ACTLAYERS = dict( silu=nn.SiLU, gelu=nn.GELU, relu=nn.ReLU, sigmoid=nn.Sigmoid, ) norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None) mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None) _make_patch_embed = dict( v1=self._make_patch_embed, v2=self._make_patch_embed_v2, ).get(patchembed_version, None) self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first) _make_downsample = dict( v1=PatchMerging2D, v2=self._make_downsample, v3=self._make_downsample_v3, none=(lambda *_, **_k: None), ).get(downsample_version, None) self.layers = nn.ModuleList() for i_layer in range(self.num_layers): downsample = _make_downsample( self.dims[i_layer], self.dims[i_layer + 1], norm_layer=norm_layer, channel_first=self.channel_first, ) if (i_layer < self.num_layers - 1) else nn.Identity() self.layers.append(self._make_layer( dim = self.dims[i_layer], drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], use_checkpoint=use_checkpoint, norm_layer=norm_layer, downsample=downsample, channel_first=self.channel_first, # ================= ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, # ================= mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, gmlp=gmlp, )) self.classifier = nn.Sequential(OrderedDict( norm=norm_layer(self.num_features), # B,H,W,C permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()), avgpool=nn.AdaptiveAvgPool2d(1), flatten=nn.Flatten(1), head=nn.Linear(self.num_features, num_classes), )) self.apply(self._init_weights) def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) # used in building optimizer # @torch.jit.ignore # def no_weight_decay(self): # return {} # used in building optimizer # @torch.jit.ignore # def no_weight_decay_keywords(self): # return {} @staticmethod def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False): # if channel first, then Norm and Output are both channel_first return nn.Sequential( nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), (norm_layer(embed_dim) if patch_norm else nn.Identity()), ) @staticmethod def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False): # if channel first, then Norm and Output are both channel_first assert patch_size == 4 return nn.Sequential( nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1), (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)), (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)), nn.GELU(), nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), (norm_layer(embed_dim) if patch_norm else nn.Identity()), ) @staticmethod def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): # if channel first, then Norm and Output are both channel_first return nn.Sequential( (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), norm_layer(out_dim), ) @staticmethod def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): # if channel first, then Norm and Output are both channel_first return nn.Sequential( (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), norm_layer(out_dim), ) @staticmethod def _make_layer( dim=96, drop_path=[0.1, 0.1], use_checkpoint=False, norm_layer=nn.LayerNorm, downsample=nn.Identity(), channel_first=False, # =========================== ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer=nn.SiLU, ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # =========================== mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate=0.0, gmlp=False, **kwargs, ): # if channel first, then Norm and Output are both channel_first depth = len(drop_path) blocks = [] for d in range(depth): blocks.append(VSSBlock( hidden_dim=dim, drop_path=drop_path[d], norm_layer=norm_layer, channel_first=channel_first, ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, gmlp=gmlp, use_checkpoint=use_checkpoint, )) return nn.Sequential(OrderedDict( blocks=nn.Sequential(*blocks,), downsample=downsample, )) def forward(self, x: torch.Tensor): x = self.patch_embed(x) for layer in self.layers: x = layer(x) x = self.classifier(x) return x def flops(self, shape=(3, 224, 224)): # shape = self.__input_shape__[1:] supported_ops={ "aten::silu": None, # as relu is in _IGNORED_OPS "aten::neg": None, # as relu is in _IGNORED_OPS "aten::exp": None, # as relu is in _IGNORED_OPS "aten::flip": None, # as permute is in _IGNORED_OPS # "prim::PythonOp.CrossScan": None, # "prim::PythonOp.CrossMerge": None, "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit, "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit, "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit, "prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit, } model = copy.deepcopy(self) model.cuda().eval() input = torch.randn((1, *shape), device=next(model.parameters()).device) params = parameter_count(model)[""] Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) del model, input return sum(Gflops.values()) * 1e9 return f"params {params} GFLOPs {sum(Gflops.values())}" # used to load ckpt from previous training code def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): def check_name(src, state_dict: dict = state_dict, strict=False): if strict: if prefix + src in list(state_dict.keys()): return True else: key = prefix + src for k in list(state_dict.keys()): if k.startswith(key): return True return False def change_name(src, dst, state_dict: dict = state_dict, strict=False): if strict: if prefix + src in list(state_dict.keys()): state_dict[prefix + dst] = state_dict[prefix + src] state_dict.pop(prefix + src) else: key = prefix + src for k in list(state_dict.keys()): if k.startswith(key): new_k = prefix + dst + k[len(key):] state_dict[new_k] = state_dict[k] state_dict.pop(k) change_name("patch_embed.proj", "patch_embed.0") change_name("patch_embed.norm", "patch_embed.2") for i in range(100): for j in range(100): change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") change_name("norm", "classifier.norm") change_name("head", "classifier.head") return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # compatible with openmmlab class Backbone_VSSM(VSSM): def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs): kwargs.update(norm_layer=norm_layer) super().__init__(**kwargs) self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) _NORMLAYERS = dict( ln=nn.LayerNorm, ln2d=LayerNorm2d, bn=nn.BatchNorm2d, ) norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) self.out_indices = out_indices for i in out_indices: layer = norm_layer(self.dims[i]) layer_name = f'outnorm{i}' self.add_module(layer_name, layer) del self.classifier self.load_pretrained(pretrained) def load_pretrained(self, ckpt=None, key="model"): if ckpt is None: return try: _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) print(f"Successfully load ckpt {ckpt}") incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) print(incompatibleKeys) except Exception as e: print(f"Failed loading checkpoint form {ckpt}: {e}") def forward(self, x): def layer_forward(l, x): x = l.blocks(x) y = l.downsample(x) return x, y x = self.patch_embed(x) outs = [] for i, layer in enumerate(self.layers): o, x = layer_forward(layer, x) # (B, H, W, C) if i in self.out_indices: norm_layer = getattr(self, f'outnorm{i}') out = norm_layer(o) if not self.channel_first: out = out.permute(0, 3, 1, 2).contiguous() outs.append(out) if len(self.out_indices) == 0: return x return outs