import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat import math class MambaBlock(nn.Module): """ Production-ready Mamba block for graph processing Device-safe implementation with optimizations """ def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False): super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) if dt_rank == "auto": self.dt_rank = math.ceil(self.d_model / 16) else: self.dt_rank = dt_rank # Linear projections self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias) # Convolution for local patterns self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, bias=True, ) # SSM parameters self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) # Initialize A (state evolution matrix) A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) # Output projection self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias) # Activation self.act = nn.SiLU() # Initialize parameters self._init_parameters() def _init_parameters(self): """Initialize parameters with proper scaling""" # Initialize dt projection specially dt_init_std = self.dt_rank**-0.5 * self.d_state with torch.no_grad(): self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std) # Initialize other projections nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.x_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) def forward(self, x): """ x: (batch, length, d_model) Returns: (batch, length, d_model) """ batch, length, _ = x.shape device = x.device # Ensure all parameters are on correct device self.A_log = self.A_log.to(device) self.D = self.D.to(device) # Input projection and split xz = self.in_proj(x) # (batch, length, 2 * d_inner) x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner) # Convolution x = rearrange(x, 'b l d -> b d l') x = self.conv1d(x)[:, :, :length] x = rearrange(x, 'b d l -> b l d') x = self.act(x) # SSM y = self.selective_scan(x) # Gating y = y * self.act(z) # Output projection out = self.out_proj(y) return out def selective_scan(self, u): """Selective scan operation - core of Mamba""" batch, length, d_inner = u.shape device = u.device # Compute ∆, B, C x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state) delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # Softplus ensures delta > 0 delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner) return self._selective_scan_pytorch(u, delta, B, C) def _selective_scan_pytorch(self, u, delta, B, C): """PyTorch implementation of selective scan - device safe""" batch, length, d_inner = u.shape device = u.device # Ensure A_log and D are on correct device A_log = self.A_log.to(device) D = self.D.to(device) # Discretize deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state) deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state) # Initialize state x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype) ys = [] for i in range(length): x = deltaA[:, i] * x + deltaB_u[:, i] y = torch.einsum('bdn,bn->bd', x, C[:, i]) ys.append(y) y = torch.stack(ys, dim=1) # (batch, length, d_inner) # Add skip connection y = y + u * D return y