├── ConvKAN.py ├── Demo.ipynb ├── KAN_Implementations ├── Efficient_KAN │ ├── __pycache__ │ │ ├── efficient_kan.cpython-312.pyc │ │ └── kan.cpython-312.pyc │ └── efficient_kan.py ├── Fast_KAN │ ├── __pycache__ │ │ └── fast_kan.cpython-312.pyc │ └── fast_kan.py └── Original_KAN │ ├── LBFGS.py │ ├── Symbolic_KANLayer.py │ ├── __pycache__ │ ├── LBFGS.cpython-312.pyc │ ├── Symbolic_KANLayer.cpython-312.pyc │ ├── original_kan.cpython-312.pyc │ ├── spline.cpython-312.pyc │ └── utils.cpython-312.pyc │ ├── original_kan.py │ ├── spline.py │ └── utils.py ├── LICENSE ├── README.md └── requirements.txt /ConvKAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | from enum import Enum 6 | import warnings 7 | sys.path.append("KAN_Implementations/Efficient_KAN") 8 | sys.path.append("KAN_Implementations/Original_KAN") 9 | sys.path.append("KAN_Implementations/Fast_KAN") 10 | from efficient_kan import Efficient_KANLinear 11 | from original_kan import KAN 12 | from fast_kan import Fast_KANLinear 13 | 14 | class ConvKAN(nn.Module): 15 | 16 | def __init__(self, 17 | in_channels, 18 | out_channels, 19 | kernel_size, 20 | stride=1, 21 | padding=0, 22 | grid_size=5, 23 | spline_order=3, 24 | scale_noise=0.1, 25 | scale_base=1.0, 26 | scale_spline=1.0, 27 | enable_standalone_scale_spline=True, 28 | base_activation=torch.nn.SiLU(), 29 | grid_eps=0.02, 30 | grid_range=[-1, 1], 31 | sp_trainable=True, 32 | sb_trainable=True, 33 | bias_trainable=True, 34 | symbolic_enabled=True, 35 | device="cpu", 36 | version= "Efficient", 37 | ): 38 | super(ConvKAN, self).__init__() 39 | 40 | self.version = version 41 | self.in_channels = in_channels 42 | self.out_channels = out_channels 43 | self.kernel_size = kernel_size 44 | self.stride = stride 45 | self.padding = padding 46 | 47 | self.unfold = nn.Unfold(kernel_size, padding=padding, stride=stride) 48 | 49 | self.linear = None 50 | 51 | if self.version == "Efficient": 52 | self.linear = Efficient_KANLinear( 53 | in_features = in_channels * kernel_size * kernel_size, 54 | out_features = out_channels, 55 | grid_size=grid_size, 56 | spline_order=spline_order, 57 | scale_noise=scale_noise, 58 | scale_base=scale_base, 59 | scale_spline=scale_spline, 60 | enable_standalone_scale_spline=enable_standalone_scale_spline, 61 | base_activation=base_activation, 62 | grid_eps=grid_eps, 63 | grid_range=grid_range, 64 | ) 65 | warnings.warn('Warning: Efficient KAN implementation does not support the following parameters: [sp_trainable, sb_trainable, device]') 66 | elif self.version == "Original": 67 | self.linear = KAN( 68 | width = [in_channels * kernel_size * kernel_size, out_channels], 69 | grid = grid_size, 70 | k = spline_order, 71 | noise_scale = scale_noise, 72 | noise_scale_base = scale_base, 73 | base_fun = base_activation, 74 | symbolic_enabled=symbolic_enabled, 75 | bias_trainable = bias_trainable, 76 | grid_eps = grid_eps, 77 | grid_range = grid_range, 78 | sp_trainable = sp_trainable, 79 | sb_trainable = sb_trainable, 80 | device = device, 81 | ) 82 | 83 | elif self.version == "Fast": 84 | self.linear = Fast_KANLinear( 85 | input_dim = in_channels * kernel_size * kernel_size, 86 | output_dim = out_channels, 87 | num_grids=grid_size, 88 | spline_weight_init_scale=scale_spline, 89 | base_activation=base_activation, 90 | grid_min = grid_range[0], 91 | grid_max = grid_range[1], 92 | ) 93 | warnings.warn('Warning: Fast KAN implementation does not support the following parameters: [scale_noise, scale_base, enable_standalone_scale_spline, grid_eps, sp_trainable, sb_trainable, device]') 94 | 95 | 96 | 97 | def forward(self, x): 98 | 99 | batch_size, in_channels, height, width = x.size() 100 | assert x.dim() == 4 101 | assert in_channels == self.in_channels 102 | 103 | # Unfold the input tensor to extract flattened sliding blocks from a batched input tensor. 104 | # Input: [batch_size, in_channels, height, width] 105 | # Output: [batch_size, in_channels*kernel_size*kernel_size, num_patches] 106 | patches = self.unfold(x) 107 | 108 | # Transpose to have the patches dimension last. 109 | # Input: [batch_size, in_channels*kernel_size*kernel_size, num_patches] 110 | # Output: [batch_size, num_patches, in_channels*kernel_size*kernel_size] 111 | patches = patches.transpose(1, 2) 112 | 113 | # Reshape the patches to fit the linear layer input requirements. 114 | # Input: [batch_size, num_patches, in_channels*kernel_size*kernel_size] 115 | # Output: [batch_size*num_patches, in_channels*kernel_size*kernel_size] 116 | patches = patches.reshape(-1, in_channels * self.kernel_size * self.kernel_size) 117 | 118 | # Apply the linear layer to each patch. 119 | # Input: [batch_size*num_patches, in_channels*kernel_size*kernel_size] 120 | # Output: [batch_size*num_patches, out_channels 121 | out = self.linear(patches) 122 | 123 | # Reshape the output to the normal format 124 | # Input: [batch_size*num_patches, out_channels] 125 | # Output: [batch_size, num_patches, out_channels] 126 | out = out.view(batch_size, -1, out.size(-1)) 127 | 128 | # Calculate the height and width of the output. 129 | out_height = (height + 2*self.padding - self.kernel_size) // self.stride + 1 130 | out_width = (width + 2*self.padding - self.kernel_size) // self.stride + 1 131 | 132 | # Transpose back to have the channel dimension in the second position. 133 | # Input: [batch_size, num_patches, out_channels] 134 | # Output: [batch_size, out_channels, num_patches] 135 | out = out.transpose(1, 2) 136 | 137 | # Reshape the output to the final shape 138 | # Input: [batch_size, out_channels, num_patches] 139 | # Output: [batch_size, out_channels, out_height, out_width] 140 | out = out.view(batch_size, self.out_channels, out_height, out_width) 141 | 142 | return out 143 | -------------------------------------------------------------------------------- /Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**ConvKAN using Efficient KAN Implementation**" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "torch.Size([1024, 27])\n", 20 | "Input: torch.Size([1, 3, 32, 32])\n", 21 | "Output: torch.Size([1, 4, 32, 32])\n" 22 | ] 23 | }, 24 | { 25 | "name": "stderr", 26 | "output_type": "stream", 27 | "text": [ 28 | "/Users/omarrayyann/Documents/KAN-Conv2D/ConvKAN.py:65: UserWarning: Warning: Efficient KAN implementation does not support the following parameters: [sp_trainable, sb_trainable, device]\n", 29 | " warnings.warn('Warning: Efficient KAN implementation does not support the following parameters: [sp_trainable, sb_trainable, device]')\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import torch\n", 35 | "from ConvKAN import ConvKAN\n", 36 | "\n", 37 | "conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version=\"Efficient\", grid_size=3)\n", 38 | "conv.parameters()\n", 39 | "x = torch.rand((1,3,32,32))\n", 40 | "y = conv(x)\n", 41 | "\n", 42 | "print(\"Input: \", x.shape)\n", 43 | "print(\"Output: \", y.shape)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "torch.Size([1024, 27])\n", 56 | "dsadsa\n", 57 | "Dasidbasd\n" 58 | ] 59 | }, 60 | { 61 | "ename": "ValueError", 62 | "evalue": "too many values to unpack (expected 4)", 63 | "output_type": "error", 64 | "traceback": [ 65 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 66 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 67 | "Cell \u001b[0;32mIn[2], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m conv \u001b[38;5;241m=\u001b[39m ConvKAN(in_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, out_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, kernel_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, stride\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, version\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOriginal\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 6\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrand((\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m32\u001b[39m,\u001b[38;5;241m32\u001b[39m))\n\u001b[0;32m----> 7\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mconv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput: \u001b[39m\u001b[38;5;124m\"\u001b[39m, x\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOutput: \u001b[39m\u001b[38;5;124m\"\u001b[39m, y\u001b[38;5;241m.\u001b[39mshape)\n", 68 | "File \u001b[0;32m~/Documents/KAN-Conv2D/ConvKAN.py:122\u001b[0m, in \u001b[0;36mConvKAN.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;66;03m# Apply the linear layer to each patch.\u001b[39;00m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;66;03m# Input: [batch_size*num_patches, in_channels*kernel_size*kernel_size]\u001b[39;00m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# Output: [batch_size*num_patches, out_channels\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28mprint\u001b[39m(patches\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m--> 122\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpatches\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;66;03m# Reshape the output to the normal format\u001b[39;00m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;66;03m# Input: [batch_size*num_patches, out_channels]\u001b[39;00m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;66;03m# Output: [batch_size, num_patches, out_channels]\u001b[39;00m\n\u001b[1;32m 127\u001b[0m out \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39mview(batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, out\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)) \n", 69 | "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", 70 | "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", 71 | "File \u001b[0;32m~/Documents/KAN-Conv2D/KAN_Implementations/Original_KAN/original_kan.py:677\u001b[0m, in \u001b[0;36mKAN.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m l \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdepth):\n\u001b[1;32m 675\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDasidbasd\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 677\u001b[0m x_numerical, preacts, postacts_numerical, postspline \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fun[l](x)\n\u001b[1;32m 679\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msymbolic_enabled \u001b[38;5;241m==\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 680\u001b[0m x_symbolic, postacts_symbolic \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msymbolic_fun[l](x)\n", 72 | "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 4)" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "import torch\n", 78 | "from ConvKAN import ConvKAN\n", 79 | "\n", 80 | "conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version=\"Original\")\n", 81 | "\n", 82 | "x = torch.rand((1,3,32,32))\n", 83 | "y = conv.forward(x)\n", 84 | "\n", 85 | "print(\"Input: \", x.shape)\n", 86 | "print(\"Output: \", y.shape)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "**ConvKAN using Fast KAN Implementation**" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Input: torch.Size([1, 3, 32, 32])\n", 106 | "Output: torch.Size([1, 4, 32, 32])\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "/Users/omarrayyann/Documents/KAN-Conv2D/ConvKAN.py:90: UserWarning: Warning: Fast KAN implementation does not support the following parameters: [scale_noise, scale_base, enable_standalone_scale_spline, grid_eps, sp_trainable, sb_trainable, device]\n", 114 | " warnings.warn('Warning: Fast KAN implementation does not support the following parameters: [scale_noise, scale_base, enable_standalone_scale_spline, grid_eps, sp_trainable, sb_trainable, device]')\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "import torch\n", 120 | "from ConvKAN import ConvKAN\n", 121 | "\n", 122 | "conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version=\"Fast\")\n", 123 | "\n", 124 | "x = torch.rand((1,3,32,32))\n", 125 | "y = conv(x)\n", 126 | "\n", 127 | "print(\"Input: \", x.shape)\n", 128 | "print(\"Output: \", y.shape)" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.12.2" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /KAN_Implementations/Efficient_KAN/__pycache__/efficient_kan.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Efficient_KAN/__pycache__/efficient_kan.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Efficient_KAN/__pycache__/kan.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Efficient_KAN/__pycache__/kan.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Efficient_KAN/efficient_kan.py: -------------------------------------------------------------------------------- 1 | # Efficient KAN Implementation 2 | # From: https://github.com/Blealtan/efficient-kan 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import math 7 | 8 | class Efficient_KANLinear(torch.nn.Module): 9 | def __init__( 10 | self, 11 | in_features, 12 | out_features, 13 | grid_size=5, 14 | spline_order=3, 15 | scale_noise=0.1, 16 | scale_base=1.0, 17 | scale_spline=1.0, 18 | enable_standalone_scale_spline=True, 19 | base_activation=torch.nn.SiLU, 20 | grid_eps=0.02, 21 | grid_range=[-1, 1], 22 | ): 23 | super(Efficient_KANLinear, self).__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.grid_size = grid_size 27 | self.spline_order = spline_order 28 | 29 | h = (grid_range[1] - grid_range[0]) / grid_size 30 | grid = ( 31 | ( 32 | torch.arange(-spline_order, grid_size + spline_order + 1) * h 33 | + grid_range[0] 34 | ) 35 | .expand(in_features, -1) 36 | .contiguous() 37 | ) 38 | self.register_buffer("grid", grid) 39 | 40 | self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) 41 | self.spline_weight = torch.nn.Parameter( 42 | torch.Tensor(out_features, in_features, grid_size + spline_order) 43 | ) 44 | if enable_standalone_scale_spline: 45 | self.spline_scaler = torch.nn.Parameter( 46 | torch.Tensor(out_features, in_features) 47 | ) 48 | 49 | self.scale_noise = scale_noise 50 | self.scale_base = scale_base 51 | self.scale_spline = scale_spline 52 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 53 | self.base_activation = base_activation 54 | self.grid_eps = grid_eps 55 | 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) 60 | with torch.no_grad(): 61 | noise = ( 62 | ( 63 | torch.rand(self.grid_size + 1, self.in_features, self.out_features) 64 | - 1 / 2 65 | ) 66 | * self.scale_noise 67 | / self.grid_size 68 | ) 69 | self.spline_weight.data.copy_( 70 | (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) 71 | * self.curve2coeff( 72 | self.grid.T[self.spline_order : -self.spline_order], 73 | noise, 74 | ) 75 | ) 76 | if self.enable_standalone_scale_spline: 77 | # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) 78 | torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) 79 | 80 | def b_splines(self, x: torch.Tensor): 81 | """ 82 | Compute the B-spline bases for the given input tensor. 83 | 84 | Args: 85 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 86 | 87 | Returns: 88 | torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). 89 | """ 90 | assert x.dim() == 2 and x.size(1) == self.in_features 91 | 92 | grid: torch.Tensor = ( 93 | self.grid 94 | ) # (in_features, grid_size + 2 * spline_order + 1) 95 | x = x.unsqueeze(-1) 96 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) 97 | for k in range(1, self.spline_order + 1): 98 | bases = ( 99 | (x - grid[:, : -(k + 1)]) 100 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 101 | * bases[:, :, :-1] 102 | ) + ( 103 | (grid[:, k + 1 :] - x) 104 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 105 | * bases[:, :, 1:] 106 | ) 107 | 108 | assert bases.size() == ( 109 | x.size(0), 110 | self.in_features, 111 | self.grid_size + self.spline_order, 112 | ) 113 | return bases.contiguous() 114 | 115 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): 116 | """ 117 | Compute the coefficients of the curve that interpolates the given points. 118 | 119 | Args: 120 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 121 | y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). 122 | 123 | Returns: 124 | torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). 125 | """ 126 | assert x.dim() == 2 and x.size(1) == self.in_features 127 | assert y.size() == (x.size(0), self.in_features, self.out_features) 128 | 129 | A = self.b_splines(x).transpose( 130 | 0, 1 131 | ) # (in_features, batch_size, grid_size + spline_order) 132 | B = y.transpose(0, 1) # (in_features, batch_size, out_features) 133 | solution = torch.linalg.lstsq( 134 | A, B 135 | ).solution # (in_features, grid_size + spline_order, out_features) 136 | result = solution.permute( 137 | 2, 0, 1 138 | ) # (out_features, in_features, grid_size + spline_order) 139 | 140 | assert result.size() == ( 141 | self.out_features, 142 | self.in_features, 143 | self.grid_size + self.spline_order, 144 | ) 145 | return result.contiguous() 146 | 147 | @property 148 | def scaled_spline_weight(self): 149 | return self.spline_weight * ( 150 | self.spline_scaler.unsqueeze(-1) 151 | if self.enable_standalone_scale_spline 152 | else 1.0 153 | ) 154 | 155 | def forward(self, x: torch.Tensor): 156 | assert x.size(-1) == self.in_features 157 | original_shape = x.shape 158 | x = x.view(-1, self.in_features) 159 | 160 | base_output = F.linear(self.base_activation(x), self.base_weight) 161 | spline_output = F.linear( 162 | self.b_splines(x).view(x.size(0), -1), 163 | self.scaled_spline_weight.view(self.out_features, -1), 164 | ) 165 | output = base_output + spline_output 166 | 167 | output = output.view(*original_shape[:-1], self.out_features) 168 | return output 169 | 170 | @torch.no_grad() 171 | def update_grid(self, x: torch.Tensor, margin=0.01): 172 | assert x.dim() == 2 and x.size(1) == self.in_features 173 | batch = x.size(0) 174 | 175 | splines = self.b_splines(x) # (batch, in, coeff) 176 | splines = splines.permute(1, 0, 2) # (in, batch, coeff) 177 | orig_coeff = self.scaled_spline_weight # (out, in, coeff) 178 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) 179 | unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) 180 | unreduced_spline_output = unreduced_spline_output.permute( 181 | 1, 0, 2 182 | ) # (batch, in, out) 183 | 184 | # sort each channel individually to collect data distribution 185 | x_sorted = torch.sort(x, dim=0)[0] 186 | grid_adaptive = x_sorted[ 187 | torch.linspace( 188 | 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device 189 | ) 190 | ] 191 | 192 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size 193 | grid_uniform = ( 194 | torch.arange( 195 | self.grid_size + 1, dtype=torch.float32, device=x.device 196 | ).unsqueeze(1) 197 | * uniform_step 198 | + x_sorted[0] 199 | - margin 200 | ) 201 | 202 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 203 | grid = torch.concatenate( 204 | [ 205 | grid[:1] 206 | - uniform_step 207 | * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), 208 | grid, 209 | grid[-1:] 210 | + uniform_step 211 | * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), 212 | ], 213 | dim=0, 214 | ) 215 | 216 | self.grid.copy_(grid.T) 217 | self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) 218 | 219 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 220 | """ 221 | Compute the regularization loss. 222 | 223 | This is a dumb simulation of the original L1 regularization as stated in the 224 | paper, since the original one requires computing absolutes and entropy from the 225 | expanded (batch, in_features, out_features) intermediate tensor, which is hidden 226 | behind the F.linear function if we want an memory efficient implementation. 227 | 228 | The L1 regularization is now computed as mean absolute value of the spline 229 | weights. The authors implementation also includes this term in addition to the 230 | sample-based regularization. 231 | """ 232 | l1_fake = self.spline_weight.abs().mean(-1) 233 | regularization_loss_activation = l1_fake.sum() 234 | p = l1_fake / regularization_loss_activation 235 | regularization_loss_entropy = -torch.sum(p * p.log()) 236 | return ( 237 | regularize_activation * regularization_loss_activation 238 | + regularize_entropy * regularization_loss_entropy 239 | ) -------------------------------------------------------------------------------- /KAN_Implementations/Fast_KAN/__pycache__/fast_kan.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Fast_KAN/__pycache__/fast_kan.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Fast_KAN/fast_kan.py: -------------------------------------------------------------------------------- 1 | # Fast KAN Implementation 2 | # From: https://github.com/ZiyaoLi/fast-kan 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | from typing import * 9 | 10 | class SplineLinear(nn.Linear): 11 | def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None: 12 | self.init_scale = init_scale 13 | super().__init__(in_features, out_features, bias=False, **kw) 14 | 15 | def reset_parameters(self) -> None: 16 | nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale) 17 | 18 | class RadialBasisFunction(nn.Module): 19 | def __init__( 20 | self, 21 | grid_min: float = -2., 22 | grid_max: float = 2., 23 | num_grids: int = 8, 24 | denominator: float = None, # larger denominators lead to smoother basis 25 | ): 26 | super().__init__() 27 | grid = torch.linspace(grid_min, grid_max, num_grids) 28 | self.grid = torch.nn.Parameter(grid, requires_grad=False) 29 | self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1) 30 | 31 | def forward(self, x): 32 | return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2) 33 | 34 | class Fast_KANLinear(nn.Module): 35 | def __init__( 36 | self, 37 | input_dim: int, 38 | output_dim: int, 39 | grid_min: float = -2., 40 | grid_max: float = 2., 41 | num_grids: int = 8, 42 | use_base_update: bool = True, 43 | base_activation = F.silu, 44 | spline_weight_init_scale: float = 0.1, 45 | ) -> None: 46 | super().__init__() 47 | self.layernorm = nn.LayerNorm(input_dim) 48 | self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids) 49 | self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale) 50 | self.use_base_update = use_base_update 51 | if use_base_update: 52 | self.base_activation = base_activation 53 | self.base_linear = nn.Linear(input_dim, output_dim) 54 | 55 | def forward(self, x, time_benchmark=False): 56 | if not time_benchmark: 57 | spline_basis = self.rbf(self.layernorm(x)) 58 | else: 59 | spline_basis = self.rbf(x) 60 | ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1)) 61 | if self.use_base_update: 62 | base = self.base_linear(self.base_activation(x)) 63 | ret = ret + base 64 | return ret 65 | 66 | 67 | class FastKAN(nn.Module): 68 | def __init__( 69 | self, 70 | layers_hidden: List[int], 71 | grid_min: float = -2., 72 | grid_max: float = 2., 73 | num_grids: int = 8, 74 | use_base_update: bool = True, 75 | base_activation = F.silu, 76 | spline_weight_init_scale: float = 0.1, 77 | ) -> None: 78 | super().__init__() 79 | self.layers = nn.ModuleList([ 80 | Fast_KANLinear( 81 | in_dim, out_dim, 82 | grid_min=grid_min, 83 | grid_max=grid_max, 84 | num_grids=num_grids, 85 | use_base_update=use_base_update, 86 | base_activation=base_activation, 87 | spline_weight_init_scale=spline_weight_init_scale, 88 | ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:]) 89 | ]) 90 | 91 | def forward(self, x): 92 | for layer in self.layers: 93 | x = layer(x) 94 | return x 95 | 96 | 97 | class AttentionWithFastKANTransform(nn.Module): 98 | 99 | def __init__( 100 | self, 101 | q_dim: int, 102 | k_dim: int, 103 | v_dim: int, 104 | head_dim: int, 105 | num_heads: int, 106 | gating: bool = True, 107 | ): 108 | super(AttentionWithFastKANTransform, self).__init__() 109 | 110 | self.num_heads = num_heads 111 | total_dim = head_dim * self.num_heads 112 | self.gating = gating 113 | self.linear_q = Fast_KANLinear(q_dim, total_dim) 114 | self.linear_k = Fast_KANLinear(k_dim, total_dim) 115 | self.linear_v = Fast_KANLinear(v_dim, total_dim) 116 | self.linear_o = Fast_KANLinear(total_dim, q_dim) 117 | self.linear_g = None 118 | if self.gating: 119 | self.linear_g = Fast_KANLinear(q_dim, total_dim) 120 | # precompute the 1/sqrt(head_dim) 121 | self.norm = head_dim**-0.5 122 | 123 | def forward( 124 | self, 125 | q: torch.Tensor, 126 | k: torch.Tensor, 127 | v: torch.Tensor, 128 | bias: torch.Tensor = None, # additive attention bias 129 | ) -> torch.Tensor: 130 | 131 | wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc 132 | wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc 133 | att = (wq * wk).sum(-1).softmax(-2) # *qkh 134 | del wq, wk 135 | if bias is not None: 136 | att = att + bias[..., None] 137 | 138 | wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc 139 | o = (att[..., None] * wv).sum(-3) # *qhc 140 | del att, wv 141 | 142 | o = o.view(*o.shape[:-2], -1) # *q(hc) 143 | 144 | if self.linear_g is not None: 145 | # gating, use raw query input 146 | g = self.linear_g(q) 147 | o = torch.sigmoid(g) * o 148 | 149 | # merge heads 150 | o = self.linear_o(o) 151 | return o -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/LBFGS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim import Optimizer 4 | 5 | __all__ = ['LBFGS'] 6 | 7 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 8 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 9 | # Compute bounds of interpolation area 10 | if bounds is not None: 11 | xmin_bound, xmax_bound = bounds 12 | else: 13 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 14 | 15 | # Code for most common case: cubic interpolation of 2 points 16 | # w/ function and derivative values for both 17 | # Solution in this case (where x2 is the farthest point): 18 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 19 | # d2 = sqrt(d1^2 - g1*g2); 20 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 21 | # t_new = min(max(min_pos,xmin_bound),xmax_bound); 22 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 23 | d2_square = d1**2 - g1 * g2 24 | if d2_square >= 0: 25 | d2 = d2_square.sqrt() 26 | if x1 <= x2: 27 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 28 | else: 29 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 30 | return min(max(min_pos, xmin_bound), xmax_bound) 31 | else: 32 | return (xmin_bound + xmax_bound) / 2. 33 | 34 | 35 | def _strong_wolfe(obj_func, 36 | x, 37 | t, 38 | d, 39 | f, 40 | g, 41 | gtd, 42 | c1=1e-4, 43 | c2=0.9, 44 | tolerance_change=1e-9, 45 | max_ls=25): 46 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 47 | d_norm = d.abs().max() 48 | g = g.clone(memory_format=torch.contiguous_format) 49 | # evaluate objective and gradient using initial step 50 | f_new, g_new = obj_func(x, t, d) 51 | ls_func_evals = 1 52 | gtd_new = g_new.dot(d) 53 | 54 | # bracket an interval containing a point satisfying the Wolfe criteria 55 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 56 | done = False 57 | ls_iter = 0 58 | while ls_iter < max_ls: 59 | # check conditions 60 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 61 | bracket = [t_prev, t] 62 | bracket_f = [f_prev, f_new] 63 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 64 | bracket_gtd = [gtd_prev, gtd_new] 65 | break 66 | 67 | if abs(gtd_new) <= -c2 * gtd: 68 | bracket = [t] 69 | bracket_f = [f_new] 70 | bracket_g = [g_new] 71 | done = True 72 | break 73 | 74 | if gtd_new >= 0: 75 | bracket = [t_prev, t] 76 | bracket_f = [f_prev, f_new] 77 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 78 | bracket_gtd = [gtd_prev, gtd_new] 79 | break 80 | 81 | # interpolate 82 | min_step = t + 0.01 * (t - t_prev) 83 | max_step = t * 10 84 | tmp = t 85 | t = _cubic_interpolate( 86 | t_prev, 87 | f_prev, 88 | gtd_prev, 89 | t, 90 | f_new, 91 | gtd_new, 92 | bounds=(min_step, max_step)) 93 | 94 | # next step 95 | t_prev = tmp 96 | f_prev = f_new 97 | g_prev = g_new.clone(memory_format=torch.contiguous_format) 98 | gtd_prev = gtd_new 99 | f_new, g_new = obj_func(x, t, d) 100 | ls_func_evals += 1 101 | gtd_new = g_new.dot(d) 102 | ls_iter += 1 103 | 104 | # reached max number of iterations? 105 | if ls_iter == max_ls: 106 | bracket = [0, t] 107 | bracket_f = [f, f_new] 108 | bracket_g = [g, g_new] 109 | 110 | # zoom phase: we now have a point satisfying the criteria, or 111 | # a bracket around it. We refine the bracket until we find the 112 | # exact point satisfying the criteria 113 | insuf_progress = False 114 | # find high and low points in bracket 115 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) 116 | while not done and ls_iter < max_ls: 117 | # line-search bracket is so small 118 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: 119 | break 120 | 121 | # compute new trial value 122 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], 123 | bracket[1], bracket_f[1], bracket_gtd[1]) 124 | 125 | # test that we are making sufficient progress: 126 | # in case `t` is so close to boundary, we mark that we are making 127 | # insufficient progress, and if 128 | # + we have made insufficient progress in the last step, or 129 | # + `t` is at one of the boundary, 130 | # we will move `t` to a position which is `0.1 * len(bracket)` 131 | # away from the nearest boundary point. 132 | eps = 0.1 * (max(bracket) - min(bracket)) 133 | if min(max(bracket) - t, t - min(bracket)) < eps: 134 | # interpolation close to boundary 135 | if insuf_progress or t >= max(bracket) or t <= min(bracket): 136 | # evaluate at 0.1 away from boundary 137 | if abs(t - max(bracket)) < abs(t - min(bracket)): 138 | t = max(bracket) - eps 139 | else: 140 | t = min(bracket) + eps 141 | insuf_progress = False 142 | else: 143 | insuf_progress = True 144 | else: 145 | insuf_progress = False 146 | 147 | # Evaluate new point 148 | f_new, g_new = obj_func(x, t, d) 149 | ls_func_evals += 1 150 | gtd_new = g_new.dot(d) 151 | ls_iter += 1 152 | 153 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 154 | # Armijo condition not satisfied or not lower than lowest point 155 | bracket[high_pos] = t 156 | bracket_f[high_pos] = f_new 157 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) 158 | bracket_gtd[high_pos] = gtd_new 159 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 160 | else: 161 | if abs(gtd_new) <= -c2 * gtd: 162 | # Wolfe conditions satisfied 163 | done = True 164 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 165 | # old low becomes new high 166 | bracket[high_pos] = bracket[low_pos] 167 | bracket_f[high_pos] = bracket_f[low_pos] 168 | bracket_g[high_pos] = bracket_g[low_pos] 169 | bracket_gtd[high_pos] = bracket_gtd[low_pos] 170 | 171 | # new point becomes new low 172 | bracket[low_pos] = t 173 | bracket_f[low_pos] = f_new 174 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) 175 | bracket_gtd[low_pos] = gtd_new 176 | 177 | # return stuff 178 | t = bracket[low_pos] 179 | f_new = bracket_f[low_pos] 180 | g_new = bracket_g[low_pos] 181 | return f_new, g_new, t, ls_func_evals 182 | 183 | 184 | 185 | class LBFGS(Optimizer): 186 | """Implements L-BFGS algorithm. 187 | 188 | Heavily inspired by `minFunc 189 | `_. 190 | 191 | .. warning:: 192 | This optimizer doesn't support per-parameter options and parameter 193 | groups (there can be only one). 194 | 195 | .. warning:: 196 | Right now all parameters have to be on a single device. This will be 197 | improved in the future. 198 | 199 | .. note:: 200 | This is a very memory intensive optimizer (it requires additional 201 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 202 | try reducing the history size, or use a different algorithm. 203 | 204 | Args: 205 | lr (float): learning rate (default: 1) 206 | max_iter (int): maximal number of iterations per optimization step 207 | (default: 20) 208 | max_eval (int): maximal number of function evaluations per optimization 209 | step (default: max_iter * 1.25). 210 | tolerance_grad (float): termination tolerance on first order optimality 211 | (default: 1e-7). 212 | tolerance_change (float): termination tolerance on function 213 | value/parameter changes (default: 1e-9). 214 | history_size (int): update history size (default: 100). 215 | line_search_fn (str): either 'strong_wolfe' or None (default: None). 216 | """ 217 | 218 | def __init__(self, 219 | params, 220 | lr=1, 221 | max_iter=20, 222 | max_eval=None, 223 | tolerance_grad=1e-7, 224 | tolerance_change=1e-9, 225 | tolerance_ys=1e-32, 226 | history_size=100, 227 | line_search_fn=None): 228 | if max_eval is None: 229 | max_eval = max_iter * 5 // 4 230 | defaults = dict( 231 | lr=lr, 232 | max_iter=max_iter, 233 | max_eval=max_eval, 234 | tolerance_grad=tolerance_grad, 235 | tolerance_change=tolerance_change, 236 | tolerance_ys=tolerance_ys, 237 | history_size=history_size, 238 | line_search_fn=line_search_fn) 239 | super().__init__(params, defaults) 240 | 241 | if len(self.param_groups) != 1: 242 | raise ValueError("LBFGS doesn't support per-parameter options " 243 | "(parameter groups)") 244 | 245 | self._params = self.param_groups[0]['params'] 246 | self._numel_cache = None 247 | 248 | def _numel(self): 249 | if self._numel_cache is None: 250 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 251 | return self._numel_cache 252 | 253 | def _gather_flat_grad(self): 254 | views = [] 255 | for p in self._params: 256 | if p.grad is None: 257 | view = p.new(p.numel()).zero_() 258 | elif p.grad.is_sparse: 259 | view = p.grad.to_dense().view(-1) 260 | else: 261 | view = p.grad.view(-1) 262 | views.append(view) 263 | return torch.cat(views, 0) 264 | 265 | def _add_grad(self, step_size, update): 266 | offset = 0 267 | for p in self._params: 268 | numel = p.numel() 269 | # view as to avoid deprecated pointwise semantics 270 | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) 271 | offset += numel 272 | assert offset == self._numel() 273 | 274 | def _clone_param(self): 275 | return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 276 | 277 | def _set_param(self, params_data): 278 | for p, pdata in zip(self._params, params_data): 279 | p.copy_(pdata) 280 | 281 | def _directional_evaluate(self, closure, x, t, d): 282 | self._add_grad(t, d) 283 | loss = float(closure()) 284 | flat_grad = self._gather_flat_grad() 285 | self._set_param(x) 286 | return loss, flat_grad 287 | 288 | 289 | @torch.no_grad() 290 | def step(self, closure): 291 | """Perform a single optimization step. 292 | 293 | Args: 294 | closure (Callable): A closure that reevaluates the model 295 | and returns the loss. 296 | """ 297 | assert len(self.param_groups) == 1 298 | 299 | # Make sure the closure is always called with grad enabled 300 | closure = torch.enable_grad()(closure) 301 | 302 | group = self.param_groups[0] 303 | lr = group['lr'] 304 | max_iter = group['max_iter'] 305 | max_eval = group['max_eval'] 306 | tolerance_grad = group['tolerance_grad'] 307 | tolerance_change = group['tolerance_change'] 308 | tolerance_ys = group['tolerance_ys'] 309 | line_search_fn = group['line_search_fn'] 310 | history_size = group['history_size'] 311 | 312 | # NOTE: LBFGS has only global state, but we register it as state for 313 | # the first param, because this helps with casting in load_state_dict 314 | state = self.state[self._params[0]] 315 | state.setdefault('func_evals', 0) 316 | state.setdefault('n_iter', 0) 317 | 318 | # evaluate initial f(x) and df/dx 319 | orig_loss = closure() 320 | loss = float(orig_loss) 321 | current_evals = 1 322 | state['func_evals'] += 1 323 | 324 | flat_grad = self._gather_flat_grad() 325 | opt_cond = flat_grad.abs().max() <= tolerance_grad 326 | 327 | # optimal condition 328 | if opt_cond: 329 | return orig_loss 330 | 331 | # tensors cached in state (for tracing) 332 | d = state.get('d') 333 | t = state.get('t') 334 | old_dirs = state.get('old_dirs') 335 | old_stps = state.get('old_stps') 336 | ro = state.get('ro') 337 | H_diag = state.get('H_diag') 338 | prev_flat_grad = state.get('prev_flat_grad') 339 | prev_loss = state.get('prev_loss') 340 | 341 | n_iter = 0 342 | # optimize for a max of max_iter iterations 343 | while n_iter < max_iter: 344 | # keep track of nb of iterations 345 | n_iter += 1 346 | state['n_iter'] += 1 347 | 348 | ############################################################ 349 | # compute gradient descent direction 350 | ############################################################ 351 | if state['n_iter'] == 1: 352 | d = flat_grad.neg() 353 | old_dirs = [] 354 | old_stps = [] 355 | ro = [] 356 | H_diag = 1 357 | else: 358 | # do lbfgs update (update memory) 359 | y = flat_grad.sub(prev_flat_grad) 360 | s = d.mul(t) 361 | ys = y.dot(s) # y*s 362 | if ys > tolerance_ys: 363 | # updating memory 364 | if len(old_dirs) == history_size: 365 | # shift history by one (limited-memory) 366 | old_dirs.pop(0) 367 | old_stps.pop(0) 368 | ro.pop(0) 369 | 370 | # store new direction/step 371 | old_dirs.append(y) 372 | old_stps.append(s) 373 | ro.append(1. / ys) 374 | 375 | # update scale of initial Hessian approximation 376 | H_diag = ys / y.dot(y) # (y*y) 377 | 378 | # compute the approximate (L-BFGS) inverse Hessian 379 | # multiplied by the gradient 380 | num_old = len(old_dirs) 381 | 382 | if 'al' not in state: 383 | state['al'] = [None] * history_size 384 | al = state['al'] 385 | 386 | # iteration in L-BFGS loop collapsed to use just one buffer 387 | q = flat_grad.neg() 388 | for i in range(num_old - 1, -1, -1): 389 | al[i] = old_stps[i].dot(q) * ro[i] 390 | q.add_(old_dirs[i], alpha=-al[i]) 391 | 392 | # multiply by initial Hessian 393 | # r/d is the final direction 394 | d = r = torch.mul(q, H_diag) 395 | for i in range(num_old): 396 | be_i = old_dirs[i].dot(r) * ro[i] 397 | r.add_(old_stps[i], alpha=al[i] - be_i) 398 | 399 | if prev_flat_grad is None: 400 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 401 | else: 402 | prev_flat_grad.copy_(flat_grad) 403 | prev_loss = loss 404 | 405 | ############################################################ 406 | # compute step length 407 | ############################################################ 408 | # reset initial guess for step size 409 | if state['n_iter'] == 1: 410 | t = min(1., 1. / flat_grad.abs().sum()) * lr 411 | else: 412 | t = lr 413 | 414 | # directional derivative 415 | gtd = flat_grad.dot(d) # g * d 416 | 417 | # directional derivative is below tolerance 418 | if gtd > -tolerance_change: 419 | break 420 | 421 | # optional line search: user function 422 | ls_func_evals = 0 423 | if line_search_fn is not None: 424 | # perform line search, using user function 425 | if line_search_fn != "strong_wolfe": 426 | raise RuntimeError("only 'strong_wolfe' is supported") 427 | else: 428 | x_init = self._clone_param() 429 | 430 | def obj_func(x, t, d): 431 | return self._directional_evaluate(closure, x, t, d) 432 | 433 | loss, flat_grad, t, ls_func_evals = _strong_wolfe( 434 | obj_func, x_init, t, d, loss, flat_grad, gtd) 435 | self._add_grad(t, d) 436 | opt_cond = flat_grad.abs().max() <= tolerance_grad 437 | else: 438 | # no line search, simply move with fixed-step 439 | self._add_grad(t, d) 440 | if n_iter != max_iter: 441 | # re-evaluate function only if not in last iteration 442 | # the reason we do this: in a stochastic setting, 443 | # no use to re-evaluate that function here 444 | with torch.enable_grad(): 445 | loss = float(closure()) 446 | flat_grad = self._gather_flat_grad() 447 | opt_cond = flat_grad.abs().max() <= tolerance_grad 448 | ls_func_evals = 1 449 | 450 | # update func eval 451 | current_evals += ls_func_evals 452 | state['func_evals'] += ls_func_evals 453 | 454 | ############################################################ 455 | # check conditions 456 | ############################################################ 457 | if n_iter == max_iter: 458 | break 459 | 460 | if current_evals >= max_eval: 461 | break 462 | 463 | # optimal condition 464 | if opt_cond: 465 | break 466 | 467 | # lack of progress 468 | if d.mul(t).abs().max() <= tolerance_change: 469 | break 470 | 471 | if abs(loss - prev_loss) < tolerance_change: 472 | break 473 | 474 | state['d'] = d 475 | state['t'] = t 476 | state['old_dirs'] = old_dirs 477 | state['old_stps'] = old_stps 478 | state['ro'] = ro 479 | state['H_diag'] = H_diag 480 | state['prev_flat_grad'] = prev_flat_grad 481 | state['prev_loss'] = prev_loss 482 | 483 | return orig_loss 484 | -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/Symbolic_KANLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sympy 5 | from utils import * 6 | 7 | 8 | class Symbolic_KANLayer(nn.Module): 9 | ''' 10 | KANLayer class 11 | 12 | Attributes: 13 | ----------- 14 | in_dim: int 15 | input dimension 16 | out_dim: int 17 | output dimension 18 | funs: 2D array of torch functions (or lambda functions) 19 | symbolic functions (torch) 20 | funs_name: 2D arry of str 21 | names of symbolic functions 22 | funs_sympy: 2D array of sympy functions (or lambda functions) 23 | symbolic functions (sympy) 24 | affine: 3D array of floats 25 | affine transformations of inputs and outputs 26 | 27 | Methods: 28 | -------- 29 | __init__(): 30 | initialize a Symbolic_KANLayer 31 | forward(): 32 | forward 33 | get_subset(): 34 | get subset of the KANLayer (used for pruning) 35 | fix_symbolic(): 36 | fix an activation function to be symbolic 37 | ''' 38 | def __init__(self, in_dim=3, out_dim=2, device='cpu'): 39 | ''' 40 | initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions) 41 | 42 | Args: 43 | ----- 44 | in_dim : int 45 | input dimension 46 | out_dim : int 47 | output dimension 48 | device : str 49 | device 50 | 51 | Returns: 52 | -------- 53 | self 54 | 55 | Example 56 | ------- 57 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3) 58 | >>> len(sb.funs), len(sb.funs[0]) 59 | (3, 3) 60 | ''' 61 | super(Symbolic_KANLayer, self).__init__() 62 | self.out_dim = out_dim 63 | self.in_dim = in_dim 64 | self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False) 65 | # torch 66 | self.funs = [[lambda x: x for i in range(self.in_dim)] for j in range(self.out_dim)] 67 | # name 68 | self.funs_name = [['' for i in range(self.in_dim)] for j in range(self.out_dim)] 69 | # sympy 70 | self.funs_sympy = [['' for i in range(self.in_dim)] for j in range(self.out_dim)] 71 | 72 | self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device)) 73 | # c*f(a*x+b)+d 74 | 75 | self.device = device 76 | 77 | def forward(self, x): 78 | ''' 79 | forward 80 | 81 | Args: 82 | ----- 83 | x : 2D array 84 | inputs, shape (batch, input dimension) 85 | 86 | Returns: 87 | -------- 88 | y : 2D array 89 | outputs, shape (batch, output dimension) 90 | postacts : 3D array 91 | activations after activation functions but before summing on nodes 92 | 93 | Example 94 | ------- 95 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5) 96 | >>> x = torch.normal(0,1,size=(100,3)) 97 | >>> y, postacts = sb(x) 98 | >>> y.shape, postacts.shape 99 | (torch.Size([100, 5]), torch.Size([100, 5, 3])) 100 | ''' 101 | 102 | batch = x.shape[0] 103 | postacts = [] 104 | 105 | for i in range(self.in_dim): 106 | postacts_ = [] 107 | for j in range(self.out_dim): 108 | xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3] 109 | postacts_.append(self.mask[j][i]*xij) 110 | postacts.append(torch.stack(postacts_)) 111 | 112 | postacts = torch.stack(postacts) 113 | postacts = postacts.permute(2,1,0,3)[:,:,:,0] 114 | y = torch.sum(postacts, dim=2) 115 | 116 | return y, postacts 117 | 118 | 119 | def get_subset(self, in_id, out_id): 120 | ''' 121 | get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning) 122 | 123 | Args: 124 | ----- 125 | in_id : list 126 | id of selected input neurons 127 | out_id : list 128 | id of selected output neurons 129 | 130 | Returns: 131 | -------- 132 | spb : Symbolic_KANLayer 133 | 134 | Example 135 | ------- 136 | >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10) 137 | >>> sb_small = sb_large.get_subset([0,9],[1,2,3]) 138 | >>> sb_small.in_dim, sb_small.out_dim 139 | (2, 3) 140 | ''' 141 | sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device) 142 | sbb.in_dim = len(in_id) 143 | sbb.out_dim = len(out_id) 144 | sbb.mask.data = self.mask.data[out_id][:,in_id] 145 | sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id] 146 | sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id] 147 | sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id] 148 | sbb.affine.data = self.affine.data[out_id][:,in_id] 149 | return sbb 150 | 151 | 152 | def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True): 153 | ''' 154 | fix an activation function to be symbolic 155 | 156 | Args: 157 | ----- 158 | i : int 159 | the id of input neuron 160 | j : int 161 | the id of output neuron 162 | fun_name : str 163 | the name of the symbolic functions 164 | x : 1D array 165 | preactivations 166 | y : 1D array 167 | postactivations 168 | a_range : tuple 169 | sweeping range of a 170 | b_range : tuple 171 | sweeping range of a 172 | verbose : bool 173 | print more information if True 174 | 175 | Returns: 176 | -------- 177 | r2 (coefficient of determination) 178 | 179 | Example 1 180 | --------- 181 | >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0 182 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) 183 | >>> sb.fix_symbolic(2,1,'sin') 184 | >>> print(sb.funs_name) 185 | >>> print(sb.affine) 186 | [['', '', ''], ['', '', 'sin']] 187 | Parameter containing: 188 | tensor([[0., 0., 0., 0.], 189 | [0., 0., 0., 0.], 190 | [1., 0., 1., 0.]], requires_grad=True) 191 | Example 2 192 | --------- 193 | >>> # when x & y are provided, fit_params() is called to find the best fit coefficients 194 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) 195 | >>> batch = 100 196 | >>> x = torch.linspace(-1,1,steps=batch) 197 | >>> noises = torch.normal(0,1,(batch,)) * 0.02 198 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises 199 | >>> sb.fix_symbolic(2,1,'sin',x,y) 200 | >>> print(sb.funs_name) 201 | >>> print(sb.affine[1,2,:].data) 202 | r2 is 0.9999701976776123 203 | [['', '', ''], ['', '', 'sin']] 204 | tensor([2.9981, 1.9997, 5.0039, 0.6978]) 205 | ''' 206 | if isinstance(fun_name,str): 207 | fun = SYMBOLIC_LIB[fun_name][0] 208 | fun_sympy = SYMBOLIC_LIB[fun_name][1] 209 | self.funs_sympy[j][i] = fun_sympy 210 | self.funs_name[j][i] = fun_name 211 | if x == None or y == None: 212 | #initialzie from just fun 213 | self.funs[j][i] = fun 214 | if random == False: 215 | self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) 216 | else: 217 | self.affine.data[j][i] = torch.rand(4,) * 2 - 1 218 | return None 219 | else: 220 | #initialize from x & y and fun 221 | params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device) 222 | self.funs[j][i] = fun 223 | self.affine.data[j][i] = params 224 | return r2 225 | else: 226 | # if fun_name itself is a function 227 | fun = fun_name 228 | fun_sympy = fun_name 229 | self.funs_sympy[j][i] = fun_sympy 230 | self.funs_name[j][i] = "anonymous" 231 | 232 | self.funs[j][i] = fun 233 | if random == False: 234 | self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) 235 | else: 236 | self.affine.data[j][i] = torch.rand(4,) * 2 - 1 237 | return None 238 | -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/__pycache__/LBFGS.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Original_KAN/__pycache__/LBFGS.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/__pycache__/Symbolic_KANLayer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Original_KAN/__pycache__/Symbolic_KANLayer.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/__pycache__/original_kan.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Original_KAN/__pycache__/original_kan.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/__pycache__/spline.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Original_KAN/__pycache__/spline.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omarrayyann/KAN-Conv2D/e9e69c8e73ac653b2d152761deb1886ff2aaf7d5/KAN_Implementations/Original_KAN/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/original_kan.py: -------------------------------------------------------------------------------- 1 | # Original KAN Implementation 2 | # From: https://github.com/KindXiaoming/pykan 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from spline import * 8 | from Symbolic_KANLayer import * 9 | from LBFGS import * 10 | import os 11 | import glob 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | import random 15 | import copy 16 | 17 | class KANLayer(nn.Module): 18 | """ 19 | KANLayer class 20 | 21 | 22 | Attributes: 23 | ----------- 24 | in_dim: int 25 | input dimension 26 | out_dim: int 27 | output dimension 28 | size: int 29 | the number of splines = input dimension * output dimension 30 | k: int 31 | the piecewise polynomial order of splines 32 | grid: 2D torch.float 33 | grid points 34 | noises: 2D torch.float 35 | injected noises to splines at initialization (to break degeneracy) 36 | coef: 2D torch.tensor 37 | coefficients of B-spline bases 38 | scale_base: 1D torch.float 39 | magnitude of the residual function b(x) 40 | scale_sp: 1D torch.float 41 | mangitude of the spline function spline(x) 42 | base_fun: fun 43 | residual function b(x) 44 | mask: 1D torch.float 45 | mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function. 46 | grid_eps: float in [0,1] 47 | a hyperparameter used in update_grid_from_samples. When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. 48 | weight_sharing: 1D tensor int 49 | allow spline activations to share parameters 50 | lock_counter: int 51 | counter how many activation functions are locked (weight sharing) 52 | lock_id: 1D torch.int 53 | the id of activation functions that are locked 54 | device: str 55 | device 56 | 57 | Methods: 58 | -------- 59 | __init__(): 60 | initialize a KANLayer 61 | forward(): 62 | forward 63 | update_grid_from_samples(): 64 | update grids based on samples' incoming activations 65 | initialize_grid_from_parent(): 66 | initialize grids from another model 67 | get_subset(): 68 | get subset of the KANLayer (used for pruning) 69 | lock(): 70 | lock several activation functions to share parameters 71 | unlock(): 72 | unlock already locked activation functions 73 | """ 74 | 75 | def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu',full_output=False): 76 | '''' 77 | initialize a KANLayer 78 | 79 | Args: 80 | ----- 81 | in_dim : int 82 | input dimension. Default: 2. 83 | out_dim : int 84 | output dimension. Default: 3. 85 | num : int 86 | the number of grid intervals = G. Default: 5. 87 | k : int 88 | the order of piecewise polynomial. Default: 3. 89 | noise_scale : float 90 | the scale of noise injected at initialization. Default: 0.1. 91 | scale_base : float 92 | the scale of the residual function b(x). Default: 1.0. 93 | scale_sp : float 94 | the scale of the base function spline(x). Default: 1.0. 95 | base_fun : function 96 | residual function b(x). Default: torch.nn.SiLU() 97 | grid_eps : float 98 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02. 99 | grid_range : list/np.array of shape (2,) 100 | setting the range of grids. Default: [-1,1]. 101 | sp_trainable : bool 102 | If true, scale_sp is trainable. Default: True. 103 | sb_trainable : bool 104 | If true, scale_base is trainable. Default: True. 105 | device : str 106 | device 107 | 108 | Returns: 109 | -------- 110 | self 111 | 112 | Example 113 | ------- 114 | >>> model = KANLayer(in_dim=3, out_dim=5) 115 | >>> (model.in_dim, model.out_dim) 116 | (3, 5) 117 | ''' 118 | super(KANLayer, self).__init__() 119 | # size 120 | self.size = size = out_dim * in_dim 121 | self.out_dim = out_dim 122 | self.in_dim = in_dim 123 | self.num = num 124 | self.k = k 125 | self.full_output = full_output 126 | 127 | # shape: (size, num) 128 | self.grid = torch.einsum('i,j->ij', torch.ones(size, device=device), torch.linspace(grid_range[0], grid_range[1], steps=num + 1, device=device)) 129 | self.grid = torch.nn.Parameter(self.grid).requires_grad_(False) 130 | noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num 131 | noises = noises.to(device) 132 | # shape: (size, coef) 133 | self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k, device)) 134 | if isinstance(scale_base, float): 135 | self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable 136 | else: 137 | self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).to(device)).requires_grad_(sb_trainable) 138 | self.scale_sp = torch.nn.Parameter(torch.ones(size, device=device) * scale_sp).requires_grad_(sp_trainable) # make scale trainable 139 | self.base_fun = base_fun 140 | 141 | self.mask = torch.nn.Parameter(torch.ones(size, device=device)).requires_grad_(False) 142 | self.grid_eps = grid_eps 143 | self.weight_sharing = torch.arange(size) 144 | self.lock_counter = 0 145 | self.lock_id = torch.zeros(size) 146 | self.device = device 147 | 148 | def forward(self, x): 149 | ''' 150 | KANLayer forward given input x 151 | 152 | Args: 153 | ----- 154 | x : 2D torch.float 155 | inputs, shape (number of samples, input dimension) 156 | 157 | Returns: 158 | -------- 159 | y : 2D torch.float 160 | outputs, shape (number of samples, output dimension) 161 | preacts : 3D torch.float 162 | fan out x into activations, shape (number of sampels, output dimension, input dimension) 163 | postacts : 3D torch.float 164 | the outputs of activation functions with preacts as inputs 165 | postspline : 3D torch.float 166 | the outputs of spline functions with preacts as inputs 167 | 168 | Example 169 | ------- 170 | >>> model = KANLayer(in_dim=3, out_dim=5) 171 | >>> x = torch.normal(0,1,size=(100,3)) 172 | >>> y, preacts, postacts, postspline = model(x) 173 | >>> y.shape, preacts.shape, postacts.shape, postspline.shape 174 | (torch.Size([100, 5]), 175 | torch.Size([100, 5, 3]), 176 | torch.Size([100, 5, 3]), 177 | torch.Size([100, 5, 3])) 178 | ''' 179 | batch = x.shape[0] 180 | # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim) 181 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, device=self.device)).reshape(batch, self.size).permute(1, 0) 182 | preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim) 183 | base = self.base_fun(x).permute(1, 0) # shape (batch, size) 184 | y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device) # shape (size, batch) 185 | y = y.permute(1, 0) # shape (batch, size) 186 | postspline = y.clone().reshape(batch, self.out_dim, self.in_dim) 187 | y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y 188 | y = self.mask[None, :] * y 189 | postacts = y.clone().reshape(batch, self.out_dim, self.in_dim) 190 | y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2) # shape (batch, out_dim) 191 | # y shape: (batch, out_dim); preacts shape: (batch, in_dim, out_dim) 192 | # postspline shape: (batch, in_dim, out_dim); postacts: (batch, in_dim, out_dim) 193 | # postspline is for extension; postacts is for visualization 194 | if self.full_output: 195 | return y, preacts, postacts, postspline 196 | else: 197 | return y 198 | 199 | def update_grid_from_samples(self, x): 200 | ''' 201 | update grid from samples 202 | 203 | Args: 204 | ----- 205 | x : 2D torch.float 206 | inputs, shape (number of samples, input dimension) 207 | 208 | Returns: 209 | -------- 210 | None 211 | 212 | Example 213 | ------- 214 | >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) 215 | >>> print(model.grid.data) 216 | >>> x = torch.linspace(-3,3,steps=100)[:,None] 217 | >>> model.update_grid_from_samples(x) 218 | >>> print(model.grid.data) 219 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 220 | tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]]) 221 | ''' 222 | batch = x.shape[0] 223 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0) 224 | x_pos = torch.sort(x, dim=1)[0] 225 | y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device) 226 | num_interval = self.grid.shape[1] - 1 227 | ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] 228 | grid_adaptive = x_pos[:, ids] 229 | margin = 0.01 230 | grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=self.grid.shape[1])], dim=1) 231 | self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 232 | self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device) 233 | 234 | def initialize_grid_from_parent(self, parent, x): 235 | ''' 236 | update grid from a parent KANLayer & samples 237 | 238 | Args: 239 | ----- 240 | parent : KANLayer 241 | a parent KANLayer (whose grid is usually coarser than the current model) 242 | x : 2D torch.float 243 | inputs, shape (number of samples, input dimension) 244 | 245 | Returns: 246 | -------- 247 | None 248 | 249 | Example 250 | ------- 251 | >>> batch = 100 252 | >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) 253 | >>> print(parent_model.grid.data) 254 | >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) 255 | >>> x = torch.normal(0,1,size=(batch, 1)) 256 | >>> model.initialize_grid_from_parent(parent_model, x) 257 | >>> print(model.grid.data) 258 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 259 | tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000, 260 | 0.6000, 0.8000, 1.0000]]) 261 | ''' 262 | batch = x.shape[0] 263 | # preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim) 264 | x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0) 265 | x_pos = parent.grid 266 | sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0., device=self.device) 267 | sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1, device=self.device) 268 | y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device) 269 | percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) 270 | self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) 271 | self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k, self.device) 272 | 273 | def get_subset(self, in_id, out_id): 274 | ''' 275 | get a smaller KANLayer from a larger KANLayer (used for pruning) 276 | 277 | Args: 278 | ----- 279 | in_id : list 280 | id of selected input neurons 281 | out_id : list 282 | id of selected output neurons 283 | 284 | Returns: 285 | -------- 286 | spb : KANLayer 287 | 288 | Example 289 | ------- 290 | >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) 291 | >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) 292 | >>> kanlayer_small.in_dim, kanlayer_small.out_dim 293 | (2, 3) 294 | ''' 295 | spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun, device=self.device) 296 | spb.grid.data = self.grid.reshape(self.out_dim, self.in_dim, spb.num + 1)[out_id][:, in_id].reshape(-1, spb.num + 1) 297 | spb.coef.data = self.coef.reshape(self.out_dim, self.in_dim, spb.coef.shape[1])[out_id][:, in_id].reshape(-1, spb.coef.shape[1]) 298 | spb.scale_base.data = self.scale_base.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 299 | spb.scale_sp.data = self.scale_sp.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 300 | spb.mask.data = self.mask.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 301 | 302 | spb.in_dim = len(in_id) 303 | spb.out_dim = len(out_id) 304 | spb.size = spb.in_dim * spb.out_dim 305 | return spb 306 | 307 | def lock(self, ids): 308 | ''' 309 | lock activation functions to share parameters based on ids 310 | 311 | Args: 312 | ----- 313 | ids : list 314 | list of ids of activation functions 315 | 316 | Returns: 317 | -------- 318 | None 319 | 320 | Example 321 | ------- 322 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3) 323 | >>> print(model.weight_sharing.reshape(3,3)) 324 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same 325 | >>> print(model.weight_sharing.reshape(3,3)) 326 | tensor([[0, 1, 2], 327 | [3, 4, 5], 328 | [6, 7, 8]]) 329 | tensor([[0, 1, 2], 330 | [3, 4, 0], 331 | [6, 0, 8]]) 332 | ''' 333 | self.lock_counter += 1 334 | # ids: [[i1,j1],[i2,j2],[i3,j3],...] 335 | for i in range(len(ids)): 336 | if i != 0: 337 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[0][1] * self.in_dim + ids[0][0] 338 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = self.lock_counter 339 | 340 | def unlock(self, ids): 341 | ''' 342 | unlock activation functions 343 | 344 | Args: 345 | ----- 346 | ids : list 347 | list of ids of activation functions 348 | 349 | Returns: 350 | -------- 351 | None 352 | 353 | Example 354 | ------- 355 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3) 356 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same 357 | >>> print(model.weight_sharing.reshape(3,3)) 358 | >>> model.unlock([[0,0],[1,2],[2,1]]) # unlock the locked functions 359 | >>> print(model.weight_sharing.reshape(3,3)) 360 | tensor([[0, 1, 2], 361 | [3, 4, 0], 362 | [6, 0, 8]]) 363 | tensor([[0, 1, 2], 364 | [3, 4, 5], 365 | [6, 7, 8]]) 366 | ''' 367 | # check ids are locked 368 | num = len(ids) 369 | locked = True 370 | for i in range(num): 371 | locked *= (self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] == self.weight_sharing[ids[0][1] * self.in_dim + ids[0][0]]) 372 | if locked == False: 373 | print("they are not locked. unlock failed.") 374 | return 0 375 | for i in range(len(ids)): 376 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[i][1] * self.in_dim + ids[i][0] 377 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = 0 378 | self.lock_counter -= 1 379 | 380 | class KAN(nn.Module): 381 | ''' 382 | KAN class 383 | 384 | Attributes: 385 | ----------- 386 | biases: a list of nn.Linear() 387 | biases are added on nodes (in principle, biases can be absorbed into activation functions. However, we still have them for better optimization) 388 | act_fun: a list of KANLayer 389 | KANLayers 390 | depth: int 391 | depth of KAN 392 | width: list 393 | number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons. 394 | grid: int 395 | the number of grid intervals 396 | k: int 397 | the order of piecewise polynomial 398 | base_fun: fun 399 | residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x) 400 | symbolic_fun: a list of Symbolic_KANLayer 401 | Symbolic_KANLayers 402 | symbolic_enabled: bool 403 | If False, the symbolic front is not computed (to save time). Default: True. 404 | 405 | Methods: 406 | -------- 407 | __init__(): 408 | initialize a KAN 409 | initialize_from_another_model(): 410 | initialize a KAN from another KAN (with the same shape, but potentially different grids) 411 | update_grid_from_samples(): 412 | update spline grids based on samples 413 | initialize_grid_from_another_model(): 414 | initalize KAN grids from another KAN 415 | forward(): 416 | forward 417 | set_mode(): 418 | set the mode of an activation function: 'n' for numeric, 's' for symbolic, 'ns' for combined (note they are visualized differently in plot(). 'n' as black, 's' as red, 'ns' as purple). 419 | fix_symbolic(): 420 | fix an activation function to be symbolic 421 | suggest_symbolic(): 422 | suggest the symbolic candicates of a numeric spline-based activation function 423 | lock(): 424 | lock activation functions to share parameters 425 | unlock(): 426 | unlock locked activations 427 | get_range(): 428 | get the input and output ranges of an activation function 429 | plot(): 430 | plot the diagram of KAN 431 | train(): 432 | train KAN 433 | prune(): 434 | prune KAN 435 | remove_edge(): 436 | remove some edge of KAN 437 | remove_node(): 438 | remove some node of KAN 439 | auto_symbolic(): 440 | automatically fit all splines to be symbolic functions 441 | symbolic_formula(): 442 | obtain the symbolic formula of the KAN network 443 | ''' 444 | 445 | def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(), symbolic_enabled=True, bias_trainable=True, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, 446 | device='cpu', seed=0): 447 | ''' 448 | initalize a KAN model 449 | 450 | Args: 451 | ----- 452 | width : list of int 453 | :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs) 454 | grid : int 455 | number of grid intervals. Default: 3. 456 | k : int 457 | order of piecewise polynomial. Default: 3. 458 | noise_scale : float 459 | initial injected noise to spline. Default: 0.1. 460 | base_fun : fun 461 | the residual function b(x). Default: torch.nn.SiLU(). 462 | symbolic_enabled : bool 463 | compute or skip symbolic computations (for efficiency). By default: True. 464 | bias_trainable : bool 465 | bias parameters are updated or not. By default: True 466 | grid_eps : float 467 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02. 468 | grid_range : list/np.array of shape (2,)) 469 | setting the range of grids. Default: [-1,1]. 470 | sp_trainable : bool 471 | If true, scale_sp is trainable. Default: True. 472 | sb_trainable : bool 473 | If true, scale_base is trainable. Default: True. 474 | device : str 475 | device 476 | seed : int 477 | random seed 478 | 479 | Returns: 480 | -------- 481 | self 482 | 483 | Example 484 | ------- 485 | >>> model = KAN(width=[2,5,1], grid=5, k=3) 486 | >>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), (model.act_fun[1].in_dim, model.act_fun[1].out_dim) 487 | ((2, 5), (5, 1)) 488 | ''' 489 | super(KAN, self).__init__() 490 | 491 | torch.manual_seed(seed) 492 | np.random.seed(seed) 493 | random.seed(seed) 494 | 495 | ### initializeing the numerical front ### 496 | 497 | self.biases = [] 498 | self.act_fun = [] 499 | self.depth = len(width) - 1 500 | self.width = width 501 | 502 | for l in range(self.depth): 503 | # splines 504 | scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base 505 | sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, 506 | sb_trainable=sb_trainable, device=device, full_output=True) 507 | self.act_fun.append(sp_batch) 508 | 509 | # bias 510 | bias = nn.Linear(width[l + 1], 1, bias=False, device=device).requires_grad_(bias_trainable) 511 | bias.weight.data *= 0. 512 | self.biases.append(bias) 513 | 514 | self.biases = nn.ModuleList(self.biases) 515 | self.act_fun = nn.ModuleList(self.act_fun) 516 | 517 | self.grid = grid 518 | self.k = k 519 | self.base_fun = base_fun 520 | 521 | ### initializing the symbolic front ### 522 | self.symbolic_fun = [] 523 | for l in range(self.depth): 524 | sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1], device=device) 525 | self.symbolic_fun.append(sb_batch) 526 | 527 | self.symbolic_fun = nn.ModuleList(self.symbolic_fun) 528 | self.symbolic_enabled = symbolic_enabled 529 | 530 | self.device = device 531 | 532 | def initialize_from_another_model(self, another_model, x): 533 | ''' 534 | initialize from a parent model. The parent has the same width as the current model but may have different grids. 535 | 536 | Args: 537 | ----- 538 | another_model : KAN 539 | the parent model used to initialize the current model 540 | x : 2D torch.float 541 | inputs, shape (batch, input dimension) 542 | 543 | Returns: 544 | -------- 545 | self : KAN 546 | 547 | Example 548 | ------- 549 | >>> model_coarse = KAN(width=[2,5,1], grid=5, k=3) 550 | >>> model_fine = KAN(width=[2,5,1], grid=10, k=3) 551 | >>> print(model_fine.act_fun[0].coef[0][0].data) 552 | >>> x = torch.normal(0,1,size=(100,2)) 553 | >>> model_fine.initialize_from_another_model(model_coarse, x); 554 | >>> print(model_fine.act_fun[0].coef[0][0].data) 555 | tensor(-0.0030) 556 | tensor(0.0506) 557 | ''' 558 | another_model(x.to(another_model.device)) # get activations 559 | batch = x.shape[0] 560 | 561 | self.initialize_grid_from_another_model(another_model, x.to(another_model.device)) 562 | 563 | for l in range(self.depth): 564 | spb = self.act_fun[l] 565 | spb_parent = another_model.act_fun[l] 566 | 567 | # spb = spb_parent 568 | preacts = another_model.spline_preacts[l] 569 | postsplines = another_model.spline_postsplines[l] 570 | self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, k=spb.k, device=self.device) 571 | spb.scale_base.data = spb_parent.scale_base.data 572 | spb.scale_sp.data = spb_parent.scale_sp.data 573 | spb.mask.data = spb_parent.mask.data 574 | # print(spb.mask.data, self.act_fun[l].mask.data) 575 | 576 | for l in range(self.depth): 577 | self.biases[l].weight.data = another_model.biases[l].weight.data 578 | 579 | for l in range(self.depth): 580 | self.symbolic_fun[l] = another_model.symbolic_fun[l] 581 | 582 | return self 583 | 584 | def update_grid_from_samples(self, x): 585 | ''' 586 | update grid from samples 587 | 588 | Args: 589 | ----- 590 | x : 2D torch.float 591 | inputs, shape (batch, input dimension) 592 | 593 | Returns: 594 | -------- 595 | None 596 | 597 | Example 598 | ------- 599 | >>> model = KAN(width=[2,5,1], grid=5, k=3) 600 | >>> print(model.act_fun[0].grid[0].data) 601 | >>> x = torch.rand(100,2)*5 602 | >>> model.update_grid_from_samples(x) 603 | >>> print(model.act_fun[0].grid[0].data) 604 | tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]) 605 | tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809]) 606 | ''' 607 | for l in range(self.depth): 608 | self.forward(x) 609 | self.act_fun[l].update_grid_from_samples(self.acts[l]) 610 | 611 | def initialize_grid_from_another_model(self, model, x): 612 | ''' 613 | initialize grid from a parent model 614 | 615 | Args: 616 | ----- 617 | model : KAN 618 | parent model 619 | x : 2D torch.float 620 | inputs, shape (batch, input dimension) 621 | 622 | Returns: 623 | -------- 624 | None 625 | 626 | Example 627 | ------- 628 | >>> model_parent = KAN(width=[1,1], grid=5, k=3) 629 | >>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)[None,:] 630 | >>> x = torch.linspace(-2,2,steps=1001)[:,None] 631 | >>> model = KAN(width=[1,1], grid=5, k=3) 632 | >>> print(model.act_fun[0].grid.data) 633 | >>> model = model.initialize_from_another_model(model_parent, x) 634 | >>> print(model.act_fun[0].grid.data) 635 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 636 | tensor([[-2.0000, -1.2000, -0.4000, 0.4000, 1.2000, 2.0000]]) 637 | ''' 638 | model(x) 639 | for l in range(self.depth): 640 | self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l]) 641 | 642 | def forward(self, x): 643 | ''' 644 | KAN forward 645 | 646 | Args: 647 | ----- 648 | x : 2D torch.float 649 | inputs, shape (batch, input dimension) 650 | 651 | Returns: 652 | -------- 653 | y : 2D torch.float 654 | outputs, shape (batch, output dimension) 655 | 656 | Example 657 | ------- 658 | >>> model = KAN(width=[2,5,3], grid=5, k=3) 659 | >>> x = torch.normal(0,1,size=(100,2)) 660 | >>> model(x).shape 661 | torch.Size([100, 3]) 662 | ''' 663 | 664 | self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L]) 665 | self.spline_preacts = [] 666 | self.spline_postsplines = [] 667 | self.spline_postacts = [] 668 | self.acts_scale = [] 669 | self.acts_scale_std = [] 670 | # self.neurons_scale = [] 671 | self.acts.append(x) # acts shape: (batch, width[l]) 672 | 673 | for l in range(self.depth): 674 | 675 | x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x) 676 | 677 | if self.symbolic_enabled == True: 678 | x_symbolic, postacts_symbolic = self.symbolic_fun[l](x) 679 | else: 680 | x_symbolic = 0. 681 | postacts_symbolic = 0. 682 | 683 | x = x_numerical + x_symbolic 684 | postacts = postacts_numerical + postacts_symbolic 685 | 686 | # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0)) 687 | grid_reshape = self.act_fun[l].grid.reshape(self.width[l + 1], self.width[l], -1) 688 | input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4 689 | output_range = torch.mean(torch.abs(postacts), dim=0) 690 | self.acts_scale.append(output_range / input_range) 691 | self.acts_scale_std.append(torch.std(postacts, dim=0)) 692 | self.spline_preacts.append(preacts.detach()) 693 | self.spline_postacts.append(postacts.detach()) 694 | self.spline_postsplines.append(postspline.detach()) 695 | 696 | x = x + self.biases[l].weight 697 | self.acts.append(x) 698 | 699 | return x 700 | 701 | def set_mode(self, l, i, j, mode, mask_n=None): 702 | ''' 703 | set (l,i,j) activation to have mode 704 | 705 | Args: 706 | ----- 707 | l : int 708 | layer index 709 | i : int 710 | input neuron index 711 | j : int 712 | output neuron index 713 | mode : str 714 | 'n' (numeric) or 's' (symbolic) or 'ns' (combined) 715 | mask_n : None or float) 716 | magnitude of the numeric front 717 | 718 | Returns: 719 | -------- 720 | None 721 | ''' 722 | if mode == "s": 723 | mask_n = 0.; 724 | mask_s = 1. 725 | elif mode == "n": 726 | mask_n = 1.; 727 | mask_s = 0. 728 | elif mode == "sn" or mode == "ns": 729 | if mask_n == None: 730 | mask_n = 1. 731 | else: 732 | mask_n = mask_n 733 | mask_s = 1. 734 | else: 735 | mask_n = 0.; 736 | mask_s = 0. 737 | 738 | self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n 739 | self.symbolic_fun[l].mask.data[j, i] = mask_s 740 | 741 | def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False): 742 | ''' 743 | set (l,i,j) activation to be symbolic (specified by fun_name) 744 | 745 | Args: 746 | ----- 747 | l : int 748 | layer index 749 | i : int 750 | input neuron index 751 | j : int 752 | output neuron index 753 | fun_name : str 754 | function name 755 | fit_params_bool : bool 756 | obtaining affine parameters through fitting (True) or setting default values (False) 757 | a_range : tuple 758 | sweeping range of a 759 | b_range : tuple 760 | sweeping range of b 761 | verbose : bool 762 | If True, more information is printed. 763 | random : bool 764 | initialize affine parameteres randomly or as [1,0,1,0] 765 | 766 | Returns: 767 | -------- 768 | None or r2 (coefficient of determination) 769 | 770 | Example 1 771 | --------- 772 | >>> # when fit_params_bool = False 773 | >>> model = KAN(width=[2,5,1], grid=5, k=3) 774 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False) 775 | >>> print(model.act_fun[0].mask.reshape(2,5)) 776 | >>> print(model.symbolic_fun[0].mask.reshape(2,5)) 777 | tensor([[1., 1., 1., 1., 1.], 778 | [1., 1., 0., 1., 1.]]) 779 | tensor([[0., 0., 0., 0., 0.], 780 | [0., 0., 1., 0., 0.]]) 781 | 782 | Example 2 783 | --------- 784 | >>> # when fit_params_bool = True 785 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) 786 | >>> x = torch.normal(0,1,size=(100,2)) 787 | >>> model(x) # obtain activations (otherwise model does not have attributes acts) 788 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True) 789 | >>> print(model.act_fun[0].mask.reshape(2,5)) 790 | >>> print(model.symbolic_fun[0].mask.reshape(2,5)) 791 | r2 is 0.8131332993507385 792 | r2 is not very high, please double check if you are choosing the correct symbolic function. 793 | tensor([[1., 1., 1., 1., 1.], 794 | [1., 1., 0., 1., 1.]]) 795 | tensor([[0., 0., 0., 0., 0.], 796 | [0., 0., 1., 0., 0.]]) 797 | ''' 798 | self.set_mode(l, i, j, mode="s") 799 | if not fit_params_bool: 800 | self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random) 801 | return None 802 | else: 803 | x = self.acts[l][:, i] 804 | y = self.spline_postacts[l][:, j, i] 805 | r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range, verbose=verbose) 806 | return r2 807 | 808 | def unfix_symbolic(self, l, i, j): 809 | ''' 810 | unfix the (l,i,j) activation function. 811 | ''' 812 | self.set_mode(l, i, j, mode="n") 813 | 814 | def unfix_symbolic_all(self): 815 | ''' 816 | unfix all activation functions. 817 | ''' 818 | for l in range(len(self.width) - 1): 819 | for i in range(self.width[l]): 820 | for j in range(self.width[l + 1]): 821 | self.unfix_symbolic(l, i, j) 822 | 823 | def lock(self, l, ids): 824 | ''' 825 | lock ids in the l-th layer to be the same function 826 | 827 | Args: 828 | ----- 829 | l : int 830 | layer index 831 | ids : 2D list 832 | :math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), (l,i_2,j_2), ...` to be the same function 833 | 834 | Returns: 835 | -------- 836 | None 837 | 838 | Example 839 | ------- 840 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 841 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 842 | >>> model.lock(0,[[1,0],[1,1]]) 843 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 844 | tensor([[0, 1], 845 | [2, 3], 846 | [4, 5]]) 847 | tensor([[0, 1], 848 | [2, 1], 849 | [4, 5]]) 850 | ''' 851 | self.act_fun[l].lock(ids) 852 | 853 | def unlock(self, l, ids): 854 | ''' 855 | unlock ids in the l-th layer to be the same function 856 | 857 | Args: 858 | ----- 859 | l : int 860 | layer index 861 | ids : 2D list) 862 | [[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be unlocked 863 | 864 | Example: 865 | -------- 866 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 867 | >>> model.lock(0,[[1,0],[1,1]]) 868 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 869 | >>> model.unlock(0,[[1,0],[1,1]]) 870 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 871 | tensor([[0, 1], 872 | [2, 1], 873 | [4, 5]]) 874 | tensor([[0, 1], 875 | [2, 3], 876 | [4, 5]]) 877 | ''' 878 | self.act_fun[l].unlock(ids) 879 | 880 | def get_range(self, l, i, j, verbose=True): 881 | ''' 882 | Get the input range and output range of the (l,i,j) activation 883 | 884 | Args: 885 | ----- 886 | l : int 887 | layer index 888 | i : int 889 | input neuron index 890 | j : int 891 | output neuron index 892 | 893 | Returns: 894 | -------- 895 | x_min : float 896 | minimum of input 897 | x_max : float 898 | maximum of input 899 | y_min : float 900 | minimum of output 901 | y_max : float 902 | maximum of output 903 | 904 | Example 905 | ------- 906 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 907 | >>> x = torch.normal(0,1,size=(100,2)) 908 | >>> model(x) # do a forward pass to obtain model.acts 909 | >>> model.get_range(0,0,0) 910 | x range: [-2.13 , 2.75 ] 911 | y range: [-0.50 , 1.83 ] 912 | (tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275)) 913 | ''' 914 | x = self.spline_preacts[l][:, j, i] 915 | y = self.spline_postacts[l][:, j, i] 916 | x_min = torch.min(x) 917 | x_max = torch.max(x) 918 | y_min = torch.min(y) 919 | y_max = torch.max(y) 920 | if verbose: 921 | print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']') 922 | print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']') 923 | return x_min, x_max, y_min, y_max 924 | 925 | def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None): 926 | ''' 927 | plot KAN 928 | 929 | Args: 930 | ----- 931 | folder : str 932 | the folder to store pngs 933 | beta : float 934 | positive number. control the transparency of each activation. transparency = tanh(beta*l1). 935 | mask : bool 936 | If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions. 937 | mode : bool 938 | "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean). 939 | scale : float 940 | control the size of the diagram 941 | in_vars: None or list of str 942 | the name(s) of input variables 943 | out_vars: None or list of str 944 | the name(s) of output variables 945 | title: None or str 946 | title 947 | 948 | Returns: 949 | -------- 950 | Figure 951 | 952 | Example 953 | ------- 954 | >>> # see more interactive examples in demos 955 | >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) 956 | >>> x = torch.normal(0,1,size=(100,2)) 957 | >>> model(x) # do a forward pass to obtain model.acts 958 | >>> model.plot() 959 | ''' 960 | if not os.path.exists(folder): 961 | os.makedirs(folder) 962 | # matplotlib.use('Agg') 963 | depth = len(self.width) - 1 964 | for l in range(depth): 965 | w_large = 2.0 966 | for i in range(self.width[l]): 967 | for j in range(self.width[l + 1]): 968 | rank = torch.argsort(self.acts[l][:, i]) 969 | fig, ax = plt.subplots(figsize=(w_large, w_large)) 970 | 971 | num = rank.shape[0] 972 | 973 | symbol_mask = self.symbolic_fun[l].mask[j][i] 974 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i] 975 | if symbol_mask > 0. and numerical_mask > 0.: 976 | color = 'purple' 977 | alpha_mask = 1 978 | if symbol_mask > 0. and numerical_mask == 0.: 979 | color = "red" 980 | alpha_mask = 1 981 | if symbol_mask == 0. and numerical_mask > 0.: 982 | color = "black" 983 | alpha_mask = 1 984 | if symbol_mask == 0. and numerical_mask == 0.: 985 | color = "white" 986 | alpha_mask = 0 987 | 988 | if tick == True: 989 | ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50) 990 | ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50) 991 | x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False) 992 | plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max]) 993 | plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max]) 994 | else: 995 | plt.xticks([]) 996 | plt.yticks([]) 997 | if alpha_mask == 1: 998 | plt.gca().patch.set_edgecolor('black') 999 | else: 1000 | plt.gca().patch.set_edgecolor('white') 1001 | plt.gca().patch.set_linewidth(1.5) 1002 | # plt.axis('off') 1003 | 1004 | plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5) 1005 | if sample == True: 1006 | plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, s=400 * scale ** 2) 1007 | plt.gca().spines[:].set_color(color) 1008 | 1009 | lock_id = self.act_fun[l].lock_id[j * self.width[l] + i].long().item() 1010 | if lock_id > 0: 1011 | im = plt.imread(f'{folder}/lock.png') 1012 | newax = fig.add_axes([0.15, 0.7, 0.15, 0.15]) 1013 | plt.text(500, 400, lock_id, fontsize=15) 1014 | newax.imshow(im) 1015 | newax.axis('off') 1016 | 1017 | plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400) 1018 | plt.close() 1019 | 1020 | def score2alpha(score): 1021 | return np.tanh(beta * score) 1022 | 1023 | if mode == "supervised": 1024 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale] 1025 | elif mode == "unsupervised": 1026 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale_std] 1027 | 1028 | # draw skeleton 1029 | width = np.array(self.width) 1030 | A = 1 1031 | y0 = 0.4 # 0.4 1032 | 1033 | # plt.figure(figsize=(5,5*(neuron_depth-1)*y0)) 1034 | neuron_depth = len(width) 1035 | min_spacing = A / np.maximum(np.max(width), 5) 1036 | 1037 | max_neuron = np.max(width) 1038 | max_num_weights = np.max(width[:-1] * width[1:]) 1039 | y1 = 0.4 / np.maximum(max_num_weights, 3) 1040 | 1041 | fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0)) 1042 | # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0)) 1043 | 1044 | # plot scatters and lines 1045 | for l in range(neuron_depth): 1046 | n = width[l] 1047 | spacing = A / n 1048 | for i in range(n): 1049 | plt.scatter(1 / (2 * n) + i / n, l * y0, s=min_spacing ** 2 * 10000 * scale ** 2, color='black') 1050 | 1051 | if l < neuron_depth - 1: 1052 | # plot connections 1053 | n_next = width[l + 1] 1054 | N = n * n_next 1055 | for j in range(n_next): 1056 | id_ = i * n_next + j 1057 | 1058 | symbol_mask = self.symbolic_fun[l].mask[j][i] 1059 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i] 1060 | if symbol_mask == 1. and numerical_mask == 1.: 1061 | color = 'purple' 1062 | alpha_mask = 1. 1063 | if symbol_mask == 1. and numerical_mask == 0.: 1064 | color = "red" 1065 | alpha_mask = 1. 1066 | if symbol_mask == 0. and numerical_mask == 1.: 1067 | color = "black" 1068 | alpha_mask = 1. 1069 | if symbol_mask == 0. and numerical_mask == 0.: 1070 | color = "white" 1071 | alpha_mask = 0. 1072 | if mask == True: 1073 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 1074 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 1075 | else: 1076 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) 1077 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) 1078 | 1079 | plt.xlim(0, 1) 1080 | plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0) 1081 | 1082 | # -- Transformation functions 1083 | DC_to_FC = ax.transData.transform 1084 | FC_to_NFC = fig.transFigure.inverted().transform 1085 | # -- Take data coordinates and transform them to normalized figure coordinates 1086 | DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x)) 1087 | 1088 | plt.axis('off') 1089 | 1090 | # plot splines 1091 | for l in range(neuron_depth - 1): 1092 | n = width[l] 1093 | for i in range(n): 1094 | n_next = width[l + 1] 1095 | N = n * n_next 1096 | for j in range(n_next): 1097 | id_ = i * n_next + j 1098 | im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png') 1099 | left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0] 1100 | right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0] 1101 | bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1] 1102 | up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1] 1103 | newax = fig.add_axes([left, bottom, right - left, up - bottom]) 1104 | # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE') 1105 | if mask == False: 1106 | newax.imshow(im, alpha=alpha[l][j][i]) 1107 | else: 1108 | ### make sure to run model.prune() first to compute mask ### 1109 | newax.imshow(im, alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 1110 | newax.axis('off') 1111 | 1112 | if in_vars != None: 1113 | n = self.width[0] 1114 | for i in range(n): 1115 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') 1116 | 1117 | if out_vars != None: 1118 | n = self.width[-1] 1119 | for i in range(n): 1120 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), y0 * (len(self.width) - 1) + 0.1, out_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') 1121 | 1122 | if title != None: 1123 | plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') 1124 | 1125 | def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, batch=-1, 1126 | small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'): 1127 | ''' 1128 | training 1129 | 1130 | Args: 1131 | ----- 1132 | dataset : dic 1133 | contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] 1134 | opt : str 1135 | "LBFGS" or "Adam" 1136 | steps : int 1137 | training steps 1138 | log : int 1139 | logging frequency 1140 | lamb : float 1141 | overall penalty strength 1142 | lamb_l1 : float 1143 | l1 penalty strength 1144 | lamb_entropy : float 1145 | entropy penalty strength 1146 | lamb_coef : float 1147 | coefficient magnitude penalty strength 1148 | lamb_coefdiff : float 1149 | difference of nearby coefficits (smoothness) penalty strength 1150 | update_grid : bool 1151 | If True, update grid regularly before stop_grid_update_step 1152 | grid_update_num : int 1153 | the number of grid updates before stop_grid_update_step 1154 | stop_grid_update_step : int 1155 | no grid updates after this training step 1156 | batch : int 1157 | batch size, if -1 then full. 1158 | small_mag_threshold : float 1159 | threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers) 1160 | small_reg_factor : float 1161 | penalty strength applied to small factors relative to large factos 1162 | device : str 1163 | device 1164 | save_fig_freq : int 1165 | save figure every (save_fig_freq) step 1166 | 1167 | Returns: 1168 | -------- 1169 | results : dic 1170 | results['train_loss'], 1D array of training losses (RMSE) 1171 | results['test_loss'], 1D array of test losses (RMSE) 1172 | results['reg'], 1D array of regularization 1173 | 1174 | Example 1175 | ------- 1176 | >>> # for interactive examples, please see demos 1177 | >>> from utils import create_dataset 1178 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1179 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1180 | >>> dataset = create_dataset(f, n_var=2) 1181 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1182 | >>> model.plot() 1183 | ''' 1184 | 1185 | def reg(acts_scale): 1186 | 1187 | def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): 1188 | return (x < th) * x * factor + (x > th) * (x + (factor - 1) * th) 1189 | 1190 | reg_ = 0. 1191 | for i in range(len(acts_scale)): 1192 | vec = acts_scale[i].reshape(-1, ) 1193 | 1194 | p = vec / torch.sum(vec) 1195 | l1 = torch.sum(nonlinear(vec)) 1196 | entropy = - torch.sum(p * torch.log2(p + 1e-4)) 1197 | reg_ += lamb_l1 * l1 + lamb_entropy * entropy # both l1 and entropy 1198 | 1199 | # regularize coefficient to encourage spline to be zero 1200 | for i in range(len(self.act_fun)): 1201 | coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1)) 1202 | coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1)) 1203 | reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1 1204 | 1205 | return reg_ 1206 | 1207 | pbar = tqdm(range(steps), desc='description', ncols=100) 1208 | 1209 | if loss_fn == None: 1210 | loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) 1211 | else: 1212 | loss_fn = loss_fn_eval = loss_fn 1213 | 1214 | grid_update_freq = int(stop_grid_update_step / grid_update_num) 1215 | 1216 | if opt == "Adam": 1217 | optimizer = torch.optim.Adam(self.parameters(), lr=lr) 1218 | elif opt == "LBFGS": 1219 | optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) 1220 | 1221 | results = {} 1222 | results['train_loss'] = [] 1223 | results['test_loss'] = [] 1224 | results['reg'] = [] 1225 | if metrics != None: 1226 | for i in range(len(metrics)): 1227 | results[metrics[i].__name__] = [] 1228 | 1229 | if batch == -1 or batch > dataset['train_input'].shape[0]: 1230 | batch_size = dataset['train_input'].shape[0] 1231 | batch_size_test = dataset['test_input'].shape[0] 1232 | else: 1233 | batch_size = batch 1234 | batch_size_test = batch 1235 | 1236 | global train_loss, reg_ 1237 | 1238 | def closure(): 1239 | global train_loss, reg_ 1240 | optimizer.zero_grad() 1241 | pred = self.forward(dataset['train_input'][train_id].to(device)) 1242 | if sglr_avoid == True: 1243 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0] 1244 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device)) 1245 | else: 1246 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device)) 1247 | reg_ = reg(self.acts_scale) 1248 | objective = train_loss + lamb * reg_ 1249 | objective.backward() 1250 | return objective 1251 | 1252 | if save_fig: 1253 | if not os.path.exists(img_folder): 1254 | os.makedirs(img_folder) 1255 | 1256 | for _ in pbar: 1257 | 1258 | train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) 1259 | test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) 1260 | 1261 | if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid: 1262 | self.update_grid_from_samples(dataset['train_input'][train_id].to(device)) 1263 | 1264 | if opt == "LBFGS": 1265 | optimizer.step(closure) 1266 | 1267 | if opt == "Adam": 1268 | pred = self.forward(dataset['train_input'][train_id].to(device)) 1269 | if sglr_avoid == True: 1270 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0] 1271 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device)) 1272 | else: 1273 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device)) 1274 | reg_ = reg(self.acts_scale) 1275 | loss = train_loss + lamb * reg_ 1276 | optimizer.zero_grad() 1277 | loss.backward() 1278 | optimizer.step() 1279 | 1280 | test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device)) 1281 | 1282 | if _ % log == 0: 1283 | pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) 1284 | 1285 | if metrics != None: 1286 | for i in range(len(metrics)): 1287 | results[metrics[i].__name__].append(metrics[i]().item()) 1288 | 1289 | results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) 1290 | results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) 1291 | results['reg'].append(reg_.cpu().detach().numpy()) 1292 | 1293 | if save_fig and _ % save_fig_freq == 0: 1294 | self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta) 1295 | plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200) 1296 | plt.close() 1297 | 1298 | return results 1299 | 1300 | def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None): 1301 | ''' 1302 | pruning KAN on the node level. If a node has small incoming or outgoing connection, it will be pruned away. 1303 | 1304 | Args: 1305 | ----- 1306 | threshold : float 1307 | the threshold used to determine whether a node is small enough 1308 | mode : str 1309 | "auto" or "manual". If "auto", the thresold will be used to automatically prune away nodes. If "manual", active_neuron_id is needed to specify which neurons are kept (others are thrown away). 1310 | active_neuron_id : list of id lists 1311 | For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet. 1312 | 1313 | Returns: 1314 | -------- 1315 | model2 : KAN 1316 | pruned model 1317 | 1318 | Example 1319 | ------- 1320 | >>> # for more interactive examples, please see demos 1321 | >>> from utils import create_dataset 1322 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1323 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1324 | >>> dataset = create_dataset(f, n_var=2) 1325 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1326 | >>> model.prune() 1327 | >>> model.plot(mask=True) 1328 | ''' 1329 | mask = [torch.ones(self.width[0], )] 1330 | active_neurons = [list(range(self.width[0]))] 1331 | for i in range(len(self.acts_scale) - 1): 1332 | if mode == "auto": 1333 | in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold 1334 | out_important = torch.max(self.acts_scale[i + 1], dim=0)[0] > threshold 1335 | overall_important = in_important * out_important 1336 | elif mode == "manual": 1337 | overall_important = torch.zeros(self.width[i + 1], dtype=torch.bool) 1338 | overall_important[active_neurons_id[i + 1]] = True 1339 | mask.append(overall_important.float()) 1340 | active_neurons.append(torch.where(overall_important == True)[0]) 1341 | active_neurons.append(list(range(self.width[-1]))) 1342 | mask.append(torch.ones(self.width[-1], )) 1343 | 1344 | self.mask = mask # this is neuron mask for the whole model 1345 | 1346 | # update act_fun[l].mask 1347 | for l in range(len(self.acts_scale) - 1): 1348 | for i in range(self.width[l + 1]): 1349 | if i not in active_neurons[l + 1]: 1350 | self.remove_node(l + 1, i) 1351 | 1352 | model2 = KAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun, device=self.device) 1353 | model2.load_state_dict(self.state_dict()) 1354 | for i in range(len(self.acts_scale)): 1355 | if i < len(self.acts_scale) - 1: 1356 | model2.biases[i].weight.data = model2.biases[i].weight.data[:, active_neurons[i + 1]] 1357 | 1358 | model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i], active_neurons[i + 1]) 1359 | model2.width[i] = len(active_neurons[i]) 1360 | model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons[i], active_neurons[i + 1]) 1361 | 1362 | return model2 1363 | 1364 | def remove_edge(self, l, i, j): 1365 | ''' 1366 | remove activtion phi(l,i,j) (set its mask to zero) 1367 | 1368 | Args: 1369 | ----- 1370 | l : int 1371 | layer index 1372 | i : int 1373 | input neuron index 1374 | j : int 1375 | output neuron index 1376 | 1377 | Returns: 1378 | -------- 1379 | None 1380 | ''' 1381 | self.act_fun[l].mask[j * self.width[l] + i] = 0. 1382 | 1383 | def remove_node(self, l, i): 1384 | ''' 1385 | remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero) 1386 | 1387 | Args: 1388 | ----- 1389 | l : int 1390 | layer index 1391 | i : int 1392 | neuron index 1393 | 1394 | Returns: 1395 | -------- 1396 | None 1397 | ''' 1398 | self.act_fun[l - 1].mask[i * self.width[l - 1] + torch.arange(self.width[l - 1])] = 0. 1399 | self.act_fun[l].mask[torch.arange(self.width[l + 1]) * self.width[l] + i] = 0. 1400 | self.symbolic_fun[l - 1].mask[i, :] *= 0. 1401 | self.symbolic_fun[l].mask[:, i] *= 0. 1402 | 1403 | def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True): 1404 | '''suggest the symbolic candidates of phi(l,i,j) 1405 | 1406 | Args: 1407 | ----- 1408 | l : int 1409 | layer index 1410 | i : int 1411 | input neuron index 1412 | j : int 1413 | output neuron index 1414 | lib : dic 1415 | library of symbolic bases. If lib = None, the global default library will be used. 1416 | topk : int 1417 | display the top k symbolic functions (according to r2) 1418 | verbose : bool 1419 | If True, more information will be printed. 1420 | 1421 | Returns: 1422 | -------- 1423 | None 1424 | 1425 | Example 1426 | ------- 1427 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1428 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1429 | >>> dataset = create_dataset(f, n_var=2) 1430 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1431 | >>> model = model.prune() 1432 | >>> model(dataset['train_input']) 1433 | >>> model.suggest_symbolic(0,0,0) 1434 | function , r2 1435 | sin , 0.9994412064552307 1436 | gaussian , 0.9196369051933289 1437 | tanh , 0.8608126044273376 1438 | sigmoid , 0.8578218817710876 1439 | arctan , 0.842217743396759 1440 | ''' 1441 | r2s = [] 1442 | 1443 | if lib == None: 1444 | symbolic_lib = SYMBOLIC_LIB 1445 | else: 1446 | symbolic_lib = {} 1447 | for item in lib: 1448 | symbolic_lib[item] = SYMBOLIC_LIB[item] 1449 | 1450 | for (name, fun) in symbolic_lib.items(): 1451 | r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False) 1452 | r2s.append(r2.item()) 1453 | 1454 | self.unfix_symbolic(l, i, j) 1455 | 1456 | sorted_ids = np.argsort(r2s)[::-1][:topk] 1457 | r2s = np.array(r2s)[sorted_ids][:topk] 1458 | topk = np.minimum(topk, len(symbolic_lib)) 1459 | if verbose == True: 1460 | print('function', ',', 'r2') 1461 | for i in range(topk): 1462 | print(list(symbolic_lib.items())[sorted_ids[i]][0], ',', r2s[i]) 1463 | 1464 | best_name = list(symbolic_lib.items())[sorted_ids[0]][0] 1465 | best_fun = list(symbolic_lib.items())[sorted_ids[0]][1] 1466 | best_r2 = r2s[0] 1467 | return best_name, best_fun, best_r2 1468 | 1469 | def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1): 1470 | ''' 1471 | automatic symbolic regression: using top 1 suggestion from suggest_symbolic to replace splines with symbolic activations 1472 | 1473 | Args: 1474 | ----- 1475 | lib : None or a list of function names 1476 | the symbolic library 1477 | verbose : int 1478 | verbosity 1479 | 1480 | Returns: 1481 | -------- 1482 | None (print suggested symbolic formulas) 1483 | 1484 | Example 1 1485 | --------- 1486 | >>> # default library 1487 | >>> from utils import create_dataset 1488 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1489 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1490 | >>> dataset = create_dataset(f, n_var=2) 1491 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1492 | >>> >>> model = model.prune() 1493 | >>> model(dataset['train_input']) 1494 | >>> model.auto_symbolic() 1495 | fixing (0,0,0) with sin, r2=0.9994837045669556 1496 | fixing (0,1,0) with cosh, r2=0.9978033900260925 1497 | fixing (1,0,0) with arctan, r2=0.9997088313102722 1498 | 1499 | Example 2 1500 | --------- 1501 | >>> # customized library 1502 | >>> from utils import create_dataset 1503 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1504 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1505 | >>> dataset = create_dataset(f, n_var=2) 1506 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1507 | >>> >>> model = model.prune() 1508 | >>> model(dataset['train_input']) 1509 | >>> model.auto_symbolic(lib=['exp','sin','x^2']) 1510 | fixing (0,0,0) with sin, r2=0.999411404132843 1511 | fixing (0,1,0) with x^2, r2=0.9962921738624573 1512 | fixing (1,0,0) with exp, r2=0.9980258941650391 1513 | ''' 1514 | for l in range(len(self.width) - 1): 1515 | for i in range(self.width[l]): 1516 | for j in range(self.width[l + 1]): 1517 | if self.symbolic_fun[l].mask[j, i] > 0.: 1518 | print(f'skipping ({l},{i},{j}) since already symbolic') 1519 | else: 1520 | name, fun, r2 = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False) 1521 | self.fix_symbolic(l, i, j, name, verbose=verbose > 1) 1522 | if verbose >= 1: 1523 | print(f'fixing ({l},{i},{j}) with {name}, r2={r2}') 1524 | 1525 | def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify=False, output_normalizer = None ): 1526 | ''' 1527 | obtain the symbolic formula 1528 | 1529 | Args: 1530 | ----- 1531 | floating_digit : int 1532 | the number of digits to display 1533 | var : list of str 1534 | the name of variables (if not provided, by default using ['x_1', 'x_2', ...]) 1535 | normalizer : [mean array (floats), varaince array (floats)] 1536 | the normalization applied to inputs 1537 | simplify : bool 1538 | If True, simplify the equation at each step (usually quite slow), so set up False by default. 1539 | output_normalizer: [mean array (floats), varaince array (floats)] 1540 | the normalization applied to outputs 1541 | 1542 | Returns: 1543 | -------- 1544 | symbolic formula : sympy function 1545 | 1546 | Example 1547 | ------- 1548 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02) 1549 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1550 | >>> dataset = create_dataset(f, n_var=2) 1551 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1552 | >>> model = model.prune() 1553 | >>> model(dataset['train_input']) 1554 | >>> model.auto_symbolic(lib=['exp','sin','x^2']) 1555 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False); 1556 | >>> model.symbolic_formula() 1557 | ''' 1558 | symbolic_acts = [] 1559 | x = [] 1560 | 1561 | def ex_round(ex1, floating_digit=floating_digit): 1562 | ex2 = ex1 1563 | for a in sympy.preorder_traversal(ex1): 1564 | if isinstance(a, sympy.Float): 1565 | ex2 = ex2.subs(a, round(a, floating_digit)) 1566 | return ex2 1567 | 1568 | # define variables 1569 | if var == None: 1570 | for ii in range(1, self.width[0] + 1): 1571 | exec(f"x{ii} = sympy.Symbol('x_{ii}')") 1572 | exec(f"x.append(x{ii})") 1573 | else: 1574 | x = [sympy.symbols(var_) for var_ in var] 1575 | 1576 | x0 = x 1577 | 1578 | if normalizer != None: 1579 | mean = normalizer[0] 1580 | std = normalizer[1] 1581 | x = [(x[i] - mean[i]) / std[i] for i in range(len(x))] 1582 | 1583 | symbolic_acts.append(x) 1584 | 1585 | for l in range(len(self.width) - 1): 1586 | y = [] 1587 | for j in range(self.width[l + 1]): 1588 | yj = 0. 1589 | for i in range(self.width[l]): 1590 | a, b, c, d = self.symbolic_fun[l].affine[j, i] 1591 | sympy_fun = self.symbolic_fun[l].funs_sympy[j][i] 1592 | try: 1593 | yj += c * sympy_fun(a * x[i] + b) + d 1594 | except: 1595 | print('make sure all activations need to be converted to symbolic formulas first!') 1596 | return 1597 | if simplify == True: 1598 | y.append(sympy.simplify(yj + self.biases[l].weight.data[0, j])) 1599 | else: 1600 | y.append(yj + self.biases[l].weight.data[0, j]) 1601 | 1602 | x = y 1603 | symbolic_acts.append(x) 1604 | 1605 | if output_normalizer != None: 1606 | output_layer = symbolic_acts[-1] 1607 | means = output_normalizer[0] 1608 | stds = output_normalizer[1] 1609 | 1610 | assert len(output_layer) == len(means), 'output_normalizer does not match the output layer' 1611 | assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer' 1612 | 1613 | output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))] 1614 | symbolic_acts[-1] = output_layer 1615 | 1616 | 1617 | 1618 | self.symbolic_acts = [[ex_round(symbolic_acts[l][i]) for i in range(len(symbolic_acts[l]))] for l in range(len(symbolic_acts))] 1619 | 1620 | out_dim = len(symbolic_acts[-1]) 1621 | return [ex_round(symbolic_acts[-1][i]) for i in range(len(symbolic_acts[-1]))], x0 1622 | 1623 | def clear_ckpts(self, folder='./model_ckpt'): 1624 | ''' 1625 | clear all checkpoints 1626 | 1627 | Args: 1628 | ----- 1629 | folder : str 1630 | the folder that stores checkpoints 1631 | 1632 | Returns: 1633 | -------- 1634 | None 1635 | ''' 1636 | if os.path.exists(folder): 1637 | files = glob.glob(folder + '/*') 1638 | for f in files: 1639 | os.remove(f) 1640 | else: 1641 | os.makedirs(folder) 1642 | 1643 | def save_ckpt(self, name, folder='./model_ckpt'): 1644 | ''' 1645 | save the current model as checkpoint 1646 | 1647 | Args: 1648 | ----- 1649 | name: str 1650 | the name of the checkpoint to be saved 1651 | folder : str 1652 | the folder that stores checkpoints 1653 | 1654 | Returns: 1655 | -------- 1656 | None 1657 | ''' 1658 | 1659 | if not os.path.exists(folder): 1660 | os.makedirs(folder) 1661 | 1662 | torch.save(self.state_dict(), folder + '/' + name) 1663 | print('save this model to', folder + '/' + name) 1664 | 1665 | def load_ckpt(self, name, folder='./model_ckpt'): 1666 | ''' 1667 | load a checkpoint to the current model 1668 | 1669 | Args: 1670 | ----- 1671 | name: str 1672 | the name of the checkpoint to be loaded 1673 | folder : str 1674 | the folder that stores checkpoints 1675 | 1676 | Returns: 1677 | -------- 1678 | None 1679 | ''' 1680 | self.load_state_dict(torch.load(folder + '/' + name)) 1681 | -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def B_batch(x, grid, k=0, extend=True, device='cpu'): 5 | ''' 6 | evaludate x on B-spline bases 7 | 8 | Args: 9 | ----- 10 | x : 2D torch.tensor 11 | inputs, shape (number of splines, number of samples) 12 | grid : 2D torch.tensor 13 | grids, shape (number of splines, number of grid points) 14 | k : int 15 | the piecewise polynomial order of splines. 16 | extend : bool 17 | If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True 18 | device : str 19 | devicde 20 | 21 | Returns: 22 | -------- 23 | spline values : 3D torch.tensor 24 | shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1. 25 | 26 | Example 27 | ------- 28 | >>> num_spline = 5 29 | >>> num_sample = 100 30 | >>> num_grid_interval = 10 31 | >>> k = 3 32 | >>> x = torch.normal(0,1,size=(num_spline, num_sample)) 33 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 34 | >>> B_batch(x, grids, k=k).shape 35 | torch.Size([5, 13, 100]) 36 | ''' 37 | 38 | # x shape: (size, x); grid shape: (size, grid) 39 | def extend_grid(grid, k_extend=0): 40 | # pad k to left and right 41 | # grid shape: (batch, grid) 42 | h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) 43 | 44 | for i in range(k_extend): 45 | grid = torch.cat([grid[:, [0]] - h, grid], dim=1) 46 | grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) 47 | grid = grid.to(device) 48 | return grid 49 | 50 | if extend == True: 51 | grid = extend_grid(grid, k_extend=k) 52 | 53 | grid = grid.unsqueeze(dim=2).to(device) 54 | x = x.unsqueeze(dim=1).to(device) 55 | 56 | if k == 0: 57 | value = (x >= grid[:, :-1]) * (x < grid[:, 1:]) 58 | else: 59 | B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device) 60 | value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + ( 61 | grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:] 62 | return value 63 | 64 | 65 | def coef2curve(x_eval, grid, coef, k, device="cpu"): 66 | ''' 67 | converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). 68 | 69 | Args: 70 | ----- 71 | x_eval : 2D torch.tensor) 72 | shape (number of splines, number of samples) 73 | grid : 2D torch.tensor) 74 | shape (number of splines, number of grid points) 75 | coef : 2D torch.tensor) 76 | shape (number of splines, number of coef params). number of coef params = number of grid intervals + k 77 | k : int 78 | the piecewise polynomial order of splines. 79 | device : str 80 | devicde 81 | 82 | Returns: 83 | -------- 84 | y_eval : 2D torch.tensor 85 | shape (number of splines, number of samples) 86 | 87 | Example 88 | ------- 89 | >>> num_spline = 5 90 | >>> num_sample = 100 91 | >>> num_grid_interval = 10 92 | >>> k = 3 93 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) 94 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 95 | >>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k)) 96 | >>> coef2curve(x_eval, grids, coef, k=k).shape 97 | torch.Size([5, 100]) 98 | ''' 99 | # x_eval: (size, batch), grid: (size, grid), coef: (size, coef) 100 | # coef: (size, coef), B_batch: (size, coef, batch), summer over coef 101 | if coef.dtype != x_eval.dtype: 102 | coef = coef.to(x_eval.dtype) 103 | y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device)) 104 | return y_eval 105 | 106 | 107 | def curve2coef(x_eval, y_eval, grid, k, device="cpu"): 108 | ''' 109 | converting B-spline curves to B-spline coefficients using least squares. 110 | 111 | Args: 112 | ----- 113 | x_eval : 2D torch.tensor 114 | shape (number of splines, number of samples) 115 | y_eval : 2D torch.tensor 116 | shape (number of splines, number of samples) 117 | grid : 2D torch.tensor 118 | shape (number of splines, number of grid points) 119 | k : int 120 | the piecewise polynomial order of splines. 121 | device : str 122 | devicde 123 | 124 | Example 125 | ------- 126 | >>> num_spline = 5 127 | >>> num_sample = 100 128 | >>> num_grid_interval = 10 129 | >>> k = 3 130 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) 131 | >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample)) 132 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 133 | torch.Size([5, 13]) 134 | ''' 135 | # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar 136 | mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) 137 | # coef = torch.linalg.lstsq(mat, y_eval.unsqueeze(dim=2)).solution[:, :, 0] 138 | coef = torch.linalg.lstsq(mat.to(device), y_eval.unsqueeze(dim=2).to(device), 139 | driver='gelsy' if device == 'cpu' else 'gels').solution[:, :, 0] 140 | return coef.to(device) 141 | -------------------------------------------------------------------------------- /KAN_Implementations/Original_KAN/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.linear_model import LinearRegression 4 | import sympy 5 | 6 | # sigmoid = sympy.Function('sigmoid') 7 | # name: (torch implementation, sympy implementation) 8 | SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x), 9 | 'x^2': (lambda x: x**2, lambda x: x**2), 10 | 'x^3': (lambda x: x**3, lambda x: x**3), 11 | 'x^4': (lambda x: x**4, lambda x: x**4), 12 | '1/x': (lambda x: 1/x, lambda x: 1/x), 13 | '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2), 14 | '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3), 15 | '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4), 16 | 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x)), 17 | '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x)), 18 | 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x)), 19 | 'log': (lambda x: torch.log(x), lambda x: sympy.log(x)), 20 | 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x)), 21 | 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x)), 22 | 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x)), 23 | 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x)), 24 | 'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid')), 25 | #'relu': (lambda x: torch.relu(x), relu), 26 | 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x)), 27 | 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.arcsin(x)), 28 | 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x)), 29 | 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x)), 30 | '0': (lambda x: x*0, lambda x: x*0), 31 | 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2)), 32 | 'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x)), 33 | #'logcosh': (lambda x: torch.log(torch.cosh(x)), lambda x: sympy.log(sympy.cosh(x))), 34 | #'cosh^2': (lambda x: torch.cosh(x)**2, lambda x: sympy.cosh(x)**2), 35 | } 36 | 37 | def create_dataset(f, 38 | n_var=2, 39 | ranges = [-1,1], 40 | train_num=1000, 41 | test_num=1000, 42 | normalize_input=False, 43 | normalize_label=False, 44 | device='cpu', 45 | seed=0): 46 | ''' 47 | create dataset 48 | 49 | Args: 50 | ----- 51 | f : function 52 | the symbolic formula used to create the synthetic dataset 53 | ranges : list or np.array; shape (2,) or (n_var, 2) 54 | the range of input variables. Default: [-1,1]. 55 | train_num : int 56 | the number of training samples. Default: 1000. 57 | test_num : int 58 | the number of test samples. Default: 1000. 59 | normalize_input : bool 60 | If True, apply normalization to inputs. Default: False. 61 | normalize_label : bool 62 | If True, apply normalization to labels. Default: False. 63 | device : str 64 | device. Default: 'cpu'. 65 | seed : int 66 | random seed. Default: 0. 67 | 68 | Returns: 69 | -------- 70 | dataset : dic 71 | Train/test inputs/labels are dataset['train_input'], dataset['train_label'], 72 | dataset['test_input'], dataset['test_label'] 73 | 74 | Example 75 | ------- 76 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 77 | >>> dataset = create_dataset(f, n_var=2, train_num=100) 78 | >>> dataset['train_input'].shape 79 | torch.Size([100, 2]) 80 | ''' 81 | 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | 85 | if len(np.array(ranges).shape) == 1: 86 | ranges = np.array(ranges * n_var).reshape(n_var,2) 87 | else: 88 | ranges = np.array(ranges) 89 | 90 | train_input = torch.zeros(train_num, n_var) 91 | test_input = torch.zeros(test_num, n_var) 92 | for i in range(n_var): 93 | train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] 94 | test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] 95 | 96 | 97 | train_label = f(train_input) 98 | test_label = f(test_input) 99 | 100 | 101 | def normalize(data, mean, std): 102 | return (data-mean)/std 103 | 104 | if normalize_input == True: 105 | mean_input = torch.mean(train_input, dim=0, keepdim=True) 106 | std_input = torch.std(train_input, dim=0, keepdim=True) 107 | train_input = normalize(train_input, mean_input, std_input) 108 | test_input = normalize(test_input, mean_input, std_input) 109 | 110 | if normalize_label == True: 111 | mean_label = torch.mean(train_label, dim=0, keepdim=True) 112 | std_label = torch.std(train_label, dim=0, keepdim=True) 113 | train_label = normalize(train_label, mean_label, std_label) 114 | test_label = normalize(test_label, mean_label, std_label) 115 | 116 | dataset = {} 117 | dataset['train_input'] = train_input.to(device) 118 | dataset['test_input'] = test_input.to(device) 119 | 120 | dataset['train_label'] = train_label.to(device) 121 | dataset['test_label'] = test_label.to(device) 122 | 123 | return dataset 124 | 125 | 126 | 127 | def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'): 128 | ''' 129 | fit a, b, c, d such that 130 | 131 | .. math:: 132 | |y-(cf(ax+b)+d)|^2 133 | 134 | is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model. 135 | 136 | Args: 137 | ----- 138 | x : 1D array 139 | x values 140 | y : 1D array 141 | y values 142 | fun : function 143 | symbolic function 144 | a_range : tuple 145 | sweeping range of a 146 | b_range : tuple 147 | sweeping range of b 148 | grid_num : int 149 | number of steps along a and b 150 | iteration : int 151 | number of zooming in 152 | verbose : bool 153 | print extra information if True 154 | device : str 155 | device 156 | 157 | Returns: 158 | -------- 159 | a_best : float 160 | best fitted a 161 | b_best : float 162 | best fitted b 163 | c_best : float 164 | best fitted c 165 | d_best : float 166 | best fitted d 167 | r2_best : float 168 | best r2 (coefficient of determination) 169 | 170 | Example 171 | ------- 172 | >>> num = 100 173 | >>> x = torch.linspace(-1,1,steps=num) 174 | >>> noises = torch.normal(0,1,(num,)) * 0.02 175 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises 176 | >>> fit_params(x, y, torch.sin) 177 | r2 is 0.9999727010726929 178 | (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000)) 179 | ''' 180 | # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array. 181 | # sweep a and b, choose the best fitted model 182 | for _ in range(iteration): 183 | a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device) 184 | b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device) 185 | a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij') 186 | post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:]) 187 | x_mean = torch.mean(post_fun, dim=[0], keepdim=True) 188 | y_mean = torch.mean(y, dim=[0], keepdim=True) 189 | numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2 190 | denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0) 191 | r2 = numerator/(denominator+1e-4) 192 | r2 = torch.nan_to_num(r2) 193 | 194 | 195 | best_id = torch.argmax(r2) 196 | a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number 197 | 198 | 199 | if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1: 200 | if _ == 0 and verbose==True: 201 | print('Best value at boundary.') 202 | if a_id == 0: 203 | a_range = [a_[0], a_[1]] 204 | if a_id == grid_number - 1: 205 | a_range = [a_[-2], a_[-1]] 206 | if b_id == 0: 207 | b_range = [b_[0], b_[1]] 208 | if b_id == grid_number - 1: 209 | b_range = [b_[-2], b_[-1]] 210 | 211 | else: 212 | a_range = [a_[a_id-1], a_[a_id+1]] 213 | b_range = [b_[b_id-1], b_[b_id+1]] 214 | 215 | a_best = a_[a_id] 216 | b_best = b_[b_id] 217 | post_fun = fun(a_best * x + b_best) 218 | r2_best = r2[a_id, b_id] 219 | 220 | if verbose == True: 221 | print(f"r2 is {r2_best}") 222 | if r2_best < 0.9: 223 | print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.') 224 | 225 | post_fun = torch.nan_to_num(post_fun) 226 | reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy()) 227 | c_best = torch.from_numpy(reg.coef_)[0].to(device) 228 | d_best = torch.from_numpy(np.array(reg.intercept_)).to(device) 229 | return torch.stack([a_best, b_best, c_best, d_best]), r2_best 230 | 231 | 232 | 233 | def add_symbolic(name, fun): 234 | ''' 235 | add a symbolic function to library 236 | 237 | Args: 238 | ----- 239 | name : str 240 | name of the function 241 | fun : fun 242 | torch function or lambda function 243 | 244 | Returns: 245 | -------- 246 | None 247 | 248 | Example 249 | ------- 250 | >>> print(SYMBOLIC_LIB['Bessel']) 251 | KeyError: 'Bessel' 252 | >>> add_symbolic('Bessel', torch.special.bessel_j0) 253 | >>> print(SYMBOLIC_LIB['Bessel']) 254 | (, Bessel) 255 | ''' 256 | exec(f"globals()['{name}'] = sympy.Function('{name}')") 257 | SYMBOLIC_LIB[name] = (fun, globals()[name]) 258 | 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Omar Rayyan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional 2D KAN Implementation 2 | 3 | This repositry contains 3 drop-in convolutional KAN replacements. Each work on top of a different KAN implementation: 4 | 5 | 1. [Efficient implementation of Kolmogorov-Arnold Network (KAN)](https://github.com/Blealtan/efficient-kan) 6 | 2. [Original KAN implementation](https://github.com/KindXiaoming/pykan) 7 | 3. [Fast KAN implementation](https://github.com/ZiyaoLi/fast-kan) 8 | 9 | # Installation 10 | ```bash 11 | git clone git@github.com/omarrayyann/KAN-Conv2D 12 | cd KAN-Conv2D 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | # Usage 17 | 18 | You should be able to just replace ```torch.nn.Conv2D()``` with ```ConvKAN()``` 19 | 20 | ```python3 21 | 22 | from ConvKAN import ConvKAN 23 | 24 | # Implementation built on the efficient KAN Implementation (https://github.com/Blealtan/efficient-kan) 25 | conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version="Efficient") 26 | 27 | # Implementation built on the original KAN Implementation (https://github.com/KindXiaoming/pykan) 28 | conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version="Original") 29 | 30 | # Implementation built on the fast KAN Implementation (https://github.com/ZiyaoLi/fast-kan) 31 | conv = ConvKAN(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, version="Fast") 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy --------------------------------------------------------------------------------