├── README.md ├── data ├── create_dataset.m ├── test_input.csv ├── test_target.csv ├── train_input.csv ├── train_target.csv └── x.csv ├── efficient_kan ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── kan.cpython-312.pyc │ └── mlp.cpython-312.pyc ├── kan.py └── mlp.py ├── model ├── kan.pth └── mlp.pth ├── test_both.py ├── train_kan.py └── train_mlp.py /README.md: -------------------------------------------------------------------------------- 1 | # KAN 2 | This repository demonstrates the application of efficient Kolmogorov-Arnold Networks (KAN) in a curve fitting (regression) task. The original KAN can be found [here](https://github.com/KindXiaoming/pykan), while the original efficient KAN can be found [here](https://github.com/Blealtan/efficient-kan). Another similar example of Lorentzian curve fitting using KAN can be found at [here](https://github.com/JianpanHuang/CEST-KAN). 3 | 4 | The curve function here is: y = a·sin(b·x)+c·cos(d·x), x = 0:0.2:10. [a, b, c, d] used here are randomly sampled from [0.1:0.1:2]. 5 | 6 | **You may change it to whatever function you would like to fit.** 7 | 8 | The training dataset was created using the matlab code ‘create_dataset.m’ under 'Data' folder. 9 | 10 | Network specifics: size(inputlayer, hiddenlayer, outputlayer) = [51, 100, 4]. 11 | 12 | The input is curve values y with a length of 51, and the output is the coefficients vector [a, b, c, d] with a length of 4, as shown below. 13 | 14 | image 15 | 16 | The loss curves of KAN and MLP after training for 30 epochs are as follows: 17 | 18 | image 19 | 20 | The predicted curves by MLP and KAN after training for 30 epochs are as follows: 21 | 22 | image 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /data/create_dataset.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | ds = 50000; % data size 3 | x=0:0.2:10; 4 | for n = 1:1 5 | a = randsample(20,1)/10; 6 | b = randsample(20,1)/10; 7 | c = randsample(20,1)/10; 8 | d = randsample(20,1)/10; 9 | y = a*sin(b*x)+c*cos(d*x); % y function 10 | y = awgn(y,30,"measured"); 11 | input(:,n) = y; 12 | target(:,n) = [a,b,c,d]; 13 | end 14 | figure, plot(x,y) 15 | inputs = input'; 16 | targets = target'; 17 | csvwrite('test_input.csv', inputs); 18 | csvwrite('test_targe.csv', targets); -------------------------------------------------------------------------------- /data/test_input.csv: -------------------------------------------------------------------------------- 1 | 0.090444,0.24442,0.3608,0.43557,0.40378,0.34929,0.21787,0.05023,-0.13111,-0.27953,-0.39331,-0.4645,-0.43468,-0.36531,-0.22132,-0.044746,0.12369,0.29327,0.41923,0.45267,0.46752,0.39642,0.2375,0.044496,-0.13596,-0.31966,-0.41109,-0.47811,-0.46265,-0.39113,-0.23068,-0.05577,0.13939,0.30816,0.42755,0.48659,0.4764,0.37028,0.24512,0.051559,-0.13376,-0.31073,-0.43301,-0.51184,-0.48944,-0.39325,-0.23851,-0.043004,0.13439,0.32145,0.43623 2 | -------------------------------------------------------------------------------- /data/test_target.csv: -------------------------------------------------------------------------------- 1 | 0.4,2,0.1,1.8 2 | -------------------------------------------------------------------------------- /data/x.csv: -------------------------------------------------------------------------------- 1 | 0,0.2,0.4,0.6,0.8,1,1.2,1.4,1.6,1.8,2,2.2,2.4,2.6,2.8,3,3.2,3.4,3.6,3.8,4,4.2,4.4,4.6,4.8,5,5.2,5.4,5.6,5.8,6,6.2,6.4,6.6,6.8,7,7.2,7.4,7.6,7.8,8,8.2,8.4,8.6,8.8,9,9.2,9.4,9.6,9.8,10 2 | -------------------------------------------------------------------------------- /efficient_kan/__init__.py: -------------------------------------------------------------------------------- 1 | from .kan import KANLinear, KAN 2 | 3 | __all__ = ["KANLinear", "KAN"] 4 | -------------------------------------------------------------------------------- /efficient_kan/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /efficient_kan/__pycache__/kan.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/kan.cpython-312.pyc -------------------------------------------------------------------------------- /efficient_kan/__pycache__/mlp.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/mlp.cpython-312.pyc -------------------------------------------------------------------------------- /efficient_kan/kan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | class KANLinear(torch.nn.Module): 7 | def __init__( 8 | self, 9 | in_features, 10 | out_features, 11 | grid_size=5, 12 | spline_order=3, 13 | scale_noise=0.1, 14 | scale_base=1.0, 15 | scale_spline=1.0, 16 | enable_standalone_scale_spline=True, 17 | base_activation=torch.nn.SiLU, 18 | grid_eps=0.02, 19 | grid_range=[-1, 1], 20 | ): 21 | super(KANLinear, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.grid_size = grid_size 25 | self.spline_order = spline_order 26 | 27 | h = (grid_range[1] - grid_range[0]) / grid_size 28 | grid = ( 29 | ( 30 | torch.arange(-spline_order, grid_size + spline_order + 1) * h 31 | + grid_range[0] 32 | ) 33 | .expand(in_features, -1) 34 | .contiguous() 35 | ) 36 | self.register_buffer("grid", grid) 37 | 38 | self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) 39 | self.spline_weight = torch.nn.Parameter( 40 | torch.Tensor(out_features, in_features, grid_size + spline_order) 41 | ) 42 | if enable_standalone_scale_spline: 43 | self.spline_scaler = torch.nn.Parameter( 44 | torch.Tensor(out_features, in_features) 45 | ) 46 | 47 | self.scale_noise = scale_noise 48 | self.scale_base = scale_base 49 | self.scale_spline = scale_spline 50 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 51 | self.base_activation = base_activation() 52 | self.grid_eps = grid_eps 53 | 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) 58 | with torch.no_grad(): 59 | noise = ( 60 | ( 61 | torch.rand(self.grid_size + 1, self.in_features, self.out_features) 62 | - 1 / 2 63 | ) 64 | * self.scale_noise 65 | / self.grid_size 66 | ) 67 | self.spline_weight.data.copy_( 68 | (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) 69 | * self.curve2coeff( 70 | self.grid.T[self.spline_order : -self.spline_order], 71 | noise, 72 | ) 73 | ) 74 | if self.enable_standalone_scale_spline: 75 | # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) 76 | torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) 77 | 78 | def b_splines(self, x: torch.Tensor): 79 | """ 80 | Compute the B-spline bases for the given input tensor. 81 | 82 | Args: 83 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 84 | 85 | Returns: 86 | torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). 87 | """ 88 | assert x.dim() == 2 and x.size(1) == self.in_features 89 | 90 | grid: torch.Tensor = ( 91 | self.grid 92 | ) # (in_features, grid_size + 2 * spline_order + 1) 93 | x = x.unsqueeze(-1) 94 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) 95 | for k in range(1, self.spline_order + 1): 96 | bases = ( 97 | (x - grid[:, : -(k + 1)]) 98 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 99 | * bases[:, :, :-1] 100 | ) + ( 101 | (grid[:, k + 1 :] - x) 102 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 103 | * bases[:, :, 1:] 104 | ) 105 | 106 | assert bases.size() == ( 107 | x.size(0), 108 | self.in_features, 109 | self.grid_size + self.spline_order, 110 | ) 111 | return bases.contiguous() 112 | 113 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): 114 | """ 115 | Compute the coefficients of the curve that interpolates the given points. 116 | 117 | Args: 118 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 119 | y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). 120 | 121 | Returns: 122 | torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). 123 | """ 124 | assert x.dim() == 2 and x.size(1) == self.in_features 125 | assert y.size() == (x.size(0), self.in_features, self.out_features) 126 | 127 | A = self.b_splines(x).transpose( 128 | 0, 1 129 | ) # (in_features, batch_size, grid_size + spline_order) 130 | B = y.transpose(0, 1) # (in_features, batch_size, out_features) 131 | solution = torch.linalg.lstsq( 132 | A, B 133 | ).solution # (in_features, grid_size + spline_order, out_features) 134 | result = solution.permute( 135 | 2, 0, 1 136 | ) # (out_features, in_features, grid_size + spline_order) 137 | 138 | assert result.size() == ( 139 | self.out_features, 140 | self.in_features, 141 | self.grid_size + self.spline_order, 142 | ) 143 | return result.contiguous() 144 | 145 | @property 146 | def scaled_spline_weight(self): 147 | return self.spline_weight * ( 148 | self.spline_scaler.unsqueeze(-1) 149 | if self.enable_standalone_scale_spline 150 | else 1.0 151 | ) 152 | 153 | def forward(self, x: torch.Tensor): 154 | assert x.dim() == 2 and x.size(1) == self.in_features 155 | 156 | base_output = F.linear(self.base_activation(x), self.base_weight) 157 | spline_output = F.linear( 158 | self.b_splines(x).view(x.size(0), -1), 159 | self.scaled_spline_weight.view(self.out_features, -1), 160 | ) 161 | return base_output + spline_output 162 | 163 | @torch.no_grad() 164 | def update_grid(self, x: torch.Tensor, margin=0.01): 165 | assert x.dim() == 2 and x.size(1) == self.in_features 166 | batch = x.size(0) 167 | 168 | splines = self.b_splines(x) # (batch, in, coeff) 169 | splines = splines.permute(1, 0, 2) # (in, batch, coeff) 170 | orig_coeff = self.scaled_spline_weight # (out, in, coeff) 171 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) 172 | unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) 173 | unreduced_spline_output = unreduced_spline_output.permute( 174 | 1, 0, 2 175 | ) # (batch, in, out) 176 | 177 | # sort each channel individually to collect data distribution 178 | x_sorted = torch.sort(x, dim=0)[0] 179 | grid_adaptive = x_sorted[ 180 | torch.linspace( 181 | 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device 182 | ) 183 | ] 184 | 185 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size 186 | grid_uniform = ( 187 | torch.arange( 188 | self.grid_size + 1, dtype=torch.float32, device=x.device 189 | ).unsqueeze(1) 190 | * uniform_step 191 | + x_sorted[0] 192 | - margin 193 | ) 194 | 195 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 196 | grid = torch.concatenate( 197 | [ 198 | grid[:1] 199 | - uniform_step 200 | * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), 201 | grid, 202 | grid[-1:] 203 | + uniform_step 204 | * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), 205 | ], 206 | dim=0, 207 | ) 208 | 209 | self.grid.copy_(grid.T) 210 | self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) 211 | 212 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 213 | """ 214 | Compute the regularization loss. 215 | 216 | This is a dumb simulation of the original L1 regularization as stated in the 217 | paper, since the original one requires computing absolutes and entropy from the 218 | expanded (batch, in_features, out_features) intermediate tensor, which is hidden 219 | behind the F.linear function if we want an memory efficient implementation. 220 | 221 | The L1 regularization is now computed as mean absolute value of the spline 222 | weights. The authors implementation also includes this term in addition to the 223 | sample-based regularization. 224 | """ 225 | l1_fake = self.spline_weight.abs().mean(-1) 226 | regularization_loss_activation = l1_fake.sum() 227 | p = l1_fake / regularization_loss_activation 228 | regularization_loss_entropy = -torch.sum(p * p.log()) 229 | return ( 230 | regularize_activation * regularization_loss_activation 231 | + regularize_entropy * regularization_loss_entropy 232 | ) 233 | 234 | 235 | class KAN(torch.nn.Module): 236 | def __init__( 237 | self, 238 | layers_hidden, 239 | grid_size=5, 240 | spline_order=3, 241 | scale_noise=0.1, 242 | scale_base=1.0, 243 | scale_spline=1.0, 244 | base_activation=torch.nn.SiLU, 245 | grid_eps=0.02, 246 | grid_range=[-1, 1], 247 | ): 248 | super(KAN, self).__init__() 249 | self.grid_size = grid_size 250 | self.spline_order = spline_order 251 | 252 | self.layers = torch.nn.ModuleList() 253 | for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): 254 | self.layers.append( 255 | KANLinear( 256 | in_features, 257 | out_features, 258 | grid_size=grid_size, 259 | spline_order=spline_order, 260 | scale_noise=scale_noise, 261 | scale_base=scale_base, 262 | scale_spline=scale_spline, 263 | base_activation=base_activation, 264 | grid_eps=grid_eps, 265 | grid_range=grid_range, 266 | ) 267 | ) 268 | 269 | def forward(self, x: torch.Tensor, update_grid=False): 270 | for layer in self.layers: 271 | if update_grid: 272 | layer.update_grid(x) 273 | x = layer(x) 274 | return x 275 | 276 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 277 | return sum( 278 | layer.regularization_loss(regularize_activation, regularize_entropy) 279 | for layer in self.layers 280 | ) 281 | -------------------------------------------------------------------------------- /efficient_kan/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | # Define MLP model architecture 6 | class MLP(nn.Module): 7 | def __init__(self, input_size, hidden_size, output_size): 8 | super(MLP, self).__init__() 9 | self.fc1 = nn.Linear(input_size, hidden_size) 10 | self.fc2 = nn.Linear(hidden_size, output_size) 11 | 12 | def forward(self, x): 13 | x = torch.relu(self.fc1(x)) 14 | x = self.fc2(x) 15 | return x -------------------------------------------------------------------------------- /model/kan.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/model/kan.pth -------------------------------------------------------------------------------- /model/mlp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/model/mlp.pth -------------------------------------------------------------------------------- /test_both.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from tqdm import tqdm 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset, TensorDataset 7 | from sklearn.model_selection import train_test_split 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | from efficient_kan import KAN 11 | from efficient_kan.mlp import MLP 12 | 13 | # Load data 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | input = pd.read_csv('./data/test_input.csv', header=None) 16 | target = pd.read_csv('./data/test_target.csv', header=None) 17 | x = pd.read_csv('./data/x.csv', header=None) 18 | input_tensor = torch.Tensor(np.float32(np.array(input))) 19 | target_tensor = torch.Tensor(np.float32(np.array(target))) 20 | x = np.float32(np.array(x)) 21 | # print(input_arr) 22 | # print(target_arr) 23 | 24 | # Load MLP model 25 | model_mlp = MLP(51, 100, 4) 26 | model_mlp.to(device) 27 | model_mlp.load_state_dict(torch.load('./model/mlp.pth')) 28 | output_mlp = model_mlp(input_tensor) 29 | output_mlp_np = output_mlp.detach().numpy() 30 | # print(output_kan) 31 | 32 | # Load KAN model 33 | model_kan = KAN([51, 100, 4]) 34 | model_kan.to(device) 35 | model_kan.load_state_dict(torch.load('./model/kan.pth')) 36 | output_kan = model_kan(input_tensor) 37 | output_kan_np = output_kan.detach().numpy() 38 | print(output_kan_np) 39 | 40 | 41 | # Plot 42 | y_target = target_tensor[0][0]*np.sin(target_tensor[0][1]*x)+target_tensor[0][2]*np.cos(target_tensor[0][3]*x) 43 | y_mlp = output_mlp_np[0][0]*np.sin(output_mlp_np[0][1]*x)+output_mlp_np[0][2]*np.cos(output_mlp_np[0][3]*x) 44 | y_kan = output_kan_np[0][0]*np.sin(output_kan_np[0][1]*x)+output_kan_np[0][2]*np.cos(output_kan_np[0][3]*x) 45 | plt.plot(x,y_target,'o-',color='b',label="GT") 46 | plt.plot(x,y_mlp,'o-',color='g',label="MLP") 47 | plt.plot(x,y_kan,'o-',color='r',label="KAN") 48 | plt.xlabel("x",fontsize=16) 49 | plt.ylabel("y",fontsize=16) 50 | plt.title("y = a*sin(b*x)+c*cos(d*x)") 51 | plt.show() 52 | 53 | 54 | -------------------------------------------------------------------------------- /train_kan.py: -------------------------------------------------------------------------------- 1 | # from common.public import public # type: ignore 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset, TensorDataset 9 | from sklearn.model_selection import train_test_split 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | from efficient_kan import KAN 13 | 14 | # Load data 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | input = pd.read_csv('./data/train_input.csv', header=None) 17 | target = pd.read_csv('./data/train_target.csv', header=None) 18 | input_arr = np.float32(np.array(input)) 19 | target_arr = np.float32(np.array(target)) 20 | train_input, val_input, train_target, val_target = train_test_split(input_arr, target_arr, 21 | test_size=0.2, 22 | random_state=42) 23 | train_input_tensor = torch.tensor(train_input) 24 | train_target_tensor = torch.tensor(train_target) 25 | val_input_tensor = torch.tensor(val_input) 26 | val_target_tensor = torch.tensor(val_target) 27 | 28 | # print(train_data.shape, val_data.shape, train_targets.shape, val_targets.shape) 29 | train_dataset = torch.utils.data.TensorDataset(train_input_tensor, train_target_tensor) 30 | val_dataset = torch.utils.data.TensorDataset(val_input_tensor, val_target_tensor) 31 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 32 | val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) 33 | 34 | # Define model 35 | model = KAN([51, 100, 4]) 36 | model.to(device) 37 | # Define optimizer 38 | # optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4) 39 | optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) 40 | # Define learning rate scheduler 41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8) 42 | 43 | # Define loss 44 | # loss_func = nn.CrossEntropyLoss() 45 | loss_func = nn.MSELoss() 46 | train_loss_all = [] 47 | val_loss_all = [] 48 | losses = [] 49 | for epoch in range(30): 50 | # Train 51 | train_loss = 0 52 | train_num = 0 53 | model.train() 54 | with tqdm(train_loader) as pbar: 55 | running_loss = 0.0 56 | for i, (input, target) in enumerate(pbar): 57 | input = input.view(-1, 51).to(device) 58 | optimizer.zero_grad() 59 | output = model(input) 60 | loss = loss_func(output, target.to(device)) 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | train_loss += loss.item()*input.size(0) 65 | train_num += input.size(0) 66 | # print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) 67 | pbar.set_postfix(lr=optimizer.param_groups[0]['lr']) 68 | train_loss_all.append(train_loss/train_num) 69 | pbar.set_postfix(loss=train_loss/train_num) 70 | 71 | # Validation 72 | model.eval() 73 | val_loss = 0 74 | val_accuracy = 0 75 | val_num = 0 76 | with torch.no_grad(): 77 | for input, target in val_loader: 78 | input = input.view(-1, 51).to(device) 79 | output = model(input) 80 | val_loss += loss_func(output, target.to(device)).item()*input.size(0) 81 | val_num += input.size(0) 82 | val_loss_all.append(val_loss/val_num) 83 | # val_accuracy /= len(valloader) 84 | 85 | # Update learning rate 86 | scheduler.step() 87 | 88 | print( 89 | f"Epoch {epoch + 1}, Val Loss: {val_loss/val_num}" 90 | ) 91 | 92 | # Save the trained model 93 | torch.save(model.state_dict(), "./model/kan.pth") 94 | 95 | # Plot the loss values against the number of epochs 96 | fig, ax = plt.subplots() 97 | ax.plot(range(1, epoch + 2), train_loss_all, label='Train Loss') 98 | ax.plot(range(1, epoch + 2), val_loss_all, label='Val Loss') 99 | ax.set_title('Loss Curves') 100 | ax.set_xlabel('Epochs') 101 | ax.set_ylabel('Loss') 102 | ax.legend() 103 | plt.show() 104 | -------------------------------------------------------------------------------- /train_mlp.py: -------------------------------------------------------------------------------- 1 | # from common.public import public # type: ignore 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset, TensorDataset 9 | from sklearn.model_selection import train_test_split 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | from efficient_kan.mlp import MLP 13 | 14 | # Load data 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | input = pd.read_csv('./data/train_input.csv', header=None) 17 | target = pd.read_csv('./data/train_target.csv', header=None) 18 | input_arr = np.float32(np.array(input)) 19 | target_arr = np.float32(np.array(target)) 20 | train_input, val_input, train_target, val_target = train_test_split(input_arr, target_arr, 21 | test_size=0.2, 22 | random_state=42) 23 | train_input_tensor = torch.tensor(train_input) 24 | train_target_tensor = torch.tensor(train_target) 25 | val_input_tensor = torch.tensor(val_input) 26 | val_target_tensor = torch.tensor(val_target) 27 | 28 | # print(train_data.shape, val_data.shape, train_targets.shape, val_targets.shape) 29 | train_dataset = torch.utils.data.TensorDataset(train_input_tensor, train_target_tensor) 30 | val_dataset = torch.utils.data.TensorDataset(val_input_tensor, val_target_tensor) 31 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 32 | val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) 33 | 34 | # Define model 35 | model = MLP(51, 100, 4) 36 | model.to(device) 37 | # Define optimizer 38 | # optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4) 39 | optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) 40 | # Define learning rate scheduler 41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8) 42 | 43 | # Define loss 44 | # loss_func = nn.CrossEntropyLoss() 45 | loss_func = nn.MSELoss() 46 | train_loss_all = [] 47 | val_loss_all = [] 48 | losses = [] 49 | for epoch in range(30): 50 | # Train 51 | train_loss = 0 52 | train_num = 0 53 | model.train() 54 | with tqdm(train_loader) as pbar: 55 | running_loss = 0.0 56 | for i, (input, target) in enumerate(pbar): 57 | input = input.view(-1, 51).to(device) 58 | optimizer.zero_grad() 59 | output = model(input) 60 | loss = loss_func(output, target.to(device)) 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | train_loss += loss.item()*input.size(0) 65 | train_num += input.size(0) 66 | # print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) 67 | pbar.set_postfix(lr=optimizer.param_groups[0]['lr']) 68 | train_loss_all.append(train_loss/train_num) 69 | pbar.set_postfix(loss=train_loss/train_num) 70 | 71 | # Validation 72 | model.eval() 73 | val_loss = 0 74 | val_accuracy = 0 75 | val_num = 0 76 | with torch.no_grad(): 77 | for input, target in val_loader: 78 | input = input.view(-1, 51).to(device) 79 | output = model(input) 80 | val_loss += loss_func(output, target.to(device)).item()*input.size(0) 81 | val_num += input.size(0) 82 | val_loss_all.append(val_loss/val_num) 83 | # val_accuracy /= len(valloader) 84 | 85 | # Update learning rate 86 | scheduler.step() 87 | 88 | print( 89 | f"Epoch {epoch + 1}, Val Loss: {val_loss/val_num}" 90 | ) 91 | 92 | # Save the trained model 93 | torch.save(model.state_dict(), "./model/kan.pth") 94 | 95 | # Plot the loss values against the number of epochs 96 | fig, ax = plt.subplots() 97 | ax.plot(range(1, epoch + 2), train_loss_all, label='Train Loss') 98 | ax.plot(range(1, epoch + 2), val_loss_all, label='Val Loss') 99 | ax.set_title('Loss Curves') 100 | ax.set_xlabel('Epochs') 101 | ax.set_ylabel('Loss') 102 | ax.legend() 103 | plt.show() 104 | --------------------------------------------------------------------------------