|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class KANLayer(nn.Module):
|
|
"""
|
|
KAN-layer using b-splines basis based on this paper:
|
|
|
|
https://arxiv.org/abs/2404.19756
|
|
|
|
This architecture fundamentally differs from MLPs by replacing fixed activation functions
|
|
with learnable univariate functions represented as B-spline bases
|
|
by decomposing multivariate functions into sums of univariate function.
|
|
|
|
|
|
Using the model presented in the original paper turned out to be very impractical
|
|
due to the need to integrate many dependencies that contradict each other within the current project,
|
|
so KAN-layer was rewritten from scratch with a focus on a more stable implementation on PyTorch with CUDA support:
|
|
|
|
https://github.com/Blealtan/efficient-kan
|
|
|
|
Where:
|
|
|
|
- B-spline parameterization enables continuous, piecewise-polynomial functions;
|
|
- Grid-based routing follows the theoretical foundation of Kolmogorov's theorem;
|
|
- Dual-path architecture (base + spline) enhances model expressivity;
|
|
- Normalization of B-spline bases and grid perturbation thresholding (grid_eps) to prevent division-by-zero errors.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
grid_size: int = 5,
|
|
spline_order: int = 3,
|
|
scale_noise: float = 0.1,
|
|
scale_base: float = 1.0,
|
|
scale_spline: float = 1.0,
|
|
enable_standalone_scale_spline: bool = True,
|
|
base_activation: nn.Module = nn.SiLU,
|
|
grid_eps: float = 0.02,
|
|
grid_range: list = [-1, 1],
|
|
):
|
|
super(KANLayer, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.grid_size = grid_size
|
|
self.spline_order = spline_order
|
|
self.scale_noise = scale_noise
|
|
self.scale_base = scale_base
|
|
self.scale_spline = scale_spline
|
|
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
|
self.base_activation = base_activation()
|
|
self.grid_eps = grid_eps
|
|
|
|
|
|
h = (grid_range[1] - grid_range[0]) / grid_size
|
|
|
|
|
|
grid = torch.arange(-spline_order, grid_size + spline_order + 1, dtype=torch.float) * h + grid_range[0]
|
|
grid = grid.unsqueeze(0).expand(in_features, -1).contiguous()
|
|
self.register_buffer("grid", grid)
|
|
|
|
|
|
self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
|
self.base_bias = nn.Parameter(torch.Tensor(out_features))
|
|
|
|
|
|
|
|
self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
|
|
if enable_standalone_scale_spline:
|
|
|
|
self.spline_scaler = nn.Parameter(torch.ones(out_features, in_features))
|
|
|
|
|
|
self.norm_base = nn.LayerNorm(out_features)
|
|
self.norm_spline = nn.LayerNorm(out_features)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
"""
|
|
Parameter initialization strategy combining:
|
|
- base_weight: Kaiming initialization for base weights with SiLU gain adjustment;
|
|
- base_bias: Zero initialization for biases;
|
|
- spline_weight: Small random initialization for spline weights;
|
|
- spline_scaler: ones (if it's standalone).
|
|
"""
|
|
|
|
gain = math.sqrt(2.0)
|
|
nn.init.kaiming_uniform_(self.base_weight, a=0, mode='fan_in', nonlinearity='relu')
|
|
self.base_weight.data.mul_(self.scale_base * gain)
|
|
nn.init.zeros_(self.base_bias)
|
|
|
|
|
|
nn.init.uniform_(self.spline_weight, -self.scale_noise, self.scale_noise)
|
|
|
|
if self.enable_standalone_scale_spline:
|
|
|
|
pass
|
|
|
|
def b_splines(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute B-spline basis functions using Cox-de Boor recursion.
|
|
Normalization ensures partition-of-unity property for numerical stability
|
|
Args:
|
|
x: Input tensor of shape (N, in_features)
|
|
Returns:
|
|
bases: B-spline basis tensor of shape (N, in_features, grid_size + spline_order)
|
|
"""
|
|
N = x.shape[0]
|
|
|
|
|
|
grid = self.grid.unsqueeze(0).expand(N, -1, -1)
|
|
x_exp = x.unsqueeze(2)
|
|
|
|
|
|
bases = ((x_exp >= grid[:, :, :-1]) & (x_exp < grid[:, :, 1:])).to(x.dtype)
|
|
|
|
|
|
for k in range(1, self.spline_order + 1):
|
|
left_num = x_exp - grid[:, :, :-(k + 1)]
|
|
left_den = grid[:, :, k:-1] - grid[:, :, :-(k + 1)] + 1e-8
|
|
term1 = (left_num / left_den) * bases[:, :, :-1]
|
|
|
|
right_num = grid[:, :, k + 1:] - x_exp
|
|
right_den = grid[:, :, k + 1:] - grid[:, :, 1:-k] + 1e-8
|
|
term2 = (right_num / right_den) * bases[:, :, 1:]
|
|
bases = term1 + term2
|
|
|
|
|
|
bases = bases / (bases.sum(dim=2, keepdim=True) + 1e-8)
|
|
return bases.contiguous()
|
|
|
|
@property
|
|
def scaled_spline_weight(self) -> torch.Tensor:
|
|
"""
|
|
Apply scaling to spline weights if standalone scaling is enabled (adapting scaling mechnism)
|
|
"""
|
|
if self.enable_standalone_scale_spline:
|
|
|
|
|
|
return self.spline_weight * self.spline_scaler.unsqueeze(-1)
|
|
else:
|
|
return self.spline_weight
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass combining base and spline paths:
|
|
- Base path: Linear transformation with SiLU activation
|
|
- Spline path: B-spline basis expansion with learned coefficients
|
|
- Path combination through summation
|
|
"""
|
|
orig_shape = x.shape
|
|
x_flat = x.reshape(-1, self.in_features)
|
|
|
|
|
|
base_act = self.base_activation(x_flat)
|
|
base_lin = F.linear(base_act, self.base_weight, self.base_bias)
|
|
base_out = self.norm_base(base_lin)
|
|
|
|
|
|
bspline = self.b_splines(x_flat)
|
|
bspline_flat = bspline.view(x_flat.size(0), -1)
|
|
|
|
|
|
|
|
spline_w_flat = self.scaled_spline_weight.view(self.out_features, -1)
|
|
spline_lin = F.linear(bspline_flat, spline_w_flat)
|
|
spline_out = self.norm_spline(spline_lin)
|
|
|
|
|
|
out = base_out + spline_out
|
|
out = out.view(*orig_shape[:-1], self.out_features)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|