陈巍:KAN网络结构思路来自Kolmogorov-Arnold表示定理。MLP 在节点(“神经元”)上具有固定的激活函数,而 KAN 在边(“权重”)上具有可学习的激活函数。在数据拟合和 PDE 求解中,较小的 KAN 可以比较大的 MLP 获得更好的准确性。相对MLP,KAN也具备更好的可解释性,适合作为数学和物理研究中的辅助模型,帮助发现和寻找更基础的数值规律。

ML-kan-2024-05-12-15-11-55

前言

Kolmogorov-Arnold 表示定理

Kolmogorov-Arnold 定理, 中文叫做科尔莫戈洛夫-阿诺尔德表示定理。这个定理是由苏联数学家安德烈·科尔莫戈洛夫(Andrey Kolmogorov)首先提出,并由他的学生弗拉基米尔·阿诺尔德(Vladimir Arnold)在1957年进一步发展。定理最初的动机是探讨多元函数可以如何被一组更简单的函数表示,这是数学和理论计算机科学中的一个基本问题,而且实际上部分解答了数学家希尔伯特著名的23个问题中的第13个问题:是否可以使用加,减,乘,除,以及最多两个变量的代数函数的组合来求解7次方程。 Kolmogorov-Arnold 定理是是在一个更广泛的连续函数的背景下,而非希尔伯特原先提出的代数方程的框架,因此只能算是部分解决。

具体来说, Kolmogorov-Arnold 定理指的是,对于任何定义在闭区间 上的连续函数f(x1,xn)f(x_{1},\ldots x_{n}),存在整数 以及一系列的一维连续函数 和 ,使得该多变量函数可以表示为:

f(x1,...,xn)=q=12n+1Φq(p=1nϕq,p(xp))f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q\left(\sum_{p=1}^{n}\phi_{q,p}(x_p)\right)

efficient-kan: https://github.com/Blealtan/efficient-kan/

项目概述

项目概述(Readme.md)

这个仓库包含了一个高效的Kolmogorov-Arnold Network (KAN)实现。原版KAN的实现存在性能问题,主要是因为它需要扩展所有中间变量以执行不同的激活函数。新的实现通过重新公式化计算来减少内存消耗并使计算更加高效。

主要改进点

  1. 计算优化

    • 原版问题:原版KAN在每一层中,需要将输入扩展为形状为 (batch_size, out_features, in_features) 的张量,以便执行激活函数。这导致了较高的内存消耗。
    • 新实现:新实现将输入通过不同的基函数(B-splines)激活,然后线性组合这些激活结果。这种重新公式化显著减少了内存消耗,并且使得前向和反向传播自然地成为矩阵乘法。
  2. 稀疏化

    • 原版问题:原版KAN使用了一种L1正则化,定义在输入样本上,需要对形状为 (batch_size, out_features, in_features) 的张量进行非线性操作,这与新实现的重新公式化不兼容。
    • 新实现:新实现改为在权重上使用L1正则化,这是神经网络中更常见的做法,并且与新实现的重新公式化兼容。作者的实现中也包括了这种正则化,因此认为这可能有助于提高性能,但仍需进一步实验验证。
  3. 可学习的激活函数

    • 原版问题:原版KAN不仅包括可学习的激活函数(B-splines),还包括每个激活函数上的可学习缩放。
    • 新实现:新实现提供了一个选项 enable_standalone_scale_spline,默认为 True,以包含这一特性。禁用该选项会使模型更高效,但可能会损害性能。需要更多的实验来验证这一点。

更新记录

  • 2024-05-04:@xiaol 提示,基权重参数的常数初始化在MNIST数据集上可能是一个问题。目前,已将 base_weightspline_scaler 矩阵的初始化方式改为 kaiming_uniform_,与 nn.Linear 的初始化方式一致。这在MNIST数据集上的表现显著提升(从约20%到约97%),但不确定这种方法在一般情况下是否合适。

总结

  • 主要改进:通过重新公式化计算,显著减少了内存消耗并提高了计算效率。
  • 稀疏化:改为在权重上使用L1正则化,与新实现的计算方式兼容。
  • 可学习的激活函数:提供了可选项以控制是否启用每个激活函数上的可学习缩放。
  • 初始化改进:改变了基权重参数的初始化方式,显著提升了在MNIST数据集上的性能。

这些改进使得KAN的实现更加高效,并且在某些数据集上表现出更好的性能。然而,仍需进一步的实验来验证这些改进在不同任务和数据集上的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order

h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)

self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)

self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps

self.reset_parameters()

def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).

Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features

grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)

assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()

def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)

A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)

assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()

@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)

def forward(self, x: torch.Tensor):
assert x.size(-1) == self.in_features
original_shape = x.shape
x = x.reshape(-1, self.in_features)

base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output

output = output.reshape(*original_shape[:-1], self.out_features)
return output

@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)

splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)

# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]

uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)

grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)

self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.

This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.

The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)


class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order

self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)

def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x

def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)

后记

感觉还是挺有意思的,通过线性权重来替换以前的点权重,有兴趣的可能可以水很多论文了。
以前尝试过使用线性函数来训练2维图像,但是效果总是不好,今天才发觉,原来是权重初始化的问题。使用kaiming_uniform_极大的提高了准确率,还得是靠kaiming He。