|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
import math |
|
|
|
class MambaBlock(nn.Module): |
|
""" |
|
Production-ready Mamba block for graph processing |
|
Device-safe implementation with optimizations |
|
""" |
|
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.d_state = d_state |
|
self.d_conv = d_conv |
|
self.expand = expand |
|
self.d_inner = int(self.expand * self.d_model) |
|
|
|
if dt_rank == "auto": |
|
self.dt_rank = math.ceil(self.d_model / 16) |
|
else: |
|
self.dt_rank = dt_rank |
|
|
|
|
|
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias) |
|
|
|
|
|
self.conv1d = nn.Conv1d( |
|
in_channels=self.d_inner, |
|
out_channels=self.d_inner, |
|
kernel_size=d_conv, |
|
groups=self.d_inner, |
|
padding=d_conv - 1, |
|
bias=True, |
|
) |
|
|
|
|
|
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) |
|
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) |
|
|
|
|
|
A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner) |
|
self.A_log = nn.Parameter(torch.log(A)) |
|
self.D = nn.Parameter(torch.ones(self.d_inner)) |
|
|
|
|
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias) |
|
|
|
|
|
self.act = nn.SiLU() |
|
|
|
|
|
self._init_parameters() |
|
|
|
def _init_parameters(self): |
|
"""Initialize parameters with proper scaling""" |
|
|
|
dt_init_std = self.dt_rank**-0.5 * self.d_state |
|
with torch.no_grad(): |
|
self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std) |
|
|
|
|
|
nn.init.xavier_uniform_(self.in_proj.weight) |
|
nn.init.xavier_uniform_(self.x_proj.weight) |
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
|
|
def forward(self, x): |
|
""" |
|
x: (batch, length, d_model) |
|
Returns: (batch, length, d_model) |
|
""" |
|
batch, length, _ = x.shape |
|
device = x.device |
|
|
|
|
|
self.A_log = self.A_log.to(device) |
|
self.D = self.D.to(device) |
|
|
|
|
|
xz = self.in_proj(x) |
|
x, z = xz.chunk(2, dim=-1) |
|
|
|
|
|
x = rearrange(x, 'b l d -> b d l') |
|
x = self.conv1d(x)[:, :, :length] |
|
x = rearrange(x, 'b d l -> b l d') |
|
x = self.act(x) |
|
|
|
|
|
y = self.selective_scan(x) |
|
|
|
|
|
y = y * self.act(z) |
|
|
|
|
|
out = self.out_proj(y) |
|
|
|
return out |
|
|
|
def selective_scan(self, u): |
|
"""Selective scan operation - core of Mamba""" |
|
batch, length, d_inner = u.shape |
|
device = u.device |
|
|
|
|
|
x_dbl = self.x_proj(u) |
|
delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) |
|
|
|
|
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
|
return self._selective_scan_pytorch(u, delta, B, C) |
|
|
|
def _selective_scan_pytorch(self, u, delta, B, C): |
|
"""PyTorch implementation of selective scan - device safe""" |
|
batch, length, d_inner = u.shape |
|
device = u.device |
|
|
|
|
|
A_log = self.A_log.to(device) |
|
D = self.D.to(device) |
|
|
|
|
|
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) |
|
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) |
|
|
|
|
|
x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype) |
|
ys = [] |
|
|
|
for i in range(length): |
|
x = deltaA[:, i] * x + deltaB_u[:, i] |
|
y = torch.einsum('bdn,bn->bd', x, C[:, i]) |
|
ys.append(y) |
|
|
|
y = torch.stack(ys, dim=1) |
|
|
|
|
|
y = y + u * D |
|
|
|
return y |