├── README.md
├── data
├── create_dataset.m
├── test_input.csv
├── test_target.csv
├── train_input.csv
├── train_target.csv
└── x.csv
├── efficient_kan
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── kan.cpython-312.pyc
│ └── mlp.cpython-312.pyc
├── kan.py
└── mlp.py
├── model
├── kan.pth
└── mlp.pth
├── test_both.py
├── train_kan.py
└── train_mlp.py
/README.md:
--------------------------------------------------------------------------------
1 | # KAN
2 | This repository demonstrates the application of efficient Kolmogorov-Arnold Networks (KAN) in a curve fitting (regression) task. The original KAN can be found [here](https://github.com/KindXiaoming/pykan), while the original efficient KAN can be found [here](https://github.com/Blealtan/efficient-kan). Another similar example of Lorentzian curve fitting using KAN can be found at [here](https://github.com/JianpanHuang/CEST-KAN).
3 |
4 | The curve function here is: y = a·sin(b·x)+c·cos(d·x), x = 0:0.2:10. [a, b, c, d] used here are randomly sampled from [0.1:0.1:2].
5 |
6 | **You may change it to whatever function you would like to fit.**
7 |
8 | The training dataset was created using the matlab code ‘create_dataset.m’ under 'Data' folder.
9 |
10 | Network specifics: size(inputlayer, hiddenlayer, outputlayer) = [51, 100, 4].
11 |
12 | The input is curve values y with a length of 51, and the output is the coefficients vector [a, b, c, d] with a length of 4, as shown below.
13 |
14 |
15 |
16 | The loss curves of KAN and MLP after training for 30 epochs are as follows:
17 |
18 |
19 |
20 | The predicted curves by MLP and KAN after training for 30 epochs are as follows:
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/data/create_dataset.m:
--------------------------------------------------------------------------------
1 | clear all;
2 | ds = 50000; % data size
3 | x=0:0.2:10;
4 | for n = 1:1
5 | a = randsample(20,1)/10;
6 | b = randsample(20,1)/10;
7 | c = randsample(20,1)/10;
8 | d = randsample(20,1)/10;
9 | y = a*sin(b*x)+c*cos(d*x); % y function
10 | y = awgn(y,30,"measured");
11 | input(:,n) = y;
12 | target(:,n) = [a,b,c,d];
13 | end
14 | figure, plot(x,y)
15 | inputs = input';
16 | targets = target';
17 | csvwrite('test_input.csv', inputs);
18 | csvwrite('test_targe.csv', targets);
--------------------------------------------------------------------------------
/data/test_input.csv:
--------------------------------------------------------------------------------
1 | 0.090444,0.24442,0.3608,0.43557,0.40378,0.34929,0.21787,0.05023,-0.13111,-0.27953,-0.39331,-0.4645,-0.43468,-0.36531,-0.22132,-0.044746,0.12369,0.29327,0.41923,0.45267,0.46752,0.39642,0.2375,0.044496,-0.13596,-0.31966,-0.41109,-0.47811,-0.46265,-0.39113,-0.23068,-0.05577,0.13939,0.30816,0.42755,0.48659,0.4764,0.37028,0.24512,0.051559,-0.13376,-0.31073,-0.43301,-0.51184,-0.48944,-0.39325,-0.23851,-0.043004,0.13439,0.32145,0.43623
2 |
--------------------------------------------------------------------------------
/data/test_target.csv:
--------------------------------------------------------------------------------
1 | 0.4,2,0.1,1.8
2 |
--------------------------------------------------------------------------------
/data/x.csv:
--------------------------------------------------------------------------------
1 | 0,0.2,0.4,0.6,0.8,1,1.2,1.4,1.6,1.8,2,2.2,2.4,2.6,2.8,3,3.2,3.4,3.6,3.8,4,4.2,4.4,4.6,4.8,5,5.2,5.4,5.6,5.8,6,6.2,6.4,6.6,6.8,7,7.2,7.4,7.6,7.8,8,8.2,8.4,8.6,8.8,9,9.2,9.4,9.6,9.8,10
2 |
--------------------------------------------------------------------------------
/efficient_kan/__init__.py:
--------------------------------------------------------------------------------
1 | from .kan import KANLinear, KAN
2 |
3 | __all__ = ["KANLinear", "KAN"]
4 |
--------------------------------------------------------------------------------
/efficient_kan/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/efficient_kan/__pycache__/kan.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/kan.cpython-312.pyc
--------------------------------------------------------------------------------
/efficient_kan/__pycache__/mlp.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/efficient_kan/__pycache__/mlp.cpython-312.pyc
--------------------------------------------------------------------------------
/efficient_kan/kan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 |
5 |
6 | class KANLinear(torch.nn.Module):
7 | def __init__(
8 | self,
9 | in_features,
10 | out_features,
11 | grid_size=5,
12 | spline_order=3,
13 | scale_noise=0.1,
14 | scale_base=1.0,
15 | scale_spline=1.0,
16 | enable_standalone_scale_spline=True,
17 | base_activation=torch.nn.SiLU,
18 | grid_eps=0.02,
19 | grid_range=[-1, 1],
20 | ):
21 | super(KANLinear, self).__init__()
22 | self.in_features = in_features
23 | self.out_features = out_features
24 | self.grid_size = grid_size
25 | self.spline_order = spline_order
26 |
27 | h = (grid_range[1] - grid_range[0]) / grid_size
28 | grid = (
29 | (
30 | torch.arange(-spline_order, grid_size + spline_order + 1) * h
31 | + grid_range[0]
32 | )
33 | .expand(in_features, -1)
34 | .contiguous()
35 | )
36 | self.register_buffer("grid", grid)
37 |
38 | self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
39 | self.spline_weight = torch.nn.Parameter(
40 | torch.Tensor(out_features, in_features, grid_size + spline_order)
41 | )
42 | if enable_standalone_scale_spline:
43 | self.spline_scaler = torch.nn.Parameter(
44 | torch.Tensor(out_features, in_features)
45 | )
46 |
47 | self.scale_noise = scale_noise
48 | self.scale_base = scale_base
49 | self.scale_spline = scale_spline
50 | self.enable_standalone_scale_spline = enable_standalone_scale_spline
51 | self.base_activation = base_activation()
52 | self.grid_eps = grid_eps
53 |
54 | self.reset_parameters()
55 |
56 | def reset_parameters(self):
57 | torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
58 | with torch.no_grad():
59 | noise = (
60 | (
61 | torch.rand(self.grid_size + 1, self.in_features, self.out_features)
62 | - 1 / 2
63 | )
64 | * self.scale_noise
65 | / self.grid_size
66 | )
67 | self.spline_weight.data.copy_(
68 | (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
69 | * self.curve2coeff(
70 | self.grid.T[self.spline_order : -self.spline_order],
71 | noise,
72 | )
73 | )
74 | if self.enable_standalone_scale_spline:
75 | # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
76 | torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
77 |
78 | def b_splines(self, x: torch.Tensor):
79 | """
80 | Compute the B-spline bases for the given input tensor.
81 |
82 | Args:
83 | x (torch.Tensor): Input tensor of shape (batch_size, in_features).
84 |
85 | Returns:
86 | torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
87 | """
88 | assert x.dim() == 2 and x.size(1) == self.in_features
89 |
90 | grid: torch.Tensor = (
91 | self.grid
92 | ) # (in_features, grid_size + 2 * spline_order + 1)
93 | x = x.unsqueeze(-1)
94 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
95 | for k in range(1, self.spline_order + 1):
96 | bases = (
97 | (x - grid[:, : -(k + 1)])
98 | / (grid[:, k:-1] - grid[:, : -(k + 1)])
99 | * bases[:, :, :-1]
100 | ) + (
101 | (grid[:, k + 1 :] - x)
102 | / (grid[:, k + 1 :] - grid[:, 1:(-k)])
103 | * bases[:, :, 1:]
104 | )
105 |
106 | assert bases.size() == (
107 | x.size(0),
108 | self.in_features,
109 | self.grid_size + self.spline_order,
110 | )
111 | return bases.contiguous()
112 |
113 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
114 | """
115 | Compute the coefficients of the curve that interpolates the given points.
116 |
117 | Args:
118 | x (torch.Tensor): Input tensor of shape (batch_size, in_features).
119 | y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
120 |
121 | Returns:
122 | torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
123 | """
124 | assert x.dim() == 2 and x.size(1) == self.in_features
125 | assert y.size() == (x.size(0), self.in_features, self.out_features)
126 |
127 | A = self.b_splines(x).transpose(
128 | 0, 1
129 | ) # (in_features, batch_size, grid_size + spline_order)
130 | B = y.transpose(0, 1) # (in_features, batch_size, out_features)
131 | solution = torch.linalg.lstsq(
132 | A, B
133 | ).solution # (in_features, grid_size + spline_order, out_features)
134 | result = solution.permute(
135 | 2, 0, 1
136 | ) # (out_features, in_features, grid_size + spline_order)
137 |
138 | assert result.size() == (
139 | self.out_features,
140 | self.in_features,
141 | self.grid_size + self.spline_order,
142 | )
143 | return result.contiguous()
144 |
145 | @property
146 | def scaled_spline_weight(self):
147 | return self.spline_weight * (
148 | self.spline_scaler.unsqueeze(-1)
149 | if self.enable_standalone_scale_spline
150 | else 1.0
151 | )
152 |
153 | def forward(self, x: torch.Tensor):
154 | assert x.dim() == 2 and x.size(1) == self.in_features
155 |
156 | base_output = F.linear(self.base_activation(x), self.base_weight)
157 | spline_output = F.linear(
158 | self.b_splines(x).view(x.size(0), -1),
159 | self.scaled_spline_weight.view(self.out_features, -1),
160 | )
161 | return base_output + spline_output
162 |
163 | @torch.no_grad()
164 | def update_grid(self, x: torch.Tensor, margin=0.01):
165 | assert x.dim() == 2 and x.size(1) == self.in_features
166 | batch = x.size(0)
167 |
168 | splines = self.b_splines(x) # (batch, in, coeff)
169 | splines = splines.permute(1, 0, 2) # (in, batch, coeff)
170 | orig_coeff = self.scaled_spline_weight # (out, in, coeff)
171 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
172 | unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
173 | unreduced_spline_output = unreduced_spline_output.permute(
174 | 1, 0, 2
175 | ) # (batch, in, out)
176 |
177 | # sort each channel individually to collect data distribution
178 | x_sorted = torch.sort(x, dim=0)[0]
179 | grid_adaptive = x_sorted[
180 | torch.linspace(
181 | 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
182 | )
183 | ]
184 |
185 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
186 | grid_uniform = (
187 | torch.arange(
188 | self.grid_size + 1, dtype=torch.float32, device=x.device
189 | ).unsqueeze(1)
190 | * uniform_step
191 | + x_sorted[0]
192 | - margin
193 | )
194 |
195 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
196 | grid = torch.concatenate(
197 | [
198 | grid[:1]
199 | - uniform_step
200 | * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
201 | grid,
202 | grid[-1:]
203 | + uniform_step
204 | * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
205 | ],
206 | dim=0,
207 | )
208 |
209 | self.grid.copy_(grid.T)
210 | self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
211 |
212 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
213 | """
214 | Compute the regularization loss.
215 |
216 | This is a dumb simulation of the original L1 regularization as stated in the
217 | paper, since the original one requires computing absolutes and entropy from the
218 | expanded (batch, in_features, out_features) intermediate tensor, which is hidden
219 | behind the F.linear function if we want an memory efficient implementation.
220 |
221 | The L1 regularization is now computed as mean absolute value of the spline
222 | weights. The authors implementation also includes this term in addition to the
223 | sample-based regularization.
224 | """
225 | l1_fake = self.spline_weight.abs().mean(-1)
226 | regularization_loss_activation = l1_fake.sum()
227 | p = l1_fake / regularization_loss_activation
228 | regularization_loss_entropy = -torch.sum(p * p.log())
229 | return (
230 | regularize_activation * regularization_loss_activation
231 | + regularize_entropy * regularization_loss_entropy
232 | )
233 |
234 |
235 | class KAN(torch.nn.Module):
236 | def __init__(
237 | self,
238 | layers_hidden,
239 | grid_size=5,
240 | spline_order=3,
241 | scale_noise=0.1,
242 | scale_base=1.0,
243 | scale_spline=1.0,
244 | base_activation=torch.nn.SiLU,
245 | grid_eps=0.02,
246 | grid_range=[-1, 1],
247 | ):
248 | super(KAN, self).__init__()
249 | self.grid_size = grid_size
250 | self.spline_order = spline_order
251 |
252 | self.layers = torch.nn.ModuleList()
253 | for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
254 | self.layers.append(
255 | KANLinear(
256 | in_features,
257 | out_features,
258 | grid_size=grid_size,
259 | spline_order=spline_order,
260 | scale_noise=scale_noise,
261 | scale_base=scale_base,
262 | scale_spline=scale_spline,
263 | base_activation=base_activation,
264 | grid_eps=grid_eps,
265 | grid_range=grid_range,
266 | )
267 | )
268 |
269 | def forward(self, x: torch.Tensor, update_grid=False):
270 | for layer in self.layers:
271 | if update_grid:
272 | layer.update_grid(x)
273 | x = layer(x)
274 | return x
275 |
276 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
277 | return sum(
278 | layer.regularization_loss(regularize_activation, regularize_entropy)
279 | for layer in self.layers
280 | )
281 |
--------------------------------------------------------------------------------
/efficient_kan/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 | # Define MLP model architecture
6 | class MLP(nn.Module):
7 | def __init__(self, input_size, hidden_size, output_size):
8 | super(MLP, self).__init__()
9 | self.fc1 = nn.Linear(input_size, hidden_size)
10 | self.fc2 = nn.Linear(hidden_size, output_size)
11 |
12 | def forward(self, x):
13 | x = torch.relu(self.fc1(x))
14 | x = self.fc2(x)
15 | return x
--------------------------------------------------------------------------------
/model/kan.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/model/kan.pth
--------------------------------------------------------------------------------
/model/mlp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianpanHuang/KAN/951a22574136c25ba7710e9127b5e13d78381264/model/mlp.pth
--------------------------------------------------------------------------------
/test_both.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from tqdm import tqdm
5 | import numpy as np
6 | from torch.utils.data import DataLoader, Dataset, TensorDataset
7 | from sklearn.model_selection import train_test_split
8 | import pandas as pd
9 | import matplotlib.pyplot as plt
10 | from efficient_kan import KAN
11 | from efficient_kan.mlp import MLP
12 |
13 | # Load data
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 | input = pd.read_csv('./data/test_input.csv', header=None)
16 | target = pd.read_csv('./data/test_target.csv', header=None)
17 | x = pd.read_csv('./data/x.csv', header=None)
18 | input_tensor = torch.Tensor(np.float32(np.array(input)))
19 | target_tensor = torch.Tensor(np.float32(np.array(target)))
20 | x = np.float32(np.array(x))
21 | # print(input_arr)
22 | # print(target_arr)
23 |
24 | # Load MLP model
25 | model_mlp = MLP(51, 100, 4)
26 | model_mlp.to(device)
27 | model_mlp.load_state_dict(torch.load('./model/mlp.pth'))
28 | output_mlp = model_mlp(input_tensor)
29 | output_mlp_np = output_mlp.detach().numpy()
30 | # print(output_kan)
31 |
32 | # Load KAN model
33 | model_kan = KAN([51, 100, 4])
34 | model_kan.to(device)
35 | model_kan.load_state_dict(torch.load('./model/kan.pth'))
36 | output_kan = model_kan(input_tensor)
37 | output_kan_np = output_kan.detach().numpy()
38 | print(output_kan_np)
39 |
40 |
41 | # Plot
42 | y_target = target_tensor[0][0]*np.sin(target_tensor[0][1]*x)+target_tensor[0][2]*np.cos(target_tensor[0][3]*x)
43 | y_mlp = output_mlp_np[0][0]*np.sin(output_mlp_np[0][1]*x)+output_mlp_np[0][2]*np.cos(output_mlp_np[0][3]*x)
44 | y_kan = output_kan_np[0][0]*np.sin(output_kan_np[0][1]*x)+output_kan_np[0][2]*np.cos(output_kan_np[0][3]*x)
45 | plt.plot(x,y_target,'o-',color='b',label="GT")
46 | plt.plot(x,y_mlp,'o-',color='g',label="MLP")
47 | plt.plot(x,y_kan,'o-',color='r',label="KAN")
48 | plt.xlabel("x",fontsize=16)
49 | plt.ylabel("y",fontsize=16)
50 | plt.title("y = a*sin(b*x)+c*cos(d*x)")
51 | plt.show()
52 |
53 |
54 |
--------------------------------------------------------------------------------
/train_kan.py:
--------------------------------------------------------------------------------
1 | # from common.public import public # type: ignore
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from tqdm import tqdm
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset, TensorDataset
9 | from sklearn.model_selection import train_test_split
10 | import pandas as pd
11 | import matplotlib.pyplot as plt
12 | from efficient_kan import KAN
13 |
14 | # Load data
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 | input = pd.read_csv('./data/train_input.csv', header=None)
17 | target = pd.read_csv('./data/train_target.csv', header=None)
18 | input_arr = np.float32(np.array(input))
19 | target_arr = np.float32(np.array(target))
20 | train_input, val_input, train_target, val_target = train_test_split(input_arr, target_arr,
21 | test_size=0.2,
22 | random_state=42)
23 | train_input_tensor = torch.tensor(train_input)
24 | train_target_tensor = torch.tensor(train_target)
25 | val_input_tensor = torch.tensor(val_input)
26 | val_target_tensor = torch.tensor(val_target)
27 |
28 | # print(train_data.shape, val_data.shape, train_targets.shape, val_targets.shape)
29 | train_dataset = torch.utils.data.TensorDataset(train_input_tensor, train_target_tensor)
30 | val_dataset = torch.utils.data.TensorDataset(val_input_tensor, val_target_tensor)
31 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
32 | val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
33 |
34 | # Define model
35 | model = KAN([51, 100, 4])
36 | model.to(device)
37 | # Define optimizer
38 | # optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4)
39 | optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
40 | # Define learning rate scheduler
41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
42 |
43 | # Define loss
44 | # loss_func = nn.CrossEntropyLoss()
45 | loss_func = nn.MSELoss()
46 | train_loss_all = []
47 | val_loss_all = []
48 | losses = []
49 | for epoch in range(30):
50 | # Train
51 | train_loss = 0
52 | train_num = 0
53 | model.train()
54 | with tqdm(train_loader) as pbar:
55 | running_loss = 0.0
56 | for i, (input, target) in enumerate(pbar):
57 | input = input.view(-1, 51).to(device)
58 | optimizer.zero_grad()
59 | output = model(input)
60 | loss = loss_func(output, target.to(device))
61 | optimizer.zero_grad()
62 | loss.backward()
63 | optimizer.step()
64 | train_loss += loss.item()*input.size(0)
65 | train_num += input.size(0)
66 | # print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
67 | pbar.set_postfix(lr=optimizer.param_groups[0]['lr'])
68 | train_loss_all.append(train_loss/train_num)
69 | pbar.set_postfix(loss=train_loss/train_num)
70 |
71 | # Validation
72 | model.eval()
73 | val_loss = 0
74 | val_accuracy = 0
75 | val_num = 0
76 | with torch.no_grad():
77 | for input, target in val_loader:
78 | input = input.view(-1, 51).to(device)
79 | output = model(input)
80 | val_loss += loss_func(output, target.to(device)).item()*input.size(0)
81 | val_num += input.size(0)
82 | val_loss_all.append(val_loss/val_num)
83 | # val_accuracy /= len(valloader)
84 |
85 | # Update learning rate
86 | scheduler.step()
87 |
88 | print(
89 | f"Epoch {epoch + 1}, Val Loss: {val_loss/val_num}"
90 | )
91 |
92 | # Save the trained model
93 | torch.save(model.state_dict(), "./model/kan.pth")
94 |
95 | # Plot the loss values against the number of epochs
96 | fig, ax = plt.subplots()
97 | ax.plot(range(1, epoch + 2), train_loss_all, label='Train Loss')
98 | ax.plot(range(1, epoch + 2), val_loss_all, label='Val Loss')
99 | ax.set_title('Loss Curves')
100 | ax.set_xlabel('Epochs')
101 | ax.set_ylabel('Loss')
102 | ax.legend()
103 | plt.show()
104 |
--------------------------------------------------------------------------------
/train_mlp.py:
--------------------------------------------------------------------------------
1 | # from common.public import public # type: ignore
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from tqdm import tqdm
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset, TensorDataset
9 | from sklearn.model_selection import train_test_split
10 | import pandas as pd
11 | import matplotlib.pyplot as plt
12 | from efficient_kan.mlp import MLP
13 |
14 | # Load data
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 | input = pd.read_csv('./data/train_input.csv', header=None)
17 | target = pd.read_csv('./data/train_target.csv', header=None)
18 | input_arr = np.float32(np.array(input))
19 | target_arr = np.float32(np.array(target))
20 | train_input, val_input, train_target, val_target = train_test_split(input_arr, target_arr,
21 | test_size=0.2,
22 | random_state=42)
23 | train_input_tensor = torch.tensor(train_input)
24 | train_target_tensor = torch.tensor(train_target)
25 | val_input_tensor = torch.tensor(val_input)
26 | val_target_tensor = torch.tensor(val_target)
27 |
28 | # print(train_data.shape, val_data.shape, train_targets.shape, val_targets.shape)
29 | train_dataset = torch.utils.data.TensorDataset(train_input_tensor, train_target_tensor)
30 | val_dataset = torch.utils.data.TensorDataset(val_input_tensor, val_target_tensor)
31 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
32 | val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
33 |
34 | # Define model
35 | model = MLP(51, 100, 4)
36 | model.to(device)
37 | # Define optimizer
38 | # optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4)
39 | optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
40 | # Define learning rate scheduler
41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
42 |
43 | # Define loss
44 | # loss_func = nn.CrossEntropyLoss()
45 | loss_func = nn.MSELoss()
46 | train_loss_all = []
47 | val_loss_all = []
48 | losses = []
49 | for epoch in range(30):
50 | # Train
51 | train_loss = 0
52 | train_num = 0
53 | model.train()
54 | with tqdm(train_loader) as pbar:
55 | running_loss = 0.0
56 | for i, (input, target) in enumerate(pbar):
57 | input = input.view(-1, 51).to(device)
58 | optimizer.zero_grad()
59 | output = model(input)
60 | loss = loss_func(output, target.to(device))
61 | optimizer.zero_grad()
62 | loss.backward()
63 | optimizer.step()
64 | train_loss += loss.item()*input.size(0)
65 | train_num += input.size(0)
66 | # print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
67 | pbar.set_postfix(lr=optimizer.param_groups[0]['lr'])
68 | train_loss_all.append(train_loss/train_num)
69 | pbar.set_postfix(loss=train_loss/train_num)
70 |
71 | # Validation
72 | model.eval()
73 | val_loss = 0
74 | val_accuracy = 0
75 | val_num = 0
76 | with torch.no_grad():
77 | for input, target in val_loader:
78 | input = input.view(-1, 51).to(device)
79 | output = model(input)
80 | val_loss += loss_func(output, target.to(device)).item()*input.size(0)
81 | val_num += input.size(0)
82 | val_loss_all.append(val_loss/val_num)
83 | # val_accuracy /= len(valloader)
84 |
85 | # Update learning rate
86 | scheduler.step()
87 |
88 | print(
89 | f"Epoch {epoch + 1}, Val Loss: {val_loss/val_num}"
90 | )
91 |
92 | # Save the trained model
93 | torch.save(model.state_dict(), "./model/kan.pth")
94 |
95 | # Plot the loss values against the number of epochs
96 | fig, ax = plt.subplots()
97 | ax.plot(range(1, epoch + 2), train_loss_all, label='Train Loss')
98 | ax.plot(range(1, epoch + 2), val_loss_all, label='Val Loss')
99 | ax.set_title('Loss Curves')
100 | ax.set_xlabel('Epochs')
101 | ax.set_ylabel('Loss')
102 | ax.legend()
103 | plt.show()
104 |
--------------------------------------------------------------------------------