File size: 9,231 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class KANLayer(nn.Module):
    """

    KAN-layer using b-splines basis based on this paper:



    https://arxiv.org/abs/2404.19756



    This architecture fundamentally differs from MLPs by replacing fixed activation functions 

    with learnable univariate functions represented as B-spline bases 

    by decomposing multivariate functions into sums of univariate function.





    Using the model presented in the original paper turned out to be very impractical 

    due to the need to integrate many dependencies that contradict each other within the current project,

    so KAN-layer was rewritten from scratch with a focus on a more stable implementation on PyTorch with CUDA support:



    https://github.com/Blealtan/efficient-kan  



    Where:

        

    - B-spline parameterization enables continuous, piecewise-polynomial functions;

    - Grid-based routing follows the theoretical foundation of Kolmogorov's theorem;

    - Dual-path architecture (base + spline) enhances model expressivity;

    - Normalization of B-spline bases and grid perturbation thresholding (grid_eps) to prevent division-by-zero errors.

    """
    
    def __init__(

        self,

        in_features: int,                                       

        out_features: int,                                      

        grid_size: int = 5,                                     

        spline_order: int = 3,

        scale_noise: float = 0.1,

        scale_base: float = 1.0,

        scale_spline: float = 1.0,

        enable_standalone_scale_spline: bool = True,

        base_activation: nn.Module = nn.SiLU,

        grid_eps: float = 0.02,

        grid_range: list = [-1, 1],

    ):
        super(KANLayer, self).__init__()
        self.in_features = in_features                                          # Size of input vector
        self.out_features = out_features                                        # Output vector size
        self.grid_size = grid_size                                              # Number of knot intervals for B-spline basis
        self.spline_order = spline_order                                        # Degree of B-spline polynomials (k=3: cubic splines)
        self.scale_noise = scale_noise                                          # Noise scaling for numerical stability during training
        self.scale_base = scale_base                                            # Linear transformation scaling
        self.scale_spline = scale_spline                                        # Spline path scaling
        self.enable_standalone_scale_spline = enable_standalone_scale_spline    # Optional standalone scaling mechanism for spline weights
        self.base_activation = base_activation()                                # Base activation function (SiLU chosen for its smoothness properties)
        self.grid_eps = grid_eps                                                # Grid perturbation threshold for numerical stability

        # B-spline Grid Construction
        h = (grid_range[1] - grid_range[0]) / grid_size
        # Grid: grid_size + 2 * spline_order + 1 point
        # Extended grid with boundary padding for B-spline continuity
        grid = torch.arange(-spline_order, grid_size + spline_order + 1, dtype=torch.float) * h + grid_range[0]
        grid = grid.unsqueeze(0).expand(in_features, -1).contiguous()  # (in_features, grid_size + 2*spline_order + 1)
        self.register_buffer("grid", grid)

        # Linear transformation equivalent to traditional neural networks (Base Weight)
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.base_bias = nn.Parameter(torch.Tensor(out_features))

        # Learnable B-spline coefficients (Spline Weight):
        # (out_features, in_features, grid_size + spline_order)
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
        if enable_standalone_scale_spline:
            # Initialize scale in 1.0
            self.spline_scaler = nn.Parameter(torch.ones(out_features, in_features))

        # LayerNormalization for outputs
        self.norm_base = nn.LayerNorm(out_features)
        self.norm_spline = nn.LayerNorm(out_features)
        # Initialize parameters
        self.reset_parameters()

    def reset_parameters(self):
        """

        Parameter initialization strategy combining:

        - base_weight: Kaiming initialization for base weights with SiLU gain adjustment;

        - base_bias: Zero initialization for biases;

        - spline_weight: Small random initialization for spline weights;

        - spline_scaler: ones (if it's standalone).

        """
        # Gain adjustment for SiLU activation
        gain = math.sqrt(2.0)  # SiLU ~ ReLU gain
        nn.init.kaiming_uniform_(self.base_weight, a=0, mode='fan_in', nonlinearity='relu')
        self.base_weight.data.mul_(self.scale_base * gain)
        nn.init.zeros_(self.base_bias)

        # Small random initialization for spline weights to break symmetry
        nn.init.uniform_(self.spline_weight, -self.scale_noise, self.scale_noise)

        if self.enable_standalone_scale_spline:
            # Identity initialization for spline scalers
            pass

    def b_splines(self, x: torch.Tensor) -> torch.Tensor:
        """

        Compute B-spline basis functions using Cox-de Boor recursion.

        Normalization ensures partition-of-unity property for numerical stability

        Args:

            x: Input tensor of shape (N, in_features)

        Returns:

            bases: B-spline basis tensor of shape (N, in_features, grid_size + spline_order)

        """
        N = x.shape[0]
        # Expand grid to match input dimensions:
        # grid: (in_features, grid_points) -> (N, in_features, grid_points)
        grid = self.grid.unsqueeze(0).expand(N, -1, -1)  # (N, in_features, G)
        x_exp = x.unsqueeze(2)  # (N, in_features, 1)

         # Initial basis (zeroth order)
        bases = ((x_exp >= grid[:, :, :-1]) & (x_exp < grid[:, :, 1:])).to(x.dtype)  # (N, in_features, G-1)

        # Cox-de Boor Recursive B-spline construction
        for k in range(1, self.spline_order + 1):
            left_num = x_exp - grid[:, :, :-(k + 1)]
            left_den = grid[:, :, k:-1] - grid[:, :, :-(k + 1)] + 1e-8
            term1 = (left_num / left_den) * bases[:, :, :-1]

            right_num = grid[:, :, k + 1:] - x_exp
            right_den = grid[:, :, k + 1:] - grid[:, :, 1:-k] + 1e-8
            term2 = (right_num / right_den) * bases[:, :, 1:]
            bases = term1 + term2  # (N, in_features, grid_size + spline_order)

        # Normalize to maintain numerical stability = 1
        bases = bases / (bases.sum(dim=2, keepdim=True) + 1e-8)
        return bases.contiguous()  # (N, in_features, grid_size + spline_order)

    @property
    def scaled_spline_weight(self) -> torch.Tensor:
        """

        Apply scaling to spline weights if standalone scaling is enabled (adapting scaling mechnism)

        """
        if self.enable_standalone_scale_spline:
            # (out_features, in_features, grid_size + spline_order) *
            # (out_features, in_features, 1) -> (out_features, in_features, grid_size + spline_order)
            return self.spline_weight * self.spline_scaler.unsqueeze(-1)
        else:
            return self.spline_weight

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Forward pass combining base and spline paths:

        - Base path: Linear transformation with SiLU activation

        - Spline path: B-spline basis expansion with learned coefficients

        - Path combination through summation

        """
        orig_shape = x.shape
        x_flat = x.reshape(-1, self.in_features)  # (N, in_features)

        # Base path introduces Standard linear transformation with activation
        base_act = self.base_activation(x_flat)  # (N, in_features)
        base_lin = F.linear(base_act, self.base_weight, self.base_bias)  # (N, out_features)
        base_out = self.norm_base(base_lin)  # (N, out_features)

        #  Spline path: B-spline basis expansion
        bspline = self.b_splines(x_flat)  # (N, in_features, grid_size + spline_order)
        bspline_flat = bspline.view(x_flat.size(0), -1)  # (N, in_features * (grid_size + spline_order))

        # Efficient weight application through linear operation:
        # (out_features, in_features, S) -> (out_features, in_features * S)
        spline_w_flat = self.scaled_spline_weight.view(self.out_features, -1)  # (out_features, in_features * S)
        spline_lin = F.linear(bspline_flat, spline_w_flat)  # (N, out_features)
        spline_out = self.norm_spline(spline_lin)  # (N, out_features)

        # Combine paths with residual connection-like behavior
        out = base_out + spline_out  # (N, out_features)
        out = out.view(*orig_shape[:-1], self.out_features)  # Restore original shape

        return out