serpent / core /mamba_block.py
kfoughali's picture
Update core/mamba_block.py
3c6b427 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math
class MambaBlock(nn.Module):
"""
Production-ready Mamba block for graph processing
Device-safe implementation with optimizations
"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
if dt_rank == "auto":
self.dt_rank = math.ceil(self.d_model / 16)
else:
self.dt_rank = dt_rank
# Linear projections
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
# Convolution for local patterns
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
bias=True,
)
# SSM parameters
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# Initialize A (state evolution matrix)
A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
# Activation
self.act = nn.SiLU()
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
"""Initialize parameters with proper scaling"""
# Initialize dt projection specially
dt_init_std = self.dt_rank**-0.5 * self.d_state
with torch.no_grad():
self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
# Initialize other projections
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.xavier_uniform_(self.x_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
def forward(self, x):
"""
x: (batch, length, d_model)
Returns: (batch, length, d_model)
"""
batch, length, _ = x.shape
device = x.device
# Ensure all parameters are on correct device
self.A_log = self.A_log.to(device)
self.D = self.D.to(device)
# Input projection and split
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
# Convolution
x = rearrange(x, 'b l d -> b d l')
x = self.conv1d(x)[:, :, :length]
x = rearrange(x, 'b d l -> b l d')
x = self.act(x)
# SSM
y = self.selective_scan(x)
# Gating
y = y * self.act(z)
# Output projection
out = self.out_proj(y)
return out
def selective_scan(self, u):
"""Selective scan operation - core of Mamba"""
batch, length, d_inner = u.shape
device = u.device
# Compute ∆, B, C
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Softplus ensures delta > 0
delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner)
return self._selective_scan_pytorch(u, delta, B, C)
def _selective_scan_pytorch(self, u, delta, B, C):
"""PyTorch implementation of selective scan - device safe"""
batch, length, d_inner = u.shape
device = u.device
# Ensure A_log and D are on correct device
A_log = self.A_log.to(device)
D = self.D.to(device)
# Discretize
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state)
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
# Initialize state
x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
ys = []
for i in range(length):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = torch.einsum('bdn,bn->bd', x, C[:, i])
ys.append(y)
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
# Add skip connection
y = y + u * D
return y