├── efficient_kan ├── __init__.py ├── __pycache__ │ ├── kan.cpython-310.pyc │ └── __init__.cpython-310.pyc └── kan.py ├── README.md ├── LICENSE ├── training.py └── mnist.py /efficient_kan/__init__.py: -------------------------------------------------------------------------------- 1 | from .kan import KANLinear, KAN 2 | 3 | __all__ = ["KANLinear", "KAN"] 4 | -------------------------------------------------------------------------------- /efficient_kan/__pycache__/kan.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Peter-obi/Kolmogorov-Arnold-Networks-KANs-mlx/HEAD/efficient_kan/__pycache__/kan.cpython-310.pyc -------------------------------------------------------------------------------- /efficient_kan/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Peter-obi/Kolmogorov-Arnold-Networks-KANs-mlx/HEAD/efficient_kan/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kolmogorov-Arnold-Networks-KANs-mlx 2 | 3 | Rudimentary implementation of the efficient KAN https://github.com/Blealtan/efficient-kan using the mlx framework. Feel free to optimize for speed or implement other types of KAN neural networks. For example there are CNNS made with KANS like this https://github.com/AntonioTepsich/Convolutional-KANs and so on. 4 | 5 | You should have mlx installed (pip install mlx). 6 | Basic usage - python training.py. 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Peter Obi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | import mlx.optimizers as optim 6 | import numpy as np 7 | import mnist 8 | from efficient_kan import KAN 9 | 10 | def loss_fn(model, X, y): 11 | return mx.mean(nn.losses.cross_entropy(model(X), y)) 12 | 13 | def eval_fn(model, X, y): 14 | return mx.mean(mx.argmax(model(X), axis=1) == y) 15 | 16 | def batch_iterate(batch_size, X, y): 17 | perm = mx.array(np.random.permutation(y.size)) 18 | for s in range(0, y.size, batch_size): 19 | ids = perm[s : s + batch_size] 20 | yield X[ids], y[ids] 21 | 22 | def main(args): 23 | seed = 0 24 | num_layers = 2 25 | hidden_dim = 64 26 | num_classes = 10 27 | batch_size = 64 28 | num_epochs = 10 29 | learning_rate = 1e-3 30 | 31 | np.random.seed(seed) 32 | 33 | # Load the data 34 | train_images, train_labels, test_images, test_labels = map( 35 | mx.array, getattr(mnist, args.dataset)() 36 | ) 37 | 38 | # Load the model 39 | model = KAN([28 * 28, hidden_dim, num_classes]) 40 | mx.eval(model.parameters()) 41 | 42 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 43 | optimizer = optim.AdamW(learning_rate=learning_rate, weight_decay=1e-4) 44 | 45 | for e in range(num_epochs): 46 | tic = time.perf_counter() 47 | 48 | model.train() 49 | for X, y in batch_iterate(batch_size, train_images, train_labels): 50 | loss, grads = loss_and_grad_fn(model, X, y) 51 | optimizer.update(model, grads) 52 | mx.eval(model.parameters(), optimizer.state) 53 | 54 | model.eval() 55 | accuracy = eval_fn(model, test_images, test_labels) 56 | 57 | toc = time.perf_counter() 58 | print( 59 | f"Epoch {e}: Test accuracy {accuracy.item():.3f}," 60 | f" Time {toc - tic:.3f} (s)" 61 | ) 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser("Train KAN on MNIST with MLX.") 65 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") 66 | parser.add_argument( 67 | "--dataset", 68 | type=str, 69 | default="mnist", 70 | choices=["mnist", "fashion_mnist"], 71 | help="The dataset to use.", 72 | ) 73 | args = parser.parse_args() 74 | 75 | if not args.gpu: 76 | mx.set_default_device(mx.cpu) 77 | 78 | main(args) -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import gzip 4 | import os 5 | import pickle 6 | from urllib import request 7 | 8 | import numpy as np 9 | 10 | 11 | def mnist( 12 | save_dir="/tmp", base_url="http://yann.lecun.com/exdb/mnist/", filename="mnist.pkl" 13 | ): 14 | """ 15 | Load the MNIST dataset in 4 tensors: train images, train labels, 16 | test images, and test labels. 17 | 18 | Checks `save_dir` for already downloaded data otherwise downloads. 19 | 20 | Download code modified from: 21 | https://github.com/hsjeong5/MNIST-for-Numpy 22 | """ 23 | 24 | def download_and_save(save_file): 25 | filename = [ 26 | ["training_images", "train-images-idx3-ubyte.gz"], 27 | ["test_images", "t10k-images-idx3-ubyte.gz"], 28 | ["training_labels", "train-labels-idx1-ubyte.gz"], 29 | ["test_labels", "t10k-labels-idx1-ubyte.gz"], 30 | ] 31 | 32 | mnist = {} 33 | for name in filename: 34 | out_file = os.path.join("/tmp", name[1]) 35 | request.urlretrieve(base_url + name[1], out_file) 36 | for name in filename[:2]: 37 | out_file = os.path.join("/tmp", name[1]) 38 | with gzip.open(out_file, "rb") as f: 39 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape( 40 | -1, 28 * 28 41 | ) 42 | for name in filename[-2:]: 43 | out_file = os.path.join("/tmp", name[1]) 44 | with gzip.open(out_file, "rb") as f: 45 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) 46 | with open(save_file, "wb") as f: 47 | pickle.dump(mnist, f) 48 | 49 | save_file = os.path.join(save_dir, filename) 50 | if not os.path.exists(save_file): 51 | download_and_save(save_file) 52 | with open(save_file, "rb") as f: 53 | mnist = pickle.load(f) 54 | 55 | def preproc(x): 56 | return x.astype(np.float32) / 255.0 57 | 58 | mnist["training_images"] = preproc(mnist["training_images"]) 59 | mnist["test_images"] = preproc(mnist["test_images"]) 60 | return ( 61 | mnist["training_images"], 62 | mnist["training_labels"].astype(np.uint32), 63 | mnist["test_images"], 64 | mnist["test_labels"].astype(np.uint32), 65 | ) 66 | 67 | 68 | def fashion_mnist(save_dir="/tmp"): 69 | return mnist( 70 | save_dir, 71 | base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", 72 | filename="fashion_mnist.pkl", 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | train_x, train_y, test_x, test_y = mnist() 78 | assert train_x.shape == (60000, 28 * 28), "Wrong training set size" 79 | assert train_y.shape == (60000,), "Wrong training set size" 80 | assert test_x.shape == (10000, 28 * 28), "Wrong test set size" 81 | assert test_y.shape == (10000,), "Wrong test set size" 82 | -------------------------------------------------------------------------------- /efficient_kan/kan.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | import mlx.utils as utils 4 | import math 5 | 6 | class KANLinear(nn.Module): 7 | def __init__( 8 | self, 9 | in_features, 10 | out_features, 11 | grid_size=5, 12 | spline_order=3, 13 | scale_noise=0.1, 14 | scale_base=1.0, 15 | scale_spline=1.0, 16 | enable_standalone_scale_spline=True, 17 | base_activation=nn.SiLU, 18 | grid_eps=0.02, 19 | grid_range=[-1, 1], 20 | ): 21 | super().__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.grid_size = grid_size 25 | self.spline_order = spline_order 26 | 27 | h = (grid_range[1] - grid_range[0]) / grid_size 28 | grid = ( 29 | ( 30 | mx.arange(-spline_order, grid_size + spline_order + 1) * h 31 | + grid_range[0] 32 | ) 33 | .reshape(-1, 1) # Reshape to a column vector 34 | ) 35 | self.grid = mx.tile(grid, (1, in_features)) 36 | 37 | self.base_weight = mx.random.uniform(shape=(out_features, in_features)) 38 | self.spline_weight = mx.random.uniform(shape=(out_features, in_features, grid_size + spline_order)) 39 | if enable_standalone_scale_spline: 40 | self.spline_scaler = mx.random.uniform(shape=(out_features, in_features)) 41 | 42 | self.scale_noise = scale_noise 43 | self.scale_base = scale_base 44 | self.scale_spline = scale_spline 45 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 46 | self.base_activation = base_activation() 47 | self.grid_eps = grid_eps 48 | 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self): 52 | # Initialize base_weight with random values scaled by scale_base 53 | self.base_weight = mx.random.uniform(shape=self.base_weight.shape) * self.scale_base 54 | 55 | # Initialize spline_weight with random values scaled by scale_spline 56 | self.spline_weight = mx.random.uniform(shape=self.spline_weight.shape) * self.scale_spline 57 | 58 | if self.enable_standalone_scale_spline: 59 | # Initialize spline_scaler with random values scaled by scale_spline 60 | self.spline_scaler = mx.random.uniform(shape=self.spline_scaler.shape) * self.scale_spline 61 | 62 | 63 | def b_splines(self, x): 64 | assert x.ndim == 2 and x.shape[1] == self.in_features 65 | 66 | x = x[:, :, None] # Add a new dimension to x 67 | grid_reshaped = self.grid[:-1].T[None, :, :] # Reshape to (1, in_features, grid_size) 68 | next_grid_reshaped = self.grid[1:].T[None, :, :] # Reshape to (1, in_features, grid_size) 69 | 70 | bases = mx.logical_and(x >= grid_reshaped, x < next_grid_reshaped).astype(x.dtype) 71 | for k in range(1, self.spline_order + 1): 72 | bases = ( 73 | (x - self.grid[:-(k+1)].T[None, :, :]) / 74 | (self.grid[k:-1].T[None, :, :] - self.grid[:-(k+1)].T[None, :, :]) * 75 | bases[:, :, :-1] 76 | ) + ( 77 | (self.grid[(k+1):].T[None, :, :] - x) / 78 | (self.grid[(k+1):].T[None, :, :] - self.grid[1:(-k)].T[None, :, :]) * 79 | bases[:, :, 1:] 80 | ) 81 | assert bases.shape == (x.shape[0], self.in_features, self.grid_size + self.spline_order) 82 | return bases 83 | 84 | 85 | def curve2coeff(self, x, y): 86 | assert x.ndim == 2 and x.shape[1] == self.in_features 87 | assert y.shape == (x.shape[0], self.in_features, self.out_features) 88 | 89 | A = self.b_splines(x).transpose(1, 0, 2) 90 | B = y.transpose(1, 0, 2) 91 | solution = mx.linalg.lstsq(A, B).solution 92 | result = solution.transpose(0, 2, 1) 93 | 94 | assert result.shape == ( 95 | self.out_features, 96 | self.in_features, 97 | self.grid_size + self.spline_order, 98 | ) 99 | return result 100 | 101 | @property 102 | def scaled_spline_weight(self): 103 | return self.spline_weight * ( 104 | self.spline_scaler[:, :, None] 105 | if self.enable_standalone_scale_spline 106 | else 1.0 107 | ) 108 | 109 | def __call__(self, x): 110 | assert x.ndim == 2 and x.shape[1] == self.in_features 111 | base_output = mx.matmul(self.base_activation(x), self.base_weight.T) 112 | spline_output = mx.matmul( 113 | self.b_splines(x.reshape(x.shape[0], -1)).reshape(x.shape[0], -1), 114 | self.scaled_spline_weight.reshape(self.out_features, -1).T, 115 | ) 116 | return base_output + spline_output 117 | 118 | def update_grid(self, x, margin=0.01): 119 | assert x.ndim == 2 and x.shape[1] == self.in_features 120 | batch = x.shape[0] 121 | 122 | splines = self.b_splines(x) 123 | splines = splines.transpose(1, 0, 2) 124 | orig_coeff = self.scaled_spline_weight 125 | orig_coeff = orig_coeff.transpose(1, 2, 0) 126 | unreduced_spline_output = mx.matmul(splines, orig_coeff) 127 | unreduced_spline_output = unreduced_spline_output.transpose(1, 0, 2) 128 | 129 | x_sorted = mx.sort(x, axis=0) 130 | grid_adaptive = x_sorted[ 131 | mx.linspace( 132 | 0, batch - 1, self.grid_size + 1, dtype=mx.int32 133 | ).astype(mx.int64) 134 | ] 135 | 136 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size 137 | grid_uniform = ( 138 | mx.arange(self.grid_size + 1).reshape(-1, 1).astype(mx.float32) 139 | * uniform_step 140 | + x_sorted[0] 141 | - margin 142 | ) 143 | 144 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 145 | grid = mx.concatenate( 146 | [ 147 | grid[:1] 148 | - uniform_step 149 | * mx.arange(self.spline_order, 0, -1).reshape(-1, 1), 150 | grid, 151 | grid[-1:] 152 | + uniform_step 153 | * mx.arange(1, self.spline_order + 1).reshape(-1, 1), 154 | ], 155 | axis=0, 156 | ) 157 | 158 | self.grid.set_value(grid.T) 159 | self.spline_weight.set_value(self.curve2coeff(x, unreduced_spline_output)) 160 | 161 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 162 | l1_fake = mx.abs(self.spline_weight).mean(axis=-1) 163 | regularization_loss_activation = l1_fake.sum() 164 | p = l1_fake / regularization_loss_activation 165 | regularization_loss_entropy = -mx.sum(p * mx.log(p)) 166 | return ( 167 | regularize_activation * regularization_loss_activation 168 | + regularize_entropy * regularization_loss_entropy 169 | ) 170 | 171 | 172 | class KAN(nn.Module): 173 | def __init__( 174 | self, 175 | layers_hidden, 176 | grid_size=5, 177 | spline_order=3, 178 | scale_noise=0.1, 179 | scale_base=1.0, 180 | scale_spline=1.0, 181 | base_activation=nn.SiLU, 182 | grid_eps=0.02, 183 | grid_range=[-1, 1], 184 | ): 185 | super().__init__() 186 | self.grid_size = grid_size 187 | self.spline_order = spline_order 188 | 189 | self.layers = [] 190 | for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): 191 | self.layers.append( 192 | KANLinear( 193 | in_features, 194 | out_features, 195 | grid_size=grid_size, 196 | spline_order=spline_order, 197 | scale_noise=scale_noise, 198 | scale_base=scale_base, 199 | scale_spline=scale_spline, 200 | base_activation=base_activation, 201 | grid_eps=grid_eps, 202 | grid_range=grid_range, 203 | ) 204 | ) 205 | 206 | def __call__(self, x, update_grid=False): 207 | for layer in self.layers: 208 | if update_grid: 209 | layer.update_grid(x) 210 | x = layer(x) 211 | return x 212 | 213 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 214 | return mx.add(*( 215 | layer.regularization_loss(regularize_activation, regularize_entropy) 216 | for layer in self.layers 217 | )) --------------------------------------------------------------------------------