File size: 4,818 Bytes
159f602
 
 
 
 
 
 
 
 
3c6b427
159f602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c6b427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159f602
 
 
 
 
 
069fc7a
 
 
 
 
159f602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069fc7a
159f602
 
 
 
 
 
 
 
 
 
 
069fc7a
159f602
069fc7a
 
 
 
 
159f602
 
069fc7a
159f602
 
 
069fc7a
159f602
 
 
 
 
 
 
 
 
 
069fc7a
159f602
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math

class MambaBlock(nn.Module):
    """
    Production-ready Mamba block for graph processing
    Device-safe implementation with optimizations
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        
        if dt_rank == "auto":
            self.dt_rank = math.ceil(self.d_model / 16)
        else:
            self.dt_rank = dt_rank
            
        # Linear projections
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
        
        # Convolution for local patterns
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            bias=True,
        )
        
        # SSM parameters
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        
        # Initialize A (state evolution matrix)
        A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        
        # Output projection
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
        
        # Activation
        self.act = nn.SiLU()
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        """Initialize parameters with proper scaling"""
        # Initialize dt projection specially
        dt_init_std = self.dt_rank**-0.5 * self.d_state
        with torch.no_grad():
            self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
            
        # Initialize other projections
        nn.init.xavier_uniform_(self.in_proj.weight)
        nn.init.xavier_uniform_(self.x_proj.weight)
        nn.init.xavier_uniform_(self.out_proj.weight)
        
    def forward(self, x):
        """
        x: (batch, length, d_model)
        Returns: (batch, length, d_model)
        """
        batch, length, _ = x.shape
        device = x.device
        
        # Ensure all parameters are on correct device
        self.A_log = self.A_log.to(device)
        self.D = self.D.to(device)
        
        # Input projection and split
        xz = self.in_proj(x)  # (batch, length, 2 * d_inner)
        x, z = xz.chunk(2, dim=-1)  # Each: (batch, length, d_inner)
        
        # Convolution
        x = rearrange(x, 'b l d -> b d l')
        x = self.conv1d(x)[:, :, :length]
        x = rearrange(x, 'b d l -> b l d')
        x = self.act(x)
        
        # SSM
        y = self.selective_scan(x)
        
        # Gating
        y = y * self.act(z)
        
        # Output projection
        out = self.out_proj(y)
        
        return out
    
    def selective_scan(self, u):
        """Selective scan operation - core of Mamba"""
        batch, length, d_inner = u.shape
        device = u.device
        
        # Compute ∆, B, C
        x_dbl = self.x_proj(u)  # (batch, length, dt_rank + 2*d_state)
        delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        
        # Softplus ensures delta > 0
        delta = F.softplus(self.dt_proj(delta))  # (batch, length, d_inner)
        
        return self._selective_scan_pytorch(u, delta, B, C)
    
    def _selective_scan_pytorch(self, u, delta, B, C):
        """PyTorch implementation of selective scan - device safe"""
        batch, length, d_inner = u.shape
        device = u.device
        
        # Ensure A_log and D are on correct device
        A_log = self.A_log.to(device)
        D = self.D.to(device)
        
        # Discretize
        deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log)))  # (batch, length, d_inner, d_state)
        deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)  # (batch, length, d_inner, d_state)
        
        # Initialize state
        x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
        ys = []
        
        for i in range(length):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = torch.einsum('bdn,bn->bd', x, C[:, i])
            ys.append(y)
        
        y = torch.stack(ys, dim=1)  # (batch, length, d_inner)
        
        # Add skip connection
        y = y + u * D
        
        return y