File size: 4,818 Bytes
159f602 3c6b427 159f602 3c6b427 159f602 069fc7a 159f602 069fc7a 159f602 069fc7a 159f602 069fc7a 159f602 069fc7a 159f602 069fc7a 159f602 069fc7a 159f602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math
class MambaBlock(nn.Module):
"""
Production-ready Mamba block for graph processing
Device-safe implementation with optimizations
"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
if dt_rank == "auto":
self.dt_rank = math.ceil(self.d_model / 16)
else:
self.dt_rank = dt_rank
# Linear projections
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
# Convolution for local patterns
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
bias=True,
)
# SSM parameters
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# Initialize A (state evolution matrix)
A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
# Activation
self.act = nn.SiLU()
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
"""Initialize parameters with proper scaling"""
# Initialize dt projection specially
dt_init_std = self.dt_rank**-0.5 * self.d_state
with torch.no_grad():
self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
# Initialize other projections
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.xavier_uniform_(self.x_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
def forward(self, x):
"""
x: (batch, length, d_model)
Returns: (batch, length, d_model)
"""
batch, length, _ = x.shape
device = x.device
# Ensure all parameters are on correct device
self.A_log = self.A_log.to(device)
self.D = self.D.to(device)
# Input projection and split
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
# Convolution
x = rearrange(x, 'b l d -> b d l')
x = self.conv1d(x)[:, :, :length]
x = rearrange(x, 'b d l -> b l d')
x = self.act(x)
# SSM
y = self.selective_scan(x)
# Gating
y = y * self.act(z)
# Output projection
out = self.out_proj(y)
return out
def selective_scan(self, u):
"""Selective scan operation - core of Mamba"""
batch, length, d_inner = u.shape
device = u.device
# Compute ∆, B, C
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Softplus ensures delta > 0
delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner)
return self._selective_scan_pytorch(u, delta, B, C)
def _selective_scan_pytorch(self, u, delta, B, C):
"""PyTorch implementation of selective scan - device safe"""
batch, length, d_inner = u.shape
device = u.device
# Ensure A_log and D are on correct device
A_log = self.A_log.to(device)
D = self.D.to(device)
# Discretize
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state)
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
# Initialize state
x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
ys = []
for i in range(length):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = torch.einsum('bdn,bn->bd', x, C[:, i])
ys.append(y)
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
# Add skip connection
y = y + u * D
return y |