├── .gitignore
├── 1d_reaction_point_optimization.py
├── 1d_reaction_region_optimization.py
├── 1d_wave_point_optimization.py
├── 1d_wave_region_optimization.py
├── LICENSE
├── README.md
├── convection_point_optimization.py
├── convection_region_optimization.py
├── model_dict.py
├── models
├── FLS.py
├── KAN.py
├── LBFGS.py
├── PINN.py
├── PINNsFormer.py
├── PINNsFormer_Enc_Only.py
├── QRes.py
├── Symbolic_KANLayer.py
├── kan_layer.py
├── spline.py
└── utils.py
├── pic
├── algorithm.png
├── comparison.png
└── results.png
├── requirements.txt
├── scripts
├── 1d_reaction_point.sh
├── 1d_reaction_region.sh
├── 1d_wave_point.sh
├── 1d_wave_region.sh
├── convection_point.sh
└── convection_region.sh
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/1d_reaction_point_optimization.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import matplotlib.pyplot as plt
6 | import random
7 | from torch.optim import LBFGS
8 | from tqdm import tqdm
9 | import argparse
10 | from util import *
11 | from model_dict import get_model
12 |
13 | seed = 0
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 |
19 | parser = argparse.ArgumentParser('Training Point Optimization')
20 | parser.add_argument('--model', type=str, default='pinn')
21 | parser.add_argument('--device', type=str, default='cuda:0')
22 | args = parser.parse_args()
23 | device = args.device
24 |
25 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
26 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
27 |
28 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
29 | res = make_time_sequence(res, num_step=5, step=1e-4)
30 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
31 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
32 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
33 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
34 |
35 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
36 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
37 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
38 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
39 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
40 |
41 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
42 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
43 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
44 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
45 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
46 |
47 |
48 | def init_weights(m):
49 | if isinstance(m, nn.Linear):
50 | torch.nn.init.xavier_uniform(m.weight)
51 | m.bias.data.fill_(0.01)
52 |
53 |
54 | if args.model == 'KAN':
55 | model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \
56 | noise_scale_base=0.25, device=device).to(device)
57 | elif args.model == 'QRes':
58 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
59 | model.apply(init_weights)
60 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
61 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
62 | model.apply(init_weights)
63 | else:
64 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
65 | model.apply(init_weights)
66 |
67 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
68 |
69 | print(model)
70 | print(get_n_params(model))
71 | loss_track = []
72 |
73 | for i in tqdm(range(1000)):
74 | def closure():
75 | pred_res = model(x_res, t_res)
76 | pred_left = model(x_left, t_left)
77 | pred_right = model(x_right, t_right)
78 | pred_upper = model(x_upper, t_upper)
79 | pred_lower = model(x_lower, t_lower)
80 |
81 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
82 | create_graph=True)[0]
83 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
84 | create_graph=True)[0]
85 |
86 | loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
87 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
88 | loss_ic = torch.mean(
89 | (pred_left[:, 0] - torch.exp(- (x_left[:, 0] - torch.pi) ** 2 / (2 * (torch.pi / 4) ** 2))) ** 2)
90 |
91 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])
92 |
93 | loss = loss_res + loss_bc + loss_ic
94 | optim.zero_grad()
95 | loss.backward()
96 | return loss
97 |
98 |
99 | optim.step(closure)
100 |
101 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
102 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
103 |
104 | if not os.path.exists('./results/'):
105 | os.makedirs('./results/')
106 | torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_point.pt')
107 |
108 | # Visualize
109 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
110 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
111 |
112 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
113 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
114 |
115 | with torch.no_grad():
116 | pred = model(x_test, t_test)[:, 0:1]
117 | pred = pred.cpu().detach().numpy()
118 |
119 | pred = pred.reshape(101, 101)
120 |
121 |
122 | def h(x):
123 | return np.exp(- (x - np.pi) ** 2 / (2 * (np.pi / 4) ** 2))
124 |
125 |
126 | def u_ana(x, t):
127 | return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x))
128 |
129 |
130 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
131 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
132 |
133 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
134 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
135 |
136 | print('relative L1 error: {:4f}'.format(rl1))
137 | print('relative L2 error: {:4f}'.format(rl2))
138 |
139 | plt.figure(figsize=(4, 3))
140 | plt.imshow(pred, aspect='equal')
141 | plt.xlabel('x')
142 | plt.ylabel('t')
143 | plt.title('Predicted u(x,t)')
144 | plt.colorbar()
145 | plt.tight_layout()
146 | plt.axis('off')
147 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight')
148 |
149 | plt.figure(figsize=(4, 3))
150 | plt.imshow(u, aspect='equal')
151 | plt.xlabel('x')
152 | plt.ylabel('t')
153 | plt.title('Exact u(x,t)')
154 | plt.colorbar()
155 | plt.tight_layout()
156 | plt.axis('off')
157 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight')
158 |
159 | plt.figure(figsize=(4, 3))
160 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.15, vmax=0.15)
161 | plt.xlabel('x')
162 | plt.ylabel('t')
163 | plt.title('Absolute Error')
164 | plt.colorbar()
165 | plt.tight_layout()
166 | plt.axis('off')
167 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight')
168 |
--------------------------------------------------------------------------------
/1d_reaction_region_optimization.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import matplotlib.pyplot as plt
6 | import random
7 | from torch.optim import LBFGS
8 | from tqdm import tqdm
9 | import argparse
10 | from util import *
11 | from model_dict import get_model
12 |
13 | seed = 0
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 |
19 | parser = argparse.ArgumentParser('Training Region Optimization')
20 | parser.add_argument('--model', type=str, default='pinn')
21 | parser.add_argument('--device', type=str, default='cuda:0')
22 | parser.add_argument('--initial_region', type=float, default=1e-4)
23 | parser.add_argument('--sample_num', type=int, default=1)
24 | parser.add_argument('--past_iterations', type=int, default=10)
25 | args = parser.parse_args()
26 | device = args.device
27 |
28 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
29 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
30 |
31 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
32 | res = make_time_sequence(res, num_step=5, step=1e-4)
33 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
34 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
35 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
36 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
37 |
38 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
39 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
40 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
41 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
42 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
43 |
44 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
45 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
46 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
47 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
48 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
49 |
50 |
51 | def init_weights(m):
52 | if isinstance(m, nn.Linear):
53 | torch.nn.init.xavier_uniform(m.weight)
54 | m.bias.data.fill_(0.01)
55 |
56 |
57 | if args.model == 'KAN':
58 | model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \
59 | noise_scale_base=0.25, device=device).to(device)
60 | elif args.model == 'QRes':
61 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device)
62 | model.apply(init_weights)
63 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
64 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
65 | model.apply(init_weights)
66 | else:
67 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
68 | model.apply(init_weights)
69 |
70 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
71 |
72 | print(model)
73 | print(get_n_params(model))
74 | loss_track = []
75 |
76 | # for region optimization
77 | initial_region = args.initial_region
78 | sample_num = args.sample_num
79 | past_iterations = args.past_iterations
80 | gradient_list_overall = []
81 | gradient_list_temp = []
82 | gradient_variance = 1
83 |
84 | for i in tqdm(range(1000)):
85 |
86 | ###### Region Optimization with Monte Carlo Approximation ######
87 | def closure():
88 | x_res_region_sample_list = []
89 | t_res_region_sample_list = []
90 | for i in range(sample_num):
91 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance,
92 | a_min=0,
93 | a_max=0.01)
94 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance,
95 | a_min=0,
96 | a_max=0.01)
97 | x_res_region_sample_list.append(x_res + x_region_sample)
98 | t_res_region_sample_list.append(t_res + t_region_sample)
99 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0)
100 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0)
101 | pred_res = model(x_res_region_sample, t_res_region_sample)
102 | pred_left = model(x_left, t_left)
103 | pred_right = model(x_right, t_right)
104 | pred_upper = model(x_upper, t_upper)
105 | pred_lower = model(x_lower, t_lower)
106 |
107 | u_x = \
108 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res),
109 | retain_graph=True,
110 | create_graph=True)[0]
111 | u_t = \
112 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res),
113 | retain_graph=True,
114 | create_graph=True)[0]
115 |
116 | loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
117 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
118 | loss_ic = torch.mean(
119 | (pred_left[:, 0] - torch.exp(- (x_left[:, 0] - torch.pi) ** 2 / (2 * (torch.pi / 4) ** 2))) ** 2)
120 |
121 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])
122 |
123 | loss = loss_res + loss_bc + loss_ic
124 | optim.zero_grad()
125 | loss.backward(retain_graph=True)
126 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in
127 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph
128 | return loss
129 |
130 |
131 | optim.step(closure)
132 |
133 | ###### Trust Region Calibration ######
134 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0))
135 | gradient_list_overall = gradient_list_overall[-past_iterations:]
136 | gradient_list = np.array(gradient_list_overall)
137 | gradient_variance = (
138 | np.std(gradient_list, axis=0) / (np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean()
139 | gradient_list_temp = []
140 | if gradient_variance == 0:
141 | gradient_variance = 1 # for numerical stability
142 |
143 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
144 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
145 |
146 | if not os.path.exists('./results/'):
147 | os.makedirs('./results/')
148 | torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_region.pt')
149 |
150 | # Visualize
151 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
152 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
153 |
154 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
155 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
156 |
157 | with torch.no_grad():
158 | pred = model(x_test, t_test)[:, 0:1]
159 | pred = pred.cpu().detach().numpy()
160 |
161 | pred = pred.reshape(101, 101)
162 |
163 |
164 | def h(x):
165 | return np.exp(- (x - np.pi) ** 2 / (2 * (np.pi / 4) ** 2))
166 |
167 |
168 | def u_ana(x, t):
169 | return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x))
170 |
171 |
172 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
173 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
174 |
175 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
176 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
177 |
178 | print('relative L1 error: {:4f}'.format(rl1))
179 | print('relative L2 error: {:4f}'.format(rl2))
180 |
181 | plt.figure(figsize=(4, 3))
182 | plt.imshow(pred, aspect='equal')
183 | plt.xlabel('x')
184 | plt.ylabel('t')
185 | plt.title('Predicted u(x,t)')
186 | plt.colorbar()
187 | plt.tight_layout()
188 | plt.axis('off')
189 | plt.savefig(f'./results/1dreaction_{args.model}_region_optimization_pred.pdf', bbox_inches='tight')
190 |
191 | plt.figure(figsize=(4, 3))
192 | plt.imshow(u, aspect='equal')
193 | plt.xlabel('x')
194 | plt.ylabel('t')
195 | plt.title('Exact u(x,t)')
196 | plt.colorbar()
197 | plt.tight_layout()
198 | plt.axis('off')
199 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight')
200 |
201 | plt.figure(figsize=(4, 3))
202 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.15, vmax=0.15)
203 | plt.xlabel('x')
204 | plt.ylabel('t')
205 | plt.title('Absolute Error')
206 | plt.colorbar()
207 | plt.tight_layout()
208 | plt.axis('off')
209 | plt.savefig(f'./results/1dreaction_{args.model}_region_optimization_error.pdf', bbox_inches='tight')
210 |
--------------------------------------------------------------------------------
/1d_wave_point_optimization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import matplotlib.pyplot as plt
5 | import random
6 | from torch.optim import LBFGS, Adam
7 | from tqdm import tqdm
8 | import argparse
9 | from util import *
10 | from model_dict import get_model
11 |
12 | seed = 0
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 |
18 | parser = argparse.ArgumentParser('Training Point Optimization')
19 | parser.add_argument('--model', type=str, default='PINN')
20 | parser.add_argument('--device', type=str, default='cuda:0')
21 | args = parser.parse_args()
22 | device = args.device
23 |
24 | res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101)
25 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)
26 |
27 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
28 | res = make_time_sequence(res, num_step=5, step=1e-4)
29 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
30 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
31 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
32 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
33 |
34 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
35 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
36 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
37 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
38 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
39 |
40 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
41 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
42 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
43 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
44 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
45 |
46 |
47 | def init_weights(m):
48 | if isinstance(m, nn.Linear):
49 | torch.nn.init.xavier_uniform(m.weight)
50 | m.bias.data.fill_(0.01)
51 |
52 |
53 | if args.model == 'KAN':
54 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
55 | noise_scale_base=0.25, device=device).to(device)
56 | elif args.model == 'QRes':
57 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
58 | model.apply(init_weights)
59 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
60 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
61 | model.apply(init_weights)
62 | else:
63 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
64 | model.apply(init_weights)
65 |
66 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
67 |
68 | n_params = get_n_params(model)
69 |
70 | print(model)
71 | print(get_n_params(model))
72 |
73 | loss_track = []
74 | pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False).to(device)
75 |
76 | for i in tqdm(range(1000)):
77 | def closure():
78 | pred_res = model(x_res, t_res)
79 | pred_left = model(x_left, t_left)
80 | pred_right = model(x_right, t_right)
81 | pred_upper = model(x_upper, t_upper)
82 | pred_lower = model(x_lower, t_lower)
83 |
84 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
85 | create_graph=True)[0]
86 | u_xx = \
87 | torch.autograd.grad(u_x, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
88 | create_graph=True)[0]
89 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
90 | create_graph=True)[0]
91 | u_tt = \
92 | torch.autograd.grad(u_t, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
93 | create_graph=True)[0]
94 |
95 | loss_res = torch.mean((u_tt - 4 * u_xx) ** 2)
96 | loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower) ** 2)
97 |
98 | ui_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True,
99 | create_graph=True)[0]
100 |
101 | loss_ic_1 = torch.mean(
102 | (pred_left[:, 0] - torch.sin(pi * x_left[:, 0]) - 0.5 * torch.sin(3 * pi * x_left[:, 0])) ** 2)
103 | loss_ic_2 = torch.mean((ui_t) ** 2)
104 |
105 | loss_ic = loss_ic_1 + loss_ic_2
106 |
107 | loss_track.append([loss_res.item(), loss_ic.item(), loss_bc.item()])
108 |
109 | loss = loss_res + loss_ic + loss_bc
110 | optim.zero_grad()
111 | loss.backward()
112 | return loss
113 |
114 |
115 | optim.step(closure)
116 |
117 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
118 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
119 |
120 | if not os.path.exists('./results/'):
121 | os.makedirs('./results/')
122 |
123 | torch.save(model.state_dict(), f'./results/1dwave_{args.model}_point.pt')
124 |
125 | # Visualize
126 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
127 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
128 |
129 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
130 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
131 |
132 | with torch.no_grad():
133 | pred = model(x_test, t_test)[:, 0:1]
134 | pred = pred.cpu().detach().numpy()
135 |
136 | pred = pred.reshape(101, 101)
137 |
138 |
139 | def u_ana(x, t):
140 | return np.sin(np.pi * x) * np.cos(2 * np.pi * t) + 0.5 * np.sin(3 * np.pi * x) * np.cos(6 * np.pi * t)
141 |
142 |
143 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)
144 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
145 |
146 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
147 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
148 |
149 | print('relative L1 error: {:4f}'.format(rl1))
150 | print('relative L2 error: {:4f}'.format(rl2))
151 |
152 | plt.figure(figsize=(4, 3))
153 | plt.imshow(pred, aspect='equal')
154 | plt.xlabel('x')
155 | plt.ylabel('t')
156 | plt.title('Predicted u(x,t)')
157 | plt.colorbar()
158 | plt.tight_layout()
159 | plt.axis('off')
160 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight')
161 |
162 | plt.figure(figsize=(4, 3))
163 | plt.imshow(u, aspect='equal')
164 | plt.xlabel('x')
165 | plt.ylabel('t')
166 | plt.title('Exact u(x,t)')
167 | plt.colorbar()
168 | plt.tight_layout()
169 | plt.axis('off')
170 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight')
171 |
172 | plt.figure(figsize=(4, 3))
173 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.3, vmax=0.3)
174 | plt.xlabel('x')
175 | plt.ylabel('t')
176 | plt.title('Absolute Error')
177 | plt.colorbar()
178 | plt.tight_layout()
179 | plt.axis('off')
180 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight')
181 |
--------------------------------------------------------------------------------
/1d_wave_region_optimization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import matplotlib.pyplot as plt
5 | import random
6 | from torch.optim import LBFGS, Adam
7 | from tqdm import tqdm
8 | import argparse
9 | from util import *
10 | from model_dict import get_model
11 |
12 | seed = 0
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 |
18 | parser = argparse.ArgumentParser('Training Region Optimization')
19 | parser.add_argument('--model', type=str, default='pinn')
20 | parser.add_argument('--device', type=str, default='cuda:0')
21 | parser.add_argument('--initial_region', type=float, default=1e-4)
22 | parser.add_argument('--sample_num', type=int, default=1)
23 | parser.add_argument('--past_iterations', type=int, default=10)
24 | args = parser.parse_args()
25 | device = args.device
26 |
27 | res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101)
28 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)
29 |
30 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
31 | res = make_time_sequence(res, num_step=5, step=1e-4)
32 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
33 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
34 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
35 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
36 |
37 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
38 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
39 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
40 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
41 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
42 |
43 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
44 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
45 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
46 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
47 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
48 |
49 |
50 | def init_weights(m):
51 | if isinstance(m, nn.Linear):
52 | torch.nn.init.xavier_uniform(m.weight)
53 | m.bias.data.fill_(0.01)
54 |
55 |
56 | if args.model == 'KAN':
57 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
58 | noise_scale_base=0.25, device=device).to(device)
59 | elif args.model == 'QRes':
60 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device)
61 | model.apply(init_weights)
62 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
63 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
64 | model.apply(init_weights)
65 | else:
66 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
67 | model.apply(init_weights)
68 |
69 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
70 |
71 | n_params = get_n_params(model)
72 |
73 | print(model)
74 | print(get_n_params(model))
75 | loss_track = []
76 | pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False).to(device)
77 |
78 | # for region optimization
79 | initial_region = args.initial_region
80 | sample_num = args.sample_num
81 | past_iterations = args.past_iterations
82 | gradient_list_overall = []
83 | gradient_list_temp = []
84 | gradient_variance = 1
85 |
86 | for i in tqdm(range(1000)):
87 |
88 | ###### Region Optimization with Monte Carlo Approximation ######
89 | def closure():
90 | x_res_region_sample_list = []
91 | t_res_region_sample_list = []
92 | for i in range(sample_num):
93 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance,
94 | a_min=0,
95 | a_max=0.01)
96 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance,
97 | a_min=0,
98 | a_max=0.01)
99 | x_res_region_sample_list.append(x_res + x_region_sample)
100 | t_res_region_sample_list.append(t_res + t_region_sample)
101 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0)
102 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0)
103 | pred_res = model(x_res_region_sample, t_res_region_sample)
104 | pred_left = model(x_left, t_left)
105 | pred_right = model(x_right, t_right)
106 | pred_upper = model(x_upper, t_upper)
107 | pred_lower = model(x_lower, t_lower)
108 |
109 | u_x = \
110 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res),
111 | retain_graph=True,
112 | create_graph=True)[0]
113 | u_xx = torch.autograd.grad(u_x, x_res_region_sample, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
114 | create_graph=True)[0]
115 | u_t = \
116 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res),
117 | retain_graph=True,
118 | create_graph=True)[0]
119 | u_tt = torch.autograd.grad(u_t, t_res_region_sample, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
120 | create_graph=True)[0]
121 |
122 | loss_res = torch.mean((u_tt - 4 * u_xx) ** 2)
123 | loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower) ** 2)
124 |
125 | ui_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True,
126 | create_graph=True)[0]
127 |
128 | loss_ic_1 = torch.mean(
129 | (pred_left[:, 0] - torch.sin(pi * x_left[:, 0]) - 0.5 * torch.sin(3 * pi * x_left[:, 0])) ** 2)
130 | loss_ic_2 = torch.mean((ui_t) ** 2)
131 |
132 | loss_ic = loss_ic_1 + loss_ic_2
133 |
134 | loss_track.append([loss_res.item(), loss_ic.item(), loss_bc.item()])
135 |
136 | loss = loss_res + loss_ic + loss_bc
137 | optim.zero_grad()
138 | loss.backward(retain_graph=True)
139 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in
140 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph
141 | return loss
142 |
143 |
144 | optim.step(closure)
145 |
146 | ###### Trust Region Calibration ######
147 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0))
148 | gradient_list_overall = gradient_list_overall[-past_iterations:]
149 | gradient_list = np.array(gradient_list_overall)
150 | gradient_variance = (np.std(gradient_list, axis=0) / (
151 | np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean() # normalized variance
152 | gradient_list_temp = []
153 | print(gradient_variance)
154 | if gradient_variance == 0:
155 | gradient_variance = 1 # for numerical stability
156 |
157 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
158 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
159 |
160 | if not os.path.exists('./results/'):
161 | os.makedirs('./results/')
162 | torch.save(model.state_dict(), f'./results/1dwave_{args.model}_region.pt')
163 |
164 | # Visualize PINNs
165 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
166 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
167 |
168 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
169 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
170 |
171 | with torch.no_grad():
172 | pred = model(x_test, t_test)[:, 0:1]
173 | pred = pred.cpu().detach().numpy()
174 |
175 | pred = pred.reshape(101, 101)
176 |
177 |
178 | def u_ana(x, t):
179 | return np.sin(np.pi * x) * np.cos(2 * np.pi * t) + 0.5 * np.sin(3 * np.pi * x) * np.cos(6 * np.pi * t)
180 |
181 |
182 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101)
183 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
184 |
185 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
186 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
187 |
188 | print('relative L1 error: {:4f}'.format(rl1))
189 | print('relative L2 error: {:4f}'.format(rl2))
190 |
191 | plt.figure(figsize=(4, 3))
192 | plt.imshow(pred, aspect='equal')
193 | plt.xlabel('x')
194 | plt.ylabel('t')
195 | plt.title('Predicted u(x,t)')
196 | plt.colorbar()
197 | plt.tight_layout()
198 | plt.axis('off')
199 | plt.savefig(f'./results/1dwave_{args.model}_region_optimization_pred.pdf', bbox_inches='tight')
200 |
201 | plt.figure(figsize=(4, 3))
202 | plt.imshow(u, aspect='equal')
203 | plt.xlabel('x')
204 | plt.ylabel('t')
205 | plt.title('Exact u(x,t)')
206 | plt.colorbar()
207 | plt.tight_layout()
208 | plt.axis('off')
209 | plt.savefig('./results/1dwave_exact.pdf', bbox_inches='tight')
210 |
211 | plt.figure(figsize=(4, 3))
212 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.3, vmax=0.3)
213 | plt.xlabel('x')
214 | plt.ylabel('t')
215 | plt.title('Absolute Error')
216 | plt.colorbar()
217 | plt.tight_layout()
218 | plt.axis('off')
219 | plt.savefig(f'./results/1dwave_{args.model}_region_optimization_error.pdf', bbox_inches='tight')
220 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 THUML @ Tsinghua University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RoPINN (NeurIPS 2024)
2 |
3 | RoPINN: Region Optimized Physics-Informed Neural Networks. See [Paper](https://arxiv.org/abs/2405.14369) or [Slides](https://wuhaixu2016.github.io/pdf/NeurIPS2024_RoPINN.pdf).
4 |
5 | This paper proposes and theoretically studies a new training paradigm of PINNs as **region optimization** and presents [RoPINN](https://arxiv.org/abs/2405.14369) as a practical algorithm, which can bring the following benefits:
6 |
7 | - **Better generalization bound:** Introducing "region" can theoretically decrease generalization error and provide a general theoretical framework that first reveals the balance between generalization and optimization.
8 | - **Efficient practical algorithm:** We present RoPINN with a trust region calibration strategy, which can effectively accomplish the region optimization and reduce the gradient estimation error caused by sampling.
9 | - **Boost extensive backbones:** RoPINN consistently improves various PINN backbones (i.e. PINN, KAN and PINNsFormer) on a wide range of PDEs (19 different tasks) without extra gradient calculation.
10 |
11 | ## Point Optimization vs. Region Optimization
12 |
13 | Unlike conventional point optimization, our proposed region optimization extends the optimization process of PINNs from isolated points to their continuous neighborhood region.
14 |
15 |
16 |
17 |
18 | Figure 1. Comparison between previous methods and RoPINN.
19 |
20 |
21 | ## Practical Algorithm
22 |
23 | We present RoPINN for PINN training based on Monte Carlo sampling, which can effectively accomplish the region optimization without extra gradient calculation. A trust region calibration strategy is proposed to reduce the gradient estimation error caused by sampling for more trustworthy optimization.
24 |
25 |
26 |
27 |
28 | Figure 2. RoPINN algorithm.
29 |
30 |
31 | ## Get Started
32 |
33 | 1. Install Python 3.8 or Python 3.9 and **Pytorch 1.13.0**. For convenience, execute the following command.
34 |
35 | ```shell
36 | pip install -r requirements.txt
37 | ```
38 |
39 | 2. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
40 |
41 | ```shell
42 | bash scripts/1d_reaction_point.sh # canonical point optimization
43 | bash scripts/1d_reaction_region.sh # RoPINN: region optimization
44 | bash scripts/1d_wave_point.sh # canonical point optimization
45 | bash scripts/1d_wave_region.sh # RoPINN: region optimization
46 | bash scripts/convection_point.sh # canonical point optimization
47 | bash scripts/convection_region.sh # RoPINN: region optimization
48 | ```
49 |
50 | Specifically, we have included the following PINN models in this repo:
51 |
52 | - [x] PINN (Journal of Computational Physics 2019) [[Paper]](https://github.com/maziarraissi/PINNs)
53 | - [x] FLS - (IEEE Transactions on Artificial Intelligence 2022) [[Paper]](https://arxiv.org/abs/2109.09338)
54 | - [x] QRes - (SIAM 2021) [[Paper]](https://arxiv.org/abs/2101.08366)
55 | - [x] KAN - (arXiv 2024) [[Paper]](https://arxiv.org/abs/2404.19756)
56 | - [x] PINNsFormer - (ICLR 2024) [[Paper]](https://arxiv.org/abs/2307.11833)
57 |
58 | ## Results
59 |
60 | We have experimented with 19 different PDE tasks. See [our paper](https://arxiv.org/abs/2405.14369) for the full results.
61 |
62 |
63 |
64 |
65 | Figure 3. Part of experimental results of RoPINN.
66 |
67 |
68 | ## Citation
69 |
70 | If you find this repo useful, please cite our paper.
71 |
72 | ```
73 | @inproceedings{wu2024ropinn,
74 | title={RoPINN: Region Optimized Physics-Informed Neural Networks},
75 | author={Haixu Wu and Huakun Luo and Yuezhou Ma and Jianmin Wang and Mingsheng Long},
76 | booktitle={Advances in Neural Information Processing Systems},
77 | year={2024}
78 | }
79 | ```
80 |
81 | ## Contact
82 |
83 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn).
84 |
85 | ## Acknowledgement
86 |
87 | We appreciate the following GitHub repos a lot for their valuable code base or datasets:
88 |
89 | https://github.com/AdityaLab/pinnsformer
90 |
91 | https://github.com/i207M/PINNacle
92 |
--------------------------------------------------------------------------------
/convection_point_optimization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import matplotlib.pyplot as plt
4 | import random
5 | from torch.optim import LBFGS
6 | from tqdm import tqdm
7 | import os
8 | import argparse
9 | from util import *
10 | from model_dict import get_model
11 |
12 | seed = 0
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 |
18 | parser = argparse.ArgumentParser('Training Point Optimization')
19 | parser.add_argument('--model', type=str, default='pinn')
20 | parser.add_argument('--device', type=str, default='cuda:0')
21 | args = parser.parse_args()
22 | device = args.device
23 |
24 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
25 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
26 |
27 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
28 | res = make_time_sequence(res, num_step=5, step=1e-4)
29 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
30 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
31 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
32 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
33 |
34 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
35 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
36 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
37 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
38 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
39 |
40 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
41 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
42 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
43 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
44 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
45 |
46 |
47 | def init_weights(m):
48 | if isinstance(m, nn.Linear):
49 | torch.nn.init.xavier_uniform(m.weight)
50 | m.bias.data.fill_(0.0)
51 |
52 |
53 | if args.model == 'KAN':
54 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
55 | noise_scale_base=0.25, device=device).to(device)
56 | elif args.model == 'QRes':
57 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device)
58 | model.apply(init_weights)
59 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
60 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
61 | model.apply(init_weights)
62 | else:
63 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
64 | model.apply(init_weights)
65 |
66 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
67 |
68 | print(model)
69 | print(get_n_params(model))
70 |
71 | loss_track = []
72 |
73 | for i in tqdm(range(1000)):
74 | def closure():
75 | pred_res = model(x_res, t_res)
76 | pred_left = model(x_left, t_left)
77 | pred_right = model(x_right, t_right)
78 | pred_upper = model(x_upper, t_upper)
79 | pred_lower = model(x_lower, t_lower)
80 |
81 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
82 | create_graph=True)[0]
83 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True,
84 | create_graph=True)[0]
85 |
86 | loss_res = torch.mean((u_t + 50 * u_x) ** 2)
87 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
88 | loss_ic = torch.mean((pred_left[:, 0] - torch.sin(x_left[:, 0])) ** 2)
89 |
90 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])
91 |
92 | loss = loss_res + loss_bc + loss_ic
93 | optim.zero_grad()
94 | loss.backward()
95 | return loss
96 |
97 |
98 | optim.step(closure)
99 |
100 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
101 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
102 |
103 | if not os.path.exists('./results/'):
104 | os.makedirs('./results/')
105 |
106 | torch.save(model.state_dict(), f'./results/1dconvection_{args.model}_point.pt')
107 |
108 | # Visualize
109 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
110 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
111 |
112 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
113 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
114 |
115 | with torch.no_grad():
116 | pred = model(x_test, t_test)[:, 0:1]
117 | pred = pred.cpu().detach().numpy()
118 |
119 | pred = pred.reshape(101, 101)
120 |
121 |
122 | def u_res(x, t):
123 | print(x.shape)
124 | print(t.shape)
125 | return np.sin(x - 50 * t)
126 |
127 |
128 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
129 | u = u_res(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
130 |
131 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
132 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
133 |
134 | print('relative L1 error: {:4f}'.format(rl1))
135 | print('relative L2 error: {:4f}'.format(rl2))
136 |
137 | plt.figure(figsize=(4, 3))
138 | plt.imshow(pred, aspect='equal')
139 | plt.xlabel('x')
140 | plt.ylabel('t')
141 | plt.title('Predicted u(x,t)')
142 | plt.colorbar()
143 | plt.tight_layout()
144 | plt.axis('off')
145 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight')
146 |
147 | plt.figure(figsize=(4, 3))
148 | plt.imshow(u, aspect='equal')
149 | plt.xlabel('x')
150 | plt.ylabel('t')
151 | plt.title('Exact u(x,t)')
152 | plt.colorbar()
153 | plt.tight_layout()
154 | plt.axis('off')
155 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight')
156 |
157 | plt.figure(figsize=(4, 3))
158 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-1, vmax=1)
159 | plt.xlabel('x')
160 | plt.ylabel('t')
161 | plt.title('Absolute Error')
162 | plt.colorbar()
163 | plt.tight_layout()
164 | plt.axis('off')
165 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight')
166 |
--------------------------------------------------------------------------------
/convection_region_optimization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | import random
6 | from torch.optim import LBFGS
7 | from tqdm import tqdm
8 | import argparse
9 | from util import *
10 | from model_dict import get_model
11 |
12 | seed = 0
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 |
18 | parser = argparse.ArgumentParser('Training Region Optimization')
19 | parser.add_argument('--model', type=str, default='pinn')
20 | parser.add_argument('--device', type=str, default='cuda:0')
21 | parser.add_argument('--initial_region', type=float, default=1e-4)
22 | parser.add_argument('--sample_num', type=int, default=1)
23 | parser.add_argument('--past_iterations', type=int, default=5)
24 | args = parser.parse_args()
25 | device = args.device
26 |
27 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101)
28 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
29 |
30 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
31 | res = make_time_sequence(res, num_step=5, step=1e-4)
32 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
33 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
34 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
35 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)
36 |
37 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
38 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
39 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
40 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
41 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
42 |
43 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2]
44 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2]
45 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2]
46 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2]
47 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2]
48 |
49 |
50 | def init_weights(m):
51 | if isinstance(m, nn.Linear):
52 | torch.nn.init.xavier_uniform(m.weight)
53 | m.bias.data.fill_(0.0)
54 |
55 |
56 | if args.model == 'KAN':
57 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \
58 | noise_scale_base=0.25, device=device).to(device)
59 | elif args.model == 'QRes':
60 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device)
61 | model.apply(init_weights)
62 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
63 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device)
64 | model.apply(init_weights)
65 | else:
66 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device)
67 | model.apply(init_weights)
68 |
69 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')
70 |
71 | print(model)
72 | print(get_n_params(model))
73 | loss_track = []
74 |
75 | # for region optimization
76 | initial_region = args.initial_region
77 | sample_num = args.sample_num
78 | past_iterations = args.past_iterations
79 | gradient_list_overall = []
80 | gradient_list_temp = []
81 | gradient_variance = 1
82 |
83 | for i in tqdm(range(1000)):
84 |
85 | ###### Region Optimization with Monte Carlo Approximation ######
86 | def closure():
87 | x_res_region_sample_list = []
88 | t_res_region_sample_list = []
89 | for i in range(sample_num):
90 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance,
91 | a_min=0,
92 | a_max=0.01)
93 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance,
94 | a_min=0,
95 | a_max=0.01)
96 | x_res_region_sample_list.append(x_res + x_region_sample)
97 | t_res_region_sample_list.append(t_res + t_region_sample)
98 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0)
99 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0)
100 | pred_res = model(x_res_region_sample, t_res_region_sample)
101 | pred_left = model(x_left, t_left)
102 | pred_right = model(x_right, t_right)
103 | pred_upper = model(x_upper, t_upper)
104 | pred_lower = model(x_lower, t_lower)
105 |
106 | u_x = \
107 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res),
108 | retain_graph=True,
109 | create_graph=True)[0]
110 | u_t = \
111 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res),
112 | retain_graph=True,
113 | create_graph=True)[0]
114 |
115 | loss_res = torch.mean((u_t + 50 * u_x) ** 2)
116 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
117 | loss_ic = torch.mean((pred_left[:, 0] - torch.sin(x_left[:, 0])) ** 2)
118 |
119 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])
120 |
121 | loss = loss_res + loss_bc + loss_ic
122 | optim.zero_grad()
123 | loss.backward(retain_graph=True)
124 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in
125 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph
126 | return loss
127 |
128 |
129 | optim.step(closure)
130 |
131 | ###### Trust Region Calibration ######
132 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0))
133 | gradient_list_overall = gradient_list_overall[-past_iterations:]
134 | gradient_list = np.array(gradient_list_overall)
135 | gradient_variance = (np.std(gradient_list, axis=0) / (
136 | np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean() # normalized variance
137 | gradient_list_temp = []
138 | if gradient_variance == 0:
139 | gradient_variance = 1 # for numerical stability
140 |
141 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
142 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
143 |
144 | if not os.path.exists('./results/'):
145 | os.makedirs('./results/')
146 | torch.save(model.state_dict(), f'./results/convection_{args.model}_region.pt')
147 |
148 | # Visualize
149 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only':
150 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4)
151 |
152 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
153 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2]
154 |
155 | with torch.no_grad():
156 | pred = model(x_test, t_test)[:, 0:1]
157 | pred = pred.cpu().detach().numpy()
158 |
159 | pred = pred.reshape(101, 101)
160 |
161 |
162 | def u_res(x, t):
163 | print(x.shape)
164 | print(t.shape)
165 | return np.sin(x - 50 * t)
166 |
167 |
168 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101)
169 | u = u_res(res_test[:, 0], res_test[:, 1]).reshape(101, 101)
170 |
171 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
172 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2))
173 |
174 | print('relative L1 error: {:4f}'.format(rl1))
175 | print('relative L2 error: {:4f}'.format(rl2))
176 |
177 | plt.figure(figsize=(4, 3))
178 | plt.imshow(pred, aspect='equal')
179 | plt.xlabel('x')
180 | plt.ylabel('t')
181 | plt.title('Predicted u(x,t)')
182 | plt.colorbar()
183 | plt.tight_layout()
184 | plt.axis('off')
185 | plt.savefig(f'./results/convection_{args.model}_region_optimization_pred.pdf', bbox_inches='tight')
186 |
187 | plt.figure(figsize=(4, 3))
188 | plt.imshow(u, aspect='equal')
189 | plt.xlabel('x')
190 | plt.ylabel('t')
191 | plt.title('Exact u(x,t)')
192 | plt.colorbar()
193 | plt.tight_layout()
194 | plt.axis('off')
195 | plt.savefig('./results/convection_exact.pdf', bbox_inches='tight')
196 |
197 | plt.figure(figsize=(4, 3))
198 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-1, vmax=1)
199 | plt.xlabel('x')
200 | plt.ylabel('t')
201 | plt.title('Absolute Error')
202 | plt.colorbar()
203 | plt.tight_layout()
204 | plt.axis('off')
205 | plt.savefig(f'./results/convection_{args.model}_region_optimization_error.pdf', bbox_inches='tight')
206 |
--------------------------------------------------------------------------------
/model_dict.py:
--------------------------------------------------------------------------------
1 | from models import PINN, QRes, FLS, KAN, PINNsFormer, PINNsFormer_Enc_Only
2 |
3 |
4 | def get_model(args):
5 | model_dict = {
6 | 'PINN': PINN,
7 | 'QRes': QRes,
8 | 'FLS': FLS,
9 | 'KAN': KAN,
10 | 'PINNsFormer': PINNsFormer,
11 | 'PINNsFormer_Enc_Only': PINNsFormer_Enc_Only, # more efficient and with better performance than original PINNsFormer
12 | }
13 | return model_dict[args.model]
--------------------------------------------------------------------------------
/models/FLS.py:
--------------------------------------------------------------------------------
1 | # baseline implementation of First Layer Sine
2 | # paper: Learning in Sinusoidal Spaces with Physics-Informed Neural Networks
3 | # link: https://arxiv.org/abs/2109.09338
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class SinAct(nn.Module):
10 | def __init__(self):
11 | super(SinAct, self).__init__()
12 |
13 | def forward(self, x):
14 | return torch.sin(x)
15 |
16 |
17 | class Model(nn.Module):
18 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer):
19 | super(Model, self).__init__()
20 |
21 | layers = []
22 | for i in range(num_layer - 1):
23 | if i == 0:
24 | layers.append(nn.Linear(in_features=in_dim, out_features=hidden_dim))
25 | layers.append(SinAct())
26 | else:
27 | layers.append(nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
28 | layers.append(nn.Tanh())
29 |
30 | layers.append(nn.Linear(in_features=hidden_dim, out_features=out_dim))
31 |
32 | self.linear = nn.Sequential(*layers)
33 |
34 | def forward(self, x, t):
35 | src = torch.cat((x, t), dim=-1)
36 | return self.linear(src)
37 |
--------------------------------------------------------------------------------
/models/KAN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .kan_layer import *
3 | from .Symbolic_KANLayer import *
4 | from .LBFGS import *
5 | import os
6 | import glob
7 | import matplotlib.pyplot as plt
8 | from tqdm import tqdm
9 | import random
10 | import copy
11 |
12 |
13 | class Model(nn.Module):
14 | '''
15 | KAN class
16 |
17 | Attributes:
18 | -----------
19 | biases: a list of nn.Linear()
20 | biases are added on nodes (in principle, biases can be absorbed into activation functions. However, we still have them for better optimization)
21 | act_fun: a list of KANLayer
22 | KANLayers
23 | depth: int
24 | depth of KAN
25 | width: list
26 | number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons.
27 | grid: int
28 | the number of grid intervals
29 | k: int
30 | the order of piecewise polynomial
31 | base_fun: fun
32 | residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
33 | symbolic_fun: a list of Symbolic_KANLayer
34 | Symbolic_KANLayers
35 | symbolic_enabled: bool
36 | If False, the symbolic front is not computed (to save time). Default: True.
37 |
38 | Methods:
39 | --------
40 | __init__():
41 | initialize a KAN
42 | initialize_from_another_model():
43 | initialize a KAN from another KAN (with the same shape, but potentially different grids)
44 | update_grid_from_samples():
45 | update spline grids based on samples
46 | initialize_grid_from_another_model():
47 | initalize KAN grids from another KAN
48 | forward():
49 | forward
50 | set_mode():
51 | set the mode of an activation function: 'n' for numeric, 's' for symbolic, 'ns' for combined (note they are visualized differently in plot(). 'n' as black, 's' as red, 'ns' as purple).
52 | fix_symbolic():
53 | fix an activation function to be symbolic
54 | suggest_symbolic():
55 | suggest the symbolic candicates of a numeric spline-based activation function
56 | lock():
57 | lock activation functions to share parameters
58 | unlock():
59 | unlock locked activations
60 | get_range():
61 | get the input and output ranges of an activation function
62 | plot():
63 | plot the diagram of KAN
64 | train():
65 | train KAN
66 | prune():
67 | prune KAN
68 | remove_edge():
69 | remove some edge of KAN
70 | remove_node():
71 | remove some node of KAN
72 | auto_symbolic():
73 | automatically fit all splines to be symbolic functions
74 | symbolic_formula():
75 | obtain the symbolic formula of the KAN network
76 | '''
77 |
78 | def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(),
79 | symbolic_enabled=True, bias_trainable=True, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True,
80 | sb_trainable=True,
81 | device='cpu', seed=0):
82 | '''
83 | initalize a KAN model
84 |
85 | Args:
86 | -----
87 | width : list of int
88 | :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
89 | grid : int
90 | number of grid intervals. Default: 3.
91 | k : int
92 | order of piecewise polynomial. Default: 3.
93 | noise_scale : float
94 | initial injected noise to spline. Default: 0.1.
95 | base_fun : fun
96 | the residual function b(x). Default: torch.nn.SiLU().
97 | symbolic_enabled : bool
98 | compute or skip symbolic computations (for efficiency). By default: True.
99 | bias_trainable : bool
100 | bias parameters are updated or not. By default: True
101 | grid_eps : float
102 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02.
103 | grid_range : list/np.array of shape (2,))
104 | setting the range of grids. Default: [-1,1].
105 | sp_trainable : bool
106 | If true, scale_sp is trainable. Default: True.
107 | sb_trainable : bool
108 | If true, scale_base is trainable. Default: True.
109 | device : str
110 | device
111 | seed : int
112 | random seed
113 |
114 | Returns:
115 | --------
116 | self
117 |
118 | Example
119 | -------
120 | >>> model = Model(width=[2,5,1], grid=5, k=3)
121 | >>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), (model.act_fun[1].in_dim, model.act_fun[1].out_dim)
122 | ((2, 5), (5, 1))
123 | '''
124 | super(Model, self).__init__()
125 |
126 | torch.manual_seed(seed)
127 | np.random.seed(seed)
128 | random.seed(seed)
129 |
130 | ### initializeing the numerical front ###
131 |
132 | self.biases = []
133 | self.act_fun = []
134 | self.depth = len(width) - 1
135 | self.width = width
136 |
137 | for l in range(self.depth):
138 | # splines
139 | scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base
140 | sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale,
141 | scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps,
142 | grid_range=grid_range, sp_trainable=sp_trainable,
143 | sb_trainable=sb_trainable, device=device)
144 | self.act_fun.append(sp_batch)
145 |
146 | # bias
147 | bias = nn.Linear(width[l + 1], 1, bias=False).requires_grad_(bias_trainable)
148 | bias.weight.data *= 0.
149 | self.biases.append(bias)
150 |
151 | self.biases = nn.ModuleList(self.biases)
152 | self.act_fun = nn.ModuleList(self.act_fun)
153 |
154 | self.grid = grid
155 | self.k = k
156 | self.base_fun = base_fun
157 |
158 | ### initializing the symbolic front ###
159 | self.symbolic_fun = []
160 | for l in range(self.depth):
161 | sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1])
162 | self.symbolic_fun.append(sb_batch)
163 |
164 | self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
165 | self.symbolic_enabled = symbolic_enabled
166 |
167 | def initialize_from_another_model(self, another_model, x):
168 | '''
169 | initialize from a parent model. The parent has the same width as the current model but may have different grids.
170 |
171 | Args:
172 | -----
173 | another_model : KAN
174 | the parent model used to initialize the current model
175 | x : 2D torch.float
176 | inputs, shape (batch, input dimension)
177 |
178 | Returns:
179 | --------
180 | self : KAN
181 |
182 | Example
183 | -------
184 | >>> model_coarse = KAN(width=[2,5,1], grid=5, k=3)
185 | >>> model_fine = KAN(width=[2,5,1], grid=10, k=3)
186 | >>> print(model_fine.act_fun[0].coef[0][0].data)
187 | >>> x = torch.normal(0,1,size=(100,2))
188 | >>> model_fine.initialize_from_another_model(model_coarse, x);
189 | >>> print(model_fine.act_fun[0].coef[0][0].data)
190 | tensor(-0.0030)
191 | tensor(0.0506)
192 | '''
193 | another_model(x) # get activations
194 | batch = x.shape[0]
195 |
196 | self.initialize_grid_from_another_model(another_model, x)
197 |
198 | for l in range(self.depth):
199 | spb = self.act_fun[l]
200 | spb_parent = another_model.act_fun[l]
201 |
202 | # spb = spb_parent
203 | preacts = another_model.spline_preacts[l]
204 | postsplines = another_model.spline_postsplines[l]
205 | self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0),
206 | postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid,
207 | k=spb.k)
208 | spb.scale_base.data = spb_parent.scale_base.data
209 | spb.scale_sp.data = spb_parent.scale_sp.data
210 | spb.mask.data = spb_parent.mask.data
211 | # print(spb.mask.data, self.act_fun[l].mask.data)
212 |
213 | for l in range(self.depth):
214 | self.biases[l].weight.data = another_model.biases[l].weight.data
215 |
216 | for l in range(self.depth):
217 | self.symbolic_fun[l] = another_model.symbolic_fun[l]
218 |
219 | return self
220 |
221 | def update_grid_from_samples(self, x):
222 | '''
223 | update grid from samples
224 |
225 | Args:
226 | -----
227 | x : 2D torch.float
228 | inputs, shape (batch, input dimension)
229 |
230 | Returns:
231 | --------
232 | None
233 |
234 | Example
235 | -------
236 | >>> model = KAN(width=[2,5,1], grid=5, k=3)
237 | >>> print(model.act_fun[0].grid[0].data)
238 | >>> x = torch.rand(100,2)*5
239 | >>> model.update_grid_from_samples(x)
240 | >>> print(model.act_fun[0].grid[0].data)
241 | tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000])
242 | tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
243 | '''
244 | for l in range(self.depth):
245 | self.forward(x)
246 | self.act_fun[l].update_grid_from_samples(self.acts[l])
247 |
248 | def initialize_grid_from_another_model(self, model, x):
249 | '''
250 | initialize grid from a parent model
251 |
252 | Args:
253 | -----
254 | model : KAN
255 | parent model
256 | x : 2D torch.float
257 | inputs, shape (batch, input dimension)
258 |
259 | Returns:
260 | --------
261 | None
262 |
263 | Example
264 | -------
265 | >>> model_parent = KAN(width=[1,1], grid=5, k=3)
266 | >>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)[None,:]
267 | >>> x = torch.linspace(-2,2,steps=1001)[:,None]
268 | >>> model = KAN(width=[1,1], grid=5, k=3)
269 | >>> print(model.act_fun[0].grid.data)
270 | >>> model = model.initialize_from_another_model(model_parent, x)
271 | >>> print(model.act_fun[0].grid.data)
272 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
273 | tensor([[-2.0000, -1.2000, -0.4000, 0.4000, 1.2000, 2.0000]])
274 | '''
275 | model(x)
276 | for l in range(self.depth):
277 | self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
278 |
279 | def forward(self, x_res, t_res):
280 | '''
281 | KAN forward
282 |
283 | Args:
284 | -----
285 | x : 2D torch.float
286 | inputs, shape (batch, input dimension)
287 |
288 | Returns:
289 | --------
290 | y : 2D torch.float
291 | outputs, shape (batch, output dimension)
292 |
293 | Example
294 | -------
295 | >>> model = KAN(width=[2,5,3], grid=5, k=3)
296 | >>> x = torch.normal(0,1,size=(100,2))
297 | >>> model(x).shape
298 | torch.Size([100, 3])
299 | '''
300 | x = torch.cat([x_res, t_res], dim=-1)
301 | self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
302 | self.spline_preacts = []
303 | self.spline_postsplines = []
304 | self.spline_postacts = []
305 | self.acts_scale = []
306 | self.acts_scale_std = []
307 | # self.neurons_scale = []
308 |
309 | self.acts.append(x) # acts shape: (batch, width[l])
310 |
311 | for l in range(self.depth):
312 |
313 | x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
314 |
315 | if self.symbolic_enabled == True:
316 | x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
317 | else:
318 | x_symbolic = 0.
319 | postacts_symbolic = 0.
320 |
321 | x = x_numerical + x_symbolic
322 | postacts = postacts_numerical + postacts_symbolic
323 |
324 | # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
325 | grid_reshape = self.act_fun[l].grid.reshape(self.width[l + 1], self.width[l], -1)
326 | input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4
327 | output_range = torch.mean(torch.abs(postacts), dim=0)
328 | self.acts_scale.append(output_range / input_range)
329 | self.acts_scale_std.append(torch.std(postacts, dim=0))
330 | self.spline_preacts.append(preacts.detach())
331 | self.spline_postacts.append(postacts.detach())
332 | self.spline_postsplines.append(postspline.detach())
333 |
334 | x = x + self.biases[l].weight
335 | self.acts.append(x)
336 |
337 | return x
338 |
339 | def set_mode(self, l, i, j, mode, mask_n=None):
340 | '''
341 | set (l,i,j) activation to have mode
342 |
343 | Args:
344 | -----
345 | l : int
346 | layer index
347 | i : int
348 | input neuron index
349 | j : int
350 | output neuron index
351 | mode : str
352 | 'n' (numeric) or 's' (symbolic) or 'ns' (combined)
353 | mask_n : None or float)
354 | magnitude of the numeric front
355 |
356 | Returns:
357 | --------
358 | None
359 | '''
360 | if mode == "s":
361 | mask_n = 0.;
362 | mask_s = 1.
363 | elif mode == "n":
364 | mask_n = 1.;
365 | mask_s = 0.
366 | elif mode == "sn" or mode == "ns":
367 | if mask_n == None:
368 | mask_n = 1.
369 | else:
370 | mask_n = mask_n
371 | mask_s = 1.
372 | else:
373 | mask_n = 0.;
374 | mask_s = 0.
375 |
376 | self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n
377 | self.symbolic_fun[l].mask.data[j, i] = mask_s
378 |
379 | def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True,
380 | random=False):
381 | '''
382 | set (l,i,j) activation to be symbolic (specified by fun_name)
383 |
384 | Args:
385 | -----
386 | l : int
387 | layer index
388 | i : int
389 | input neuron index
390 | j : int
391 | output neuron index
392 | fun_name : str
393 | function name
394 | fit_params_bool : bool
395 | obtaining affine parameters through fitting (True) or setting default values (False)
396 | a_range : tuple
397 | sweeping range of a
398 | b_range : tuple
399 | sweeping range of b
400 | verbose : bool
401 | If True, more information is printed.
402 | random : bool
403 | initialize affine parameteres randomly or as [1,0,1,0]
404 |
405 | Returns:
406 | --------
407 | None or r2 (coefficient of determination)
408 |
409 | Example 1
410 | ---------
411 | >>> # when fit_params_bool = False
412 | >>> model = KAN(width=[2,5,1], grid=5, k=3)
413 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
414 | >>> print(model.act_fun[0].mask.reshape(2,5))
415 | >>> print(model.symbolic_fun[0].mask.reshape(2,5))
416 | tensor([[1., 1., 1., 1., 1.],
417 | [1., 1., 0., 1., 1.]])
418 | tensor([[0., 0., 0., 0., 0.],
419 | [0., 0., 1., 0., 0.]])
420 |
421 | Example 2
422 | ---------
423 | >>> # when fit_params_bool = True
424 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
425 | >>> x = torch.normal(0,1,size=(100,2))
426 | >>> model(x) # obtain activations (otherwise model does not have attributes acts)
427 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
428 | >>> print(model.act_fun[0].mask.reshape(2,5))
429 | >>> print(model.symbolic_fun[0].mask.reshape(2,5))
430 | r2 is 0.8131332993507385
431 | r2 is not very high, please double check if you are choosing the correct symbolic function.
432 | tensor([[1., 1., 1., 1., 1.],
433 | [1., 1., 0., 1., 1.]])
434 | tensor([[0., 0., 0., 0., 0.],
435 | [0., 0., 1., 0., 0.]])
436 | '''
437 | self.set_mode(l, i, j, mode="s")
438 | if not fit_params_bool:
439 | self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random)
440 | return None
441 | else:
442 | x = self.acts[l][:, i]
443 | y = self.spline_postacts[l][:, j, i]
444 | r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range,
445 | verbose=verbose)
446 | return r2
447 |
448 | def unfix_symbolic(self, l, i, j):
449 | '''
450 | unfix the (l,i,j) activation function.
451 | '''
452 | self.set_mode(l, i, j, mode="n")
453 |
454 | def unfix_symbolic_all(self):
455 | '''
456 | unfix all activation functions.
457 | '''
458 | for l in range(len(self.width) - 1):
459 | for i in range(self.width[l]):
460 | for j in range(self.width[l + 1]):
461 | self.unfix_symbolic(l, i, j)
462 |
463 | def lock(self, l, ids):
464 | '''
465 | lock ids in the l-th layer to be the same function
466 |
467 | Args:
468 | -----
469 | l : int
470 | layer index
471 | ids : 2D list
472 | :math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), (l,i_2,j_2), ...` to be the same function
473 |
474 | Returns:
475 | --------
476 | None
477 |
478 | Example
479 | -------
480 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
481 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
482 | >>> model.lock(0,[[1,0],[1,1]])
483 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
484 | tensor([[0, 1],
485 | [2, 3],
486 | [4, 5]])
487 | tensor([[0, 1],
488 | [2, 1],
489 | [4, 5]])
490 | '''
491 | self.act_fun[l].lock(ids)
492 |
493 | def unlock(self, l, ids):
494 | '''
495 | unlock ids in the l-th layer to be the same function
496 |
497 | Args:
498 | -----
499 | l : int
500 | layer index
501 | ids : 2D list)
502 | [[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be unlocked
503 |
504 | Example:
505 | --------
506 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
507 | >>> model.lock(0,[[1,0],[1,1]])
508 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
509 | >>> model.unlock(0,[[1,0],[1,1]])
510 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2))
511 | tensor([[0, 1],
512 | [2, 1],
513 | [4, 5]])
514 | tensor([[0, 1],
515 | [2, 3],
516 | [4, 5]])
517 | '''
518 | self.act_fun[l].unlock(ids)
519 |
520 | def get_range(self, l, i, j, verbose=True):
521 | '''
522 | Get the input range and output range of the (l,i,j) activation
523 |
524 | Args:
525 | -----
526 | l : int
527 | layer index
528 | i : int
529 | input neuron index
530 | j : int
531 | output neuron index
532 |
533 | Returns:
534 | --------
535 | x_min : float
536 | minimum of input
537 | x_max : float
538 | maximum of input
539 | y_min : float
540 | minimum of output
541 | y_max : float
542 | maximum of output
543 |
544 | Example
545 | -------
546 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
547 | >>> x = torch.normal(0,1,size=(100,2))
548 | >>> model(x) # do a forward pass to obtain model.acts
549 | >>> model.get_range(0,0,0)
550 | x range: [-2.13 , 2.75 ]
551 | y range: [-0.50 , 1.83 ]
552 | (tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275))
553 | '''
554 | x = self.spline_preacts[l][:, j, i]
555 | y = self.spline_postacts[l][:, j, i]
556 | x_min = torch.min(x)
557 | x_max = torch.max(x)
558 | y_min = torch.min(y)
559 | y_max = torch.max(y)
560 | if verbose:
561 | print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']')
562 | print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']')
563 | return x_min, x_max, y_min, y_max
564 |
565 | def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False,
566 | in_vars=None, out_vars=None, title=None):
567 | '''
568 | plot KAN
569 |
570 | Args:
571 | -----
572 | folder : str
573 | the folder to store pngs
574 | beta : float
575 | positive number. control the transparency of each activation. transparency = tanh(beta*l1).
576 | mask : bool
577 | If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
578 | mode : bool
579 | "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
580 | scale : float
581 | control the size of the diagram
582 | in_vars: None or list of str
583 | the name(s) of input variables
584 | out_vars: None or list of str
585 | the name(s) of output variables
586 | title: None or str
587 | title
588 |
589 | Returns:
590 | --------
591 | Figure
592 |
593 | Example
594 | -------
595 | >>> # see more interactive examples in demos
596 | >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
597 | >>> x = torch.normal(0,1,size=(100,2))
598 | >>> model(x) # do a forward pass to obtain model.acts
599 | >>> model.plot()
600 | '''
601 | if not os.path.exists(folder):
602 | os.makedirs(folder)
603 | # matplotlib.use('Agg')
604 | depth = len(self.width) - 1
605 | for l in range(depth):
606 | w_large = 2.0
607 | for i in range(self.width[l]):
608 | for j in range(self.width[l + 1]):
609 | rank = torch.argsort(self.acts[l][:, i])
610 | fig, ax = plt.subplots(figsize=(w_large, w_large))
611 |
612 | num = rank.shape[0]
613 |
614 | symbol_mask = self.symbolic_fun[l].mask[j][i]
615 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i]
616 | if symbol_mask > 0. and numerical_mask > 0.:
617 | color = 'purple'
618 | alpha_mask = 1
619 | if symbol_mask > 0. and numerical_mask == 0.:
620 | color = "red"
621 | alpha_mask = 1
622 | if symbol_mask == 0. and numerical_mask > 0.:
623 | color = "black"
624 | alpha_mask = 1
625 | if symbol_mask == 0. and numerical_mask == 0.:
626 | color = "white"
627 | alpha_mask = 0
628 |
629 | if tick == True:
630 | ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50)
631 | ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50)
632 | x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False)
633 | plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max])
634 | plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max])
635 | else:
636 | plt.xticks([])
637 | plt.yticks([])
638 | if alpha_mask == 1:
639 | plt.gca().patch.set_edgecolor('black')
640 | else:
641 | plt.gca().patch.set_edgecolor('white')
642 | plt.gca().patch.set_linewidth(1.5)
643 | # plt.axis('off')
644 |
645 | plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(),
646 | self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5)
647 | if sample == True:
648 | plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(),
649 | self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color,
650 | s=400 * scale ** 2)
651 | plt.gca().spines[:].set_color(color)
652 |
653 | lock_id = self.act_fun[l].lock_id[j * self.width[l] + i].long().item()
654 | if lock_id > 0:
655 | im = plt.imread(f'{folder}/lock.png')
656 | newax = fig.add_axes([0.15, 0.7, 0.15, 0.15])
657 | plt.text(500, 400, lock_id, fontsize=15)
658 | newax.imshow(im)
659 | newax.axis('off')
660 |
661 | plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
662 | plt.close()
663 |
664 | def score2alpha(score):
665 | return np.tanh(beta * score)
666 |
667 | if mode == "supervised":
668 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale]
669 | elif mode == "unsupervised":
670 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale_std]
671 |
672 | # draw skeleton
673 | width = np.array(self.width)
674 | A = 1
675 | y0 = 0.4 # 0.4
676 |
677 | # plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
678 | neuron_depth = len(width)
679 | min_spacing = A / np.maximum(np.max(width), 5)
680 |
681 | max_neuron = np.max(width)
682 | max_num_weights = np.max(width[:-1] * width[1:])
683 | y1 = 0.4 / np.maximum(max_num_weights, 3)
684 |
685 | fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0))
686 | # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
687 |
688 | # plot scatters and lines
689 | for l in range(neuron_depth):
690 | n = width[l]
691 | spacing = A / n
692 | for i in range(n):
693 | plt.scatter(1 / (2 * n) + i / n, l * y0, s=min_spacing ** 2 * 10000 * scale ** 2, color='black')
694 |
695 | if l < neuron_depth - 1:
696 | # plot connections
697 | n_next = width[l + 1]
698 | N = n * n_next
699 | for j in range(n_next):
700 | id_ = i * n_next + j
701 |
702 | symbol_mask = self.symbolic_fun[l].mask[j][i]
703 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i]
704 | if symbol_mask == 1. and numerical_mask == 1.:
705 | color = 'purple'
706 | alpha_mask = 1.
707 | if symbol_mask == 1. and numerical_mask == 0.:
708 | color = "red"
709 | alpha_mask = 1.
710 | if symbol_mask == 0. and numerical_mask == 1.:
711 | color = "black"
712 | alpha_mask = 1.
713 | if symbol_mask == 0. and numerical_mask == 0.:
714 | color = "white"
715 | alpha_mask = 0.
716 | if mask == True:
717 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1],
718 | color=color, lw=2 * scale,
719 | alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item())
720 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
721 | [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale,
722 | alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item())
723 | else:
724 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1],
725 | color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask)
726 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
727 | [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale,
728 | alpha=alpha[l][j][i] * alpha_mask)
729 |
730 | plt.xlim(0, 1)
731 | plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0)
732 |
733 | # -- Transformation functions
734 | DC_to_FC = ax.transData.transform
735 | FC_to_NFC = fig.transFigure.inverted().transform
736 | # -- Take data coordinates and transform them to normalized figure coordinates
737 | DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
738 |
739 | plt.axis('off')
740 |
741 | # plot splines
742 | for l in range(neuron_depth - 1):
743 | n = width[l]
744 | for i in range(n):
745 | n_next = width[l + 1]
746 | N = n * n_next
747 | for j in range(n_next):
748 | id_ = i * n_next + j
749 | im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
750 | left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
751 | right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
752 | bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1]
753 | up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1]
754 | newax = fig.add_axes([left, bottom, right - left, up - bottom])
755 | # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
756 | if mask == False:
757 | newax.imshow(im, alpha=alpha[l][j][i])
758 | else:
759 | ### make sure to run model.prune() first to compute mask ###
760 | newax.imshow(im, alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item())
761 | newax.axis('off')
762 |
763 | if in_vars != None:
764 | n = self.width[0]
765 | for i in range(n):
766 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale,
767 | horizontalalignment='center', verticalalignment='center')
768 |
769 | if out_vars != None:
770 | n = self.width[-1]
771 | for i in range(n):
772 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), y0 * (len(self.width) - 1) + 0.1, out_vars[i],
773 | fontsize=40 * scale, horizontalalignment='center',
774 | verticalalignment='center')
775 |
776 | if title != None:
777 | plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale,
778 | horizontalalignment='center', verticalalignment='center')
779 |
780 | def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0.,
781 | lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50,
782 | batch=-1,
783 | small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False,
784 | in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'):
785 | '''
786 | training
787 |
788 | Args:
789 | -----
790 | dataset : dic
791 | contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
792 | opt : str
793 | "LBFGS" or "Adam"
794 | steps : int
795 | training steps
796 | log : int
797 | logging frequency
798 | lamb : float
799 | overall penalty strength
800 | lamb_l1 : float
801 | l1 penalty strength
802 | lamb_entropy : float
803 | entropy penalty strength
804 | lamb_coef : float
805 | coefficient magnitude penalty strength
806 | lamb_coefdiff : float
807 | difference of nearby coefficits (smoothness) penalty strength
808 | update_grid : bool
809 | If True, update grid regularly before stop_grid_update_step
810 | grid_update_num : int
811 | the number of grid updates before stop_grid_update_step
812 | stop_grid_update_step : int
813 | no grid updates after this training step
814 | batch : int
815 | batch size, if -1 then full.
816 | small_mag_threshold : float
817 | threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)
818 | small_reg_factor : float
819 | penalty strength applied to small factors relative to large factos
820 | device : str
821 | device
822 | save_fig_freq : int
823 | save figure every (save_fig_freq) step
824 |
825 | Returns:
826 | --------
827 | results : dic
828 | results['train_loss'], 1D array of training losses (RMSE)
829 | results['test_loss'], 1D array of test losses (RMSE)
830 | results['reg'], 1D array of regularization
831 |
832 | Example
833 | -------
834 | >>> # for interactive examples, please see demos
835 | >>> from utils import create_dataset
836 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
837 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
838 | >>> dataset = create_dataset(f, n_var=2)
839 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
840 | >>> model.plot()
841 | '''
842 |
843 | def reg(acts_scale):
844 |
845 | def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
846 | return (x < th) * x * factor + (x > th) * (x + (factor - 1) * th)
847 |
848 | reg_ = 0.
849 | for i in range(len(acts_scale)):
850 | vec = acts_scale[i].reshape(-1, )
851 |
852 | p = vec / torch.sum(vec)
853 | l1 = torch.sum(nonlinear(vec))
854 | entropy = - torch.sum(p * torch.log2(p + 1e-4))
855 | reg_ += lamb_l1 * l1 + lamb_entropy * entropy # both l1 and entropy
856 |
857 | # regularize coefficient to encourage spline to be zero
858 | for i in range(len(self.act_fun)):
859 | coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
860 | coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1))
861 | reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
862 |
863 | return reg_
864 |
865 | pbar = tqdm(range(steps), desc='description', ncols=100)
866 |
867 | if loss_fn == None:
868 | loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
869 | else:
870 | loss_fn = loss_fn_eval = loss_fn
871 |
872 | grid_update_freq = int(stop_grid_update_step / grid_update_num)
873 |
874 | if opt == "Adam":
875 | optimizer = torch.optim.Adam(self.parameters(), lr=lr)
876 | elif opt == "LBFGS":
877 | optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
878 | tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
879 |
880 | results = {}
881 | results['train_loss'] = []
882 | results['test_loss'] = []
883 | results['reg'] = []
884 | if metrics != None:
885 | for i in range(len(metrics)):
886 | results[metrics[i].__name__] = []
887 |
888 | if batch == -1 or batch > dataset['train_input'].shape[0]:
889 | batch_size = dataset['train_input'].shape[0]
890 | batch_size_test = dataset['test_input'].shape[0]
891 | else:
892 | batch_size = batch
893 | batch_size_test = batch
894 |
895 | global train_loss, reg_
896 |
897 | def closure():
898 | global train_loss, reg_
899 | optimizer.zero_grad()
900 | pred = self.forward(dataset['train_input'][train_id].to(device))
901 | if sglr_avoid == True:
902 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
903 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
904 | else:
905 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
906 | reg_ = reg(self.acts_scale)
907 | objective = train_loss + lamb * reg_
908 | objective.backward()
909 | return objective
910 |
911 | if save_fig:
912 | if not os.path.exists(img_folder):
913 | os.makedirs(img_folder)
914 |
915 | for _ in pbar:
916 |
917 | train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
918 | test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
919 |
920 | if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
921 | self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
922 |
923 | if opt == "LBFGS":
924 | optimizer.step(closure)
925 |
926 | if opt == "Adam":
927 | pred = self.forward(dataset['train_input'][train_id].to(device))
928 | if sglr_avoid == True:
929 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
930 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
931 | else:
932 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
933 | reg_ = reg(self.acts_scale)
934 | loss = train_loss + lamb * reg_
935 | optimizer.zero_grad()
936 | loss.backward()
937 | optimizer.step()
938 |
939 | test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)),
940 | dataset['test_label'][test_id].to(device))
941 |
942 | if _ % log == 0:
943 | pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (
944 | torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
945 | reg_.cpu().detach().numpy()))
946 |
947 | if metrics != None:
948 | for i in range(len(metrics)):
949 | results[metrics[i].__name__].append(metrics[i]().item())
950 |
951 | results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
952 | results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
953 | results['reg'].append(reg_.cpu().detach().numpy())
954 |
955 | if save_fig and _ % save_fig_freq == 0:
956 | self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
957 | plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200)
958 | plt.close()
959 |
960 | return results
961 |
962 | def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None):
963 | '''
964 | pruning KAN on the node level. If a node has small incoming or outgoing connection, it will be pruned away.
965 |
966 | Args:
967 | -----
968 | threshold : float
969 | the threshold used to determine whether a node is small enough
970 | mode : str
971 | "auto" or "manual". If "auto", the thresold will be used to automatically prune away nodes. If "manual", active_neuron_id is needed to specify which neurons are kept (others are thrown away).
972 | active_neuron_id : list of id lists
973 | For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet.
974 |
975 | Returns:
976 | --------
977 | model2 : KAN
978 | pruned model
979 |
980 | Example
981 | -------
982 | >>> # for more interactive examples, please see demos
983 | >>> from utils import create_dataset
984 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
985 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
986 | >>> dataset = create_dataset(f, n_var=2)
987 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
988 | >>> model.prune()
989 | >>> model.plot(mask=True)
990 | '''
991 | mask = [torch.ones(self.width[0], )]
992 | active_neurons = [list(range(self.width[0]))]
993 | for i in range(len(self.acts_scale) - 1):
994 | if mode == "auto":
995 | in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold
996 | out_important = torch.max(self.acts_scale[i + 1], dim=0)[0] > threshold
997 | overall_important = in_important * out_important
998 | elif mode == "manual":
999 | overall_important = torch.zeros(self.width[i + 1], dtype=torch.bool)
1000 | overall_important[active_neurons_id[i + 1]] = True
1001 | mask.append(overall_important.float())
1002 | active_neurons.append(torch.where(overall_important == True)[0])
1003 | active_neurons.append(list(range(self.width[-1])))
1004 | mask.append(torch.ones(self.width[-1], ))
1005 |
1006 | self.mask = mask # this is neuron mask for the whole model
1007 |
1008 | # update act_fun[l].mask
1009 | for l in range(len(self.acts_scale) - 1):
1010 | for i in range(self.width[l + 1]):
1011 | if i not in active_neurons[l + 1]:
1012 | self.remove_node(l + 1, i)
1013 |
1014 | model2 = KAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun)
1015 | model2.load_state_dict(self.state_dict())
1016 | for i in range(len(self.acts_scale)):
1017 | if i < len(self.acts_scale) - 1:
1018 | model2.biases[i].weight.data = model2.biases[i].weight.data[:, active_neurons[i + 1]]
1019 |
1020 | model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i], active_neurons[i + 1])
1021 | model2.width[i] = len(active_neurons[i])
1022 | model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons[i], active_neurons[i + 1])
1023 |
1024 | return model2
1025 |
1026 | def remove_edge(self, l, i, j):
1027 | '''
1028 | remove activtion phi(l,i,j) (set its mask to zero)
1029 |
1030 | Args:
1031 | -----
1032 | l : int
1033 | layer index
1034 | i : int
1035 | input neuron index
1036 | j : int
1037 | output neuron index
1038 |
1039 | Returns:
1040 | --------
1041 | None
1042 | '''
1043 | self.act_fun[l].mask[j * self.width[l] + i] = 0.
1044 |
1045 | def remove_node(self, l, i):
1046 | '''
1047 | remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
1048 |
1049 | Args:
1050 | -----
1051 | l : int
1052 | layer index
1053 | i : int
1054 | neuron index
1055 |
1056 | Returns:
1057 | --------
1058 | None
1059 | '''
1060 | self.act_fun[l - 1].mask[i * self.width[l - 1] + torch.arange(self.width[l - 1])] = 0.
1061 | self.act_fun[l].mask[torch.arange(self.width[l + 1]) * self.width[l] + i] = 0.
1062 | self.symbolic_fun[l - 1].mask[i, :] *= 0.
1063 | self.symbolic_fun[l].mask[:, i] *= 0.
1064 |
1065 | def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True):
1066 | '''suggest the symbolic candidates of phi(l,i,j)
1067 |
1068 | Args:
1069 | -----
1070 | l : int
1071 | layer index
1072 | i : int
1073 | input neuron index
1074 | j : int
1075 | output neuron index
1076 | lib : dic
1077 | library of symbolic bases. If lib = None, the global default library will be used.
1078 | topk : int
1079 | display the top k symbolic functions (according to r2)
1080 | verbose : bool
1081 | If True, more information will be printed.
1082 |
1083 | Returns:
1084 | --------
1085 | None
1086 |
1087 | Example
1088 | -------
1089 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
1090 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1091 | >>> dataset = create_dataset(f, n_var=2)
1092 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
1093 | >>> model = model.prune()
1094 | >>> model(dataset['train_input'])
1095 | >>> model.suggest_symbolic(0,0,0)
1096 | function , r2
1097 | sin , 0.9994412064552307
1098 | gaussian , 0.9196369051933289
1099 | tanh , 0.8608126044273376
1100 | sigmoid , 0.8578218817710876
1101 | arctan , 0.842217743396759
1102 | '''
1103 | r2s = []
1104 |
1105 | if lib == None:
1106 | symbolic_lib = SYMBOLIC_LIB
1107 | else:
1108 | symbolic_lib = {}
1109 | for item in lib:
1110 | symbolic_lib[item] = SYMBOLIC_LIB[item]
1111 |
1112 | for (name, fun) in symbolic_lib.items():
1113 | r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False)
1114 | r2s.append(r2.item())
1115 |
1116 | self.unfix_symbolic(l, i, j)
1117 |
1118 | sorted_ids = np.argsort(r2s)[::-1][:topk]
1119 | r2s = np.array(r2s)[sorted_ids][:topk]
1120 | topk = np.minimum(topk, len(symbolic_lib))
1121 | if verbose == True:
1122 | print('function', ',', 'r2')
1123 | for i in range(topk):
1124 | print(list(symbolic_lib.items())[sorted_ids[i]][0], ',', r2s[i])
1125 |
1126 | best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
1127 | best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
1128 | best_r2 = r2s[0]
1129 | return best_name, best_fun, best_r2
1130 |
1131 | def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
1132 | '''
1133 | automatic symbolic regression: using top 1 suggestion from suggest_symbolic to replace splines with symbolic activations
1134 |
1135 | Args:
1136 | -----
1137 | lib : None or a list of function names
1138 | the symbolic library
1139 | verbose : int
1140 | verbosity
1141 |
1142 | Returns:
1143 | --------
1144 | None (print suggested symbolic formulas)
1145 |
1146 | Example 1
1147 | ---------
1148 | >>> # default library
1149 | >>> from utils import create_dataset
1150 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
1151 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1152 | >>> dataset = create_dataset(f, n_var=2)
1153 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
1154 | >>> >>> model = model.prune()
1155 | >>> model(dataset['train_input'])
1156 | >>> model.auto_symbolic()
1157 | fixing (0,0,0) with sin, r2=0.9994837045669556
1158 | fixing (0,1,0) with cosh, r2=0.9978033900260925
1159 | fixing (1,0,0) with arctan, r2=0.9997088313102722
1160 |
1161 | Example 2
1162 | ---------
1163 | >>> # customized library
1164 | >>> from utils import create_dataset
1165 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
1166 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1167 | >>> dataset = create_dataset(f, n_var=2)
1168 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
1169 | >>> >>> model = model.prune()
1170 | >>> model(dataset['train_input'])
1171 | >>> model.auto_symbolic(lib=['exp','sin','x^2'])
1172 | fixing (0,0,0) with sin, r2=0.999411404132843
1173 | fixing (0,1,0) with x^2, r2=0.9962921738624573
1174 | fixing (1,0,0) with exp, r2=0.9980258941650391
1175 | '''
1176 | for l in range(len(self.width) - 1):
1177 | for i in range(self.width[l]):
1178 | for j in range(self.width[l + 1]):
1179 | if self.symbolic_fun[l].mask[j, i] > 0.:
1180 | print(f'skipping ({l},{i},{j}) since already symbolic')
1181 | else:
1182 | name, fun, r2 = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib,
1183 | verbose=False)
1184 | self.fix_symbolic(l, i, j, name, verbose=verbose > 1)
1185 | if verbose >= 1:
1186 | print(f'fixing ({l},{i},{j}) with {name}, r2={r2}')
1187 |
1188 | def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify=False):
1189 | '''
1190 | obtain the symbolic formula
1191 |
1192 | Args:
1193 | -----
1194 | floating_digit : int
1195 | the number of digits to display
1196 | var : list of str
1197 | the name of variables (if not provided, by default using ['x_1', 'x_2', ...])
1198 | normalizer : [mean array (floats), varaince array (floats)]
1199 | the normalization applied to inputs
1200 | simplify : bool
1201 | If True, simplify the equation at each step (usually quite slow), so set up False by default.
1202 |
1203 | Returns:
1204 | --------
1205 | symbolic formula : sympy function
1206 |
1207 | Example
1208 | -------
1209 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02)
1210 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1211 | >>> dataset = create_dataset(f, n_var=2)
1212 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
1213 | >>> model = model.prune()
1214 | >>> model(dataset['train_input'])
1215 | >>> model.auto_symbolic(lib=['exp','sin','x^2'])
1216 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False);
1217 | >>> model.symbolic_formula()
1218 | '''
1219 | symbolic_acts = []
1220 | x = []
1221 |
1222 | def ex_round(ex1, floating_digit=floating_digit):
1223 | ex2 = ex1
1224 | for a in sympy.preorder_traversal(ex1):
1225 | if isinstance(a, sympy.Float):
1226 | ex2 = ex2.subs(a, round(a, floating_digit))
1227 | return ex2
1228 |
1229 | # define variables
1230 | if var == None:
1231 | for ii in range(1, self.width[0] + 1):
1232 | exec(f"x{ii} = sympy.Symbol('x_{ii}')")
1233 | exec(f"x.append(x{ii})")
1234 | else:
1235 | x = [sympy.symbols(var_) for var_ in var]
1236 |
1237 | x0 = x
1238 |
1239 | if normalizer != None:
1240 | mean = normalizer[0]
1241 | std = normalizer[1]
1242 | x = [(x[i] - mean[i]) / std[i] for i in range(len(x))]
1243 |
1244 | symbolic_acts.append(x)
1245 |
1246 | for l in range(len(self.width) - 1):
1247 | y = []
1248 | for j in range(self.width[l + 1]):
1249 | yj = 0.
1250 | for i in range(self.width[l]):
1251 | a, b, c, d = self.symbolic_fun[l].affine[j, i]
1252 | sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
1253 | try:
1254 | yj += c * sympy_fun(a * x[i] + b) + d
1255 | except:
1256 | print('make sure all activations need to be converted to symbolic formulas first!')
1257 | return
1258 | if simplify == True:
1259 | y.append(sympy.simplify(yj + self.biases[l].weight.data[0, j]))
1260 | else:
1261 | y.append(yj + self.biases[l].weight.data[0, j])
1262 |
1263 | x = y
1264 | symbolic_acts.append(x)
1265 |
1266 | self.symbolic_acts = [[ex_round(symbolic_acts[l][i]) for i in range(len(symbolic_acts[l]))] for l in
1267 | range(len(symbolic_acts))]
1268 |
1269 | out_dim = len(symbolic_acts[-1])
1270 | return [ex_round(symbolic_acts[-1][i]) for i in range(len(symbolic_acts[-1]))], x0
1271 |
1272 | def clear_ckpts(self, folder='./model_ckpt'):
1273 | '''
1274 | clear all checkpoints
1275 |
1276 | Args:
1277 | -----
1278 | folder : str
1279 | the folder that stores checkpoints
1280 |
1281 | Returns:
1282 | --------
1283 | None
1284 | '''
1285 | if os.path.exists(folder):
1286 | files = glob.glob(folder + '/*')
1287 | for f in files:
1288 | os.remove(f)
1289 | else:
1290 | os.makedirs(folder)
1291 |
1292 | def save_ckpt(self, name, folder='./model_ckpt'):
1293 | '''
1294 | save the current model as checkpoint
1295 |
1296 | Args:
1297 | -----
1298 | name: str
1299 | the name of the checkpoint to be saved
1300 | folder : str
1301 | the folder that stores checkpoints
1302 |
1303 | Returns:
1304 | --------
1305 | None
1306 | '''
1307 |
1308 | if not os.path.exists(folder):
1309 | os.makedirs(folder)
1310 |
1311 | torch.save(self.state_dict(), folder + '/' + name)
1312 | print('save this model to', folder + '/' + name)
1313 |
1314 | def load_ckpt(self, name, folder='./model_ckpt'):
1315 | '''
1316 | load a checkpoint to the current model
1317 |
1318 | Args:
1319 | -----
1320 | name: str
1321 | the name of the checkpoint to be loaded
1322 | folder : str
1323 | the folder that stores checkpoints
1324 |
1325 | Returns:
1326 | --------
1327 | None
1328 | '''
1329 | self.load_state_dict(torch.load(folder + '/' + name))
--------------------------------------------------------------------------------
/models/LBFGS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import reduce
3 | from torch.optim import Optimizer
4 |
5 | __all__ = ['LBFGS']
6 |
7 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
8 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
9 | # Compute bounds of interpolation area
10 | if bounds is not None:
11 | xmin_bound, xmax_bound = bounds
12 | else:
13 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
14 |
15 | # Code for most common case: cubic interpolation of 2 points
16 | # w/ function and derivative values for both
17 | # Solution in this case (where x2 is the farthest point):
18 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
19 | # d2 = sqrt(d1^2 - g1*g2);
20 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
21 | # t_new = min(max(min_pos,xmin_bound),xmax_bound);
22 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
23 | d2_square = d1**2 - g1 * g2
24 | if d2_square >= 0:
25 | d2 = d2_square.sqrt()
26 | if x1 <= x2:
27 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
28 | else:
29 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
30 | return min(max(min_pos, xmin_bound), xmax_bound)
31 | else:
32 | return (xmin_bound + xmax_bound) / 2.
33 |
34 |
35 | def _strong_wolfe(obj_func,
36 | x,
37 | t,
38 | d,
39 | f,
40 | g,
41 | gtd,
42 | c1=1e-4,
43 | c2=0.9,
44 | tolerance_change=1e-9,
45 | max_ls=25):
46 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
47 | d_norm = d.abs().max()
48 | g = g.clone(memory_format=torch.contiguous_format)
49 | # evaluate objective and gradient using initial step
50 | f_new, g_new = obj_func(x, t, d)
51 | ls_func_evals = 1
52 | gtd_new = g_new.dot(d)
53 |
54 | # bracket an interval containing a point satisfying the Wolfe criteria
55 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
56 | done = False
57 | ls_iter = 0
58 | while ls_iter < max_ls:
59 | # check conditions
60 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
61 | bracket = [t_prev, t]
62 | bracket_f = [f_prev, f_new]
63 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
64 | bracket_gtd = [gtd_prev, gtd_new]
65 | break
66 |
67 | if abs(gtd_new) <= -c2 * gtd:
68 | bracket = [t]
69 | bracket_f = [f_new]
70 | bracket_g = [g_new]
71 | done = True
72 | break
73 |
74 | if gtd_new >= 0:
75 | bracket = [t_prev, t]
76 | bracket_f = [f_prev, f_new]
77 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
78 | bracket_gtd = [gtd_prev, gtd_new]
79 | break
80 |
81 | # interpolate
82 | min_step = t + 0.01 * (t - t_prev)
83 | max_step = t * 10
84 | tmp = t
85 | t = _cubic_interpolate(
86 | t_prev,
87 | f_prev,
88 | gtd_prev,
89 | t,
90 | f_new,
91 | gtd_new,
92 | bounds=(min_step, max_step))
93 |
94 | # next step
95 | t_prev = tmp
96 | f_prev = f_new
97 | g_prev = g_new.clone(memory_format=torch.contiguous_format)
98 | gtd_prev = gtd_new
99 | f_new, g_new = obj_func(x, t, d)
100 | ls_func_evals += 1
101 | gtd_new = g_new.dot(d)
102 | ls_iter += 1
103 |
104 | # reached max number of iterations?
105 | if ls_iter == max_ls:
106 | bracket = [0, t]
107 | bracket_f = [f, f_new]
108 | bracket_g = [g, g_new]
109 |
110 | # zoom phase: we now have a point satisfying the criteria, or
111 | # a bracket around it. We refine the bracket until we find the
112 | # exact point satisfying the criteria
113 | insuf_progress = False
114 | # find high and low points in bracket
115 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
116 | while not done and ls_iter < max_ls:
117 | # line-search bracket is so small
118 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
119 | break
120 |
121 | # compute new trial value
122 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
123 | bracket[1], bracket_f[1], bracket_gtd[1])
124 |
125 | # test that we are making sufficient progress:
126 | # in case `t` is so close to boundary, we mark that we are making
127 | # insufficient progress, and if
128 | # + we have made insufficient progress in the last step, or
129 | # + `t` is at one of the boundary,
130 | # we will move `t` to a position which is `0.1 * len(bracket)`
131 | # away from the nearest boundary point.
132 | eps = 0.1 * (max(bracket) - min(bracket))
133 | if min(max(bracket) - t, t - min(bracket)) < eps:
134 | # interpolation close to boundary
135 | if insuf_progress or t >= max(bracket) or t <= min(bracket):
136 | # evaluate at 0.1 away from boundary
137 | if abs(t - max(bracket)) < abs(t - min(bracket)):
138 | t = max(bracket) - eps
139 | else:
140 | t = min(bracket) + eps
141 | insuf_progress = False
142 | else:
143 | insuf_progress = True
144 | else:
145 | insuf_progress = False
146 |
147 | # Evaluate new point
148 | f_new, g_new = obj_func(x, t, d)
149 | ls_func_evals += 1
150 | gtd_new = g_new.dot(d)
151 | ls_iter += 1
152 |
153 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
154 | # Armijo condition not satisfied or not lower than lowest point
155 | bracket[high_pos] = t
156 | bracket_f[high_pos] = f_new
157 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
158 | bracket_gtd[high_pos] = gtd_new
159 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
160 | else:
161 | if abs(gtd_new) <= -c2 * gtd:
162 | # Wolfe conditions satisfied
163 | done = True
164 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
165 | # old low becomes new high
166 | bracket[high_pos] = bracket[low_pos]
167 | bracket_f[high_pos] = bracket_f[low_pos]
168 | bracket_g[high_pos] = bracket_g[low_pos]
169 | bracket_gtd[high_pos] = bracket_gtd[low_pos]
170 |
171 | # new point becomes new low
172 | bracket[low_pos] = t
173 | bracket_f[low_pos] = f_new
174 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
175 | bracket_gtd[low_pos] = gtd_new
176 |
177 | # return stuff
178 | t = bracket[low_pos]
179 | f_new = bracket_f[low_pos]
180 | g_new = bracket_g[low_pos]
181 | return f_new, g_new, t, ls_func_evals
182 |
183 |
184 |
185 | class LBFGS(Optimizer):
186 | """Implements L-BFGS algorithm.
187 |
188 | Heavily inspired by `minFunc
189 | `_.
190 |
191 | .. warning::
192 | This optimizer doesn't support per-parameter options and parameter
193 | groups (there can be only one).
194 |
195 | .. warning::
196 | Right now all parameters have to be on a single device. This will be
197 | improved in the future.
198 |
199 | .. note::
200 | This is a very memory intensive optimizer (it requires additional
201 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
202 | try reducing the history size, or use a different algorithm.
203 |
204 | Args:
205 | lr (float): learning rate (default: 1)
206 | max_iter (int): maximal number of iterations per optimization step
207 | (default: 20)
208 | max_eval (int): maximal number of function evaluations per optimization
209 | step (default: max_iter * 1.25).
210 | tolerance_grad (float): termination tolerance on first order optimality
211 | (default: 1e-7).
212 | tolerance_change (float): termination tolerance on function
213 | value/parameter changes (default: 1e-9).
214 | history_size (int): update history size (default: 100).
215 | line_search_fn (str): either 'strong_wolfe' or None (default: None).
216 | """
217 |
218 | def __init__(self,
219 | params,
220 | lr=1,
221 | max_iter=20,
222 | max_eval=None,
223 | tolerance_grad=1e-7,
224 | tolerance_change=1e-9,
225 | tolerance_ys=1e-32,
226 | history_size=100,
227 | line_search_fn=None):
228 | if max_eval is None:
229 | max_eval = max_iter * 5 // 4
230 | defaults = dict(
231 | lr=lr,
232 | max_iter=max_iter,
233 | max_eval=max_eval,
234 | tolerance_grad=tolerance_grad,
235 | tolerance_change=tolerance_change,
236 | tolerance_ys=tolerance_ys,
237 | history_size=history_size,
238 | line_search_fn=line_search_fn)
239 | super().__init__(params, defaults)
240 |
241 | if len(self.param_groups) != 1:
242 | raise ValueError("LBFGS doesn't support per-parameter options "
243 | "(parameter groups)")
244 |
245 | self._params = self.param_groups[0]['params']
246 | self._numel_cache = None
247 |
248 | def _numel(self):
249 | if self._numel_cache is None:
250 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
251 | return self._numel_cache
252 |
253 | def _gather_flat_grad(self):
254 | views = []
255 | for p in self._params:
256 | if p.grad is None:
257 | view = p.new(p.numel()).zero_()
258 | elif p.grad.is_sparse:
259 | view = p.grad.to_dense().view(-1)
260 | else:
261 | view = p.grad.view(-1)
262 | views.append(view)
263 | return torch.cat(views, 0)
264 |
265 | def _add_grad(self, step_size, update):
266 | offset = 0
267 | for p in self._params:
268 | numel = p.numel()
269 | # view as to avoid deprecated pointwise semantics
270 | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
271 | offset += numel
272 | assert offset == self._numel()
273 |
274 | def _clone_param(self):
275 | return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
276 |
277 | def _set_param(self, params_data):
278 | for p, pdata in zip(self._params, params_data):
279 | p.copy_(pdata)
280 |
281 | def _directional_evaluate(self, closure, x, t, d):
282 | self._add_grad(t, d)
283 | loss = float(closure())
284 | flat_grad = self._gather_flat_grad()
285 | self._set_param(x)
286 | return loss, flat_grad
287 |
288 |
289 | @torch.no_grad()
290 | def step(self, closure):
291 | """Perform a single optimization step.
292 |
293 | Args:
294 | closure (Callable): A closure that reevaluates the model
295 | and returns the loss.
296 | """
297 | assert len(self.param_groups) == 1
298 |
299 | # Make sure the closure is always called with grad enabled
300 | closure = torch.enable_grad()(closure)
301 |
302 | group = self.param_groups[0]
303 | lr = group['lr']
304 | max_iter = group['max_iter']
305 | max_eval = group['max_eval']
306 | tolerance_grad = group['tolerance_grad']
307 | tolerance_change = group['tolerance_change']
308 | tolerance_ys = group['tolerance_ys']
309 | line_search_fn = group['line_search_fn']
310 | history_size = group['history_size']
311 |
312 | # NOTE: LBFGS has only global state, but we register it as state for
313 | # the first param, because this helps with casting in load_state_dict
314 | state = self.state[self._params[0]]
315 | state.setdefault('func_evals', 0)
316 | state.setdefault('n_iter', 0)
317 |
318 | # evaluate initial f(x) and df/dx
319 | orig_loss = closure()
320 | loss = float(orig_loss)
321 | current_evals = 1
322 | state['func_evals'] += 1
323 |
324 | flat_grad = self._gather_flat_grad()
325 | opt_cond = flat_grad.abs().max() <= tolerance_grad
326 |
327 | # optimal condition
328 | if opt_cond:
329 | return orig_loss
330 |
331 | # tensors cached in state (for tracing)
332 | d = state.get('d')
333 | t = state.get('t')
334 | old_dirs = state.get('old_dirs')
335 | old_stps = state.get('old_stps')
336 | ro = state.get('ro')
337 | H_diag = state.get('H_diag')
338 | prev_flat_grad = state.get('prev_flat_grad')
339 | prev_loss = state.get('prev_loss')
340 |
341 | n_iter = 0
342 | # optimize for a max of max_iter iterations
343 | while n_iter < max_iter:
344 | # keep track of nb of iterations
345 | n_iter += 1
346 | state['n_iter'] += 1
347 |
348 | ############################################################
349 | # compute gradient descent direction
350 | ############################################################
351 | if state['n_iter'] == 1:
352 | d = flat_grad.neg()
353 | old_dirs = []
354 | old_stps = []
355 | ro = []
356 | H_diag = 1
357 | else:
358 | # do lbfgs update (update memory)
359 | y = flat_grad.sub(prev_flat_grad)
360 | s = d.mul(t)
361 | ys = y.dot(s) # y*s
362 | if ys > tolerance_ys:
363 | # updating memory
364 | if len(old_dirs) == history_size:
365 | # shift history by one (limited-memory)
366 | old_dirs.pop(0)
367 | old_stps.pop(0)
368 | ro.pop(0)
369 |
370 | # store new direction/step
371 | old_dirs.append(y)
372 | old_stps.append(s)
373 | ro.append(1. / ys)
374 |
375 | # update scale of initial Hessian approximation
376 | H_diag = ys / y.dot(y) # (y*y)
377 |
378 | # compute the approximate (L-BFGS) inverse Hessian
379 | # multiplied by the gradient
380 | num_old = len(old_dirs)
381 |
382 | if 'al' not in state:
383 | state['al'] = [None] * history_size
384 | al = state['al']
385 |
386 | # iteration in L-BFGS loop collapsed to use just one buffer
387 | q = flat_grad.neg()
388 | for i in range(num_old - 1, -1, -1):
389 | al[i] = old_stps[i].dot(q) * ro[i]
390 | q.add_(old_dirs[i], alpha=-al[i])
391 |
392 | # multiply by initial Hessian
393 | # r/d is the final direction
394 | d = r = torch.mul(q, H_diag)
395 | for i in range(num_old):
396 | be_i = old_dirs[i].dot(r) * ro[i]
397 | r.add_(old_stps[i], alpha=al[i] - be_i)
398 |
399 | if prev_flat_grad is None:
400 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
401 | else:
402 | prev_flat_grad.copy_(flat_grad)
403 | prev_loss = loss
404 |
405 | ############################################################
406 | # compute step length
407 | ############################################################
408 | # reset initial guess for step size
409 | if state['n_iter'] == 1:
410 | t = min(1., 1. / flat_grad.abs().sum()) * lr
411 | else:
412 | t = lr
413 |
414 | # directional derivative
415 | gtd = flat_grad.dot(d) # g * d
416 |
417 | # directional derivative is below tolerance
418 | if gtd > -tolerance_change:
419 | break
420 |
421 | # optional line search: user function
422 | ls_func_evals = 0
423 | if line_search_fn is not None:
424 | # perform line search, using user function
425 | if line_search_fn != "strong_wolfe":
426 | raise RuntimeError("only 'strong_wolfe' is supported")
427 | else:
428 | x_init = self._clone_param()
429 |
430 | def obj_func(x, t, d):
431 | return self._directional_evaluate(closure, x, t, d)
432 |
433 | loss, flat_grad, t, ls_func_evals = _strong_wolfe(
434 | obj_func, x_init, t, d, loss, flat_grad, gtd)
435 | self._add_grad(t, d)
436 | opt_cond = flat_grad.abs().max() <= tolerance_grad
437 | else:
438 | # no line search, simply move with fixed-step
439 | self._add_grad(t, d)
440 | if n_iter != max_iter:
441 | # re-evaluate function only if not in last iteration
442 | # the reason we do this: in a stochastic setting,
443 | # no use to re-evaluate that function here
444 | with torch.enable_grad():
445 | loss = float(closure())
446 | flat_grad = self._gather_flat_grad()
447 | opt_cond = flat_grad.abs().max() <= tolerance_grad
448 | ls_func_evals = 1
449 |
450 | # update func eval
451 | current_evals += ls_func_evals
452 | state['func_evals'] += ls_func_evals
453 |
454 | ############################################################
455 | # check conditions
456 | ############################################################
457 | if n_iter == max_iter:
458 | break
459 |
460 | if current_evals >= max_eval:
461 | break
462 |
463 | # optimal condition
464 | if opt_cond:
465 | break
466 |
467 | # lack of progress
468 | if d.mul(t).abs().max() <= tolerance_change:
469 | break
470 |
471 | if abs(loss - prev_loss) < tolerance_change:
472 | break
473 |
474 | state['d'] = d
475 | state['t'] = t
476 | state['old_dirs'] = old_dirs
477 | state['old_stps'] = old_stps
478 | state['ro'] = ro
479 | state['H_diag'] = H_diag
480 | state['prev_flat_grad'] = prev_flat_grad
481 | state['prev_loss'] = prev_loss
482 |
483 | return orig_loss
--------------------------------------------------------------------------------
/models/PINN.py:
--------------------------------------------------------------------------------
1 | # baseline implementation of PINNs
2 | # paper: Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations
3 | # link: https://www.sciencedirect.com/science/article/pii/S0021999118307125
4 | # code: https://github.com/maziarraissi/PINNs
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class Model(nn.Module):
11 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer):
12 | super(Model, self).__init__()
13 |
14 | layers = []
15 | for i in range(num_layer - 1):
16 | if i == 0:
17 | layers.append(nn.Linear(in_features=in_dim, out_features=hidden_dim))
18 | layers.append(nn.Tanh())
19 | else:
20 | layers.append(nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
21 | layers.append(nn.Tanh())
22 |
23 | layers.append(nn.Linear(in_features=hidden_dim, out_features=out_dim))
24 |
25 | self.linear = nn.Sequential(*layers)
26 |
27 | def forward(self, x, t):
28 | src = torch.cat((x, t), dim=-1)
29 | return self.linear(src)
30 |
--------------------------------------------------------------------------------
/models/PINNsFormer.py:
--------------------------------------------------------------------------------
1 | # implementation of PINNsformer
2 | # paper: PINNsFormer: A Transformer-Based Framework For Physics-Informed Neural Networks
3 | # link: https://arxiv.org/abs/2307.11833
4 |
5 | import torch
6 | import torch.nn as nn
7 | import pdb
8 | from util import get_clones
9 |
10 |
11 | class WaveAct(nn.Module):
12 | def __init__(self):
13 | super(WaveAct, self).__init__()
14 | self.w1 = nn.Parameter(torch.ones(1), requires_grad=True)
15 | self.w2 = nn.Parameter(torch.ones(1), requires_grad=True)
16 |
17 | def forward(self, x):
18 | return self.w1 * torch.sin(x) + self.w2 * torch.cos(x)
19 |
20 |
21 | class FeedForward(nn.Module):
22 | def __init__(self, d_model, d_ff=256):
23 | super(FeedForward, self).__init__()
24 | self.linear = nn.Sequential(*[
25 | nn.Linear(d_model, d_ff),
26 | WaveAct(),
27 | nn.Linear(d_ff, d_ff),
28 | WaveAct(),
29 | nn.Linear(d_ff, d_model)
30 | ])
31 |
32 | def forward(self, x):
33 | return self.linear(x)
34 |
35 |
36 | class EncoderLayer(nn.Module):
37 | def __init__(self, d_model, heads):
38 | super(EncoderLayer, self).__init__()
39 |
40 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
41 | self.ff = FeedForward(d_model)
42 | self.act1 = WaveAct()
43 | self.act2 = WaveAct()
44 |
45 | def forward(self, x):
46 | x2 = self.act1(x)
47 | # pdb.set_trace()
48 | x = x + self.attn(x2, x2, x2)[0]
49 | x2 = self.act2(x)
50 | x = x + self.ff(x2)
51 | return x
52 |
53 |
54 | class DecoderLayer(nn.Module):
55 | def __init__(self, d_model, heads):
56 | super(DecoderLayer, self).__init__()
57 |
58 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
59 | self.ff = FeedForward(d_model)
60 | self.act1 = WaveAct()
61 | self.act2 = WaveAct()
62 |
63 | def forward(self, x, e_outputs):
64 | x2 = self.act1(x)
65 | x = x + self.attn(x2, e_outputs, e_outputs)[0]
66 | x2 = self.act2(x)
67 | x = x + self.ff(x2)
68 | return x
69 |
70 |
71 | class Encoder(nn.Module):
72 | def __init__(self, d_model, N, heads):
73 | super(Encoder, self).__init__()
74 | self.N = N
75 | self.layers = get_clones(EncoderLayer(d_model, heads), N)
76 | self.act = WaveAct()
77 |
78 | def forward(self, x):
79 | for i in range(self.N):
80 | x = self.layers[i](x)
81 | return self.act(x)
82 |
83 |
84 | class Decoder(nn.Module):
85 | def __init__(self, d_model, N, heads):
86 | super(Decoder, self).__init__()
87 | self.N = N
88 | self.layers = get_clones(DecoderLayer(d_model, heads), N)
89 | self.act = WaveAct()
90 |
91 | def forward(self, x, e_outputs):
92 | for i in range(self.N):
93 | x = self.layers[i](x, e_outputs)
94 | return self.act(x)
95 |
96 |
97 | class Model(nn.Module):
98 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer, hidden_d_ff=512, heads=2):
99 | super(Model, self).__init__()
100 |
101 | self.linear_emb = nn.Linear(in_dim, hidden_dim)
102 |
103 | self.encoder = Encoder(hidden_dim, num_layer, heads)
104 | self.decoder = Decoder(hidden_dim, num_layer, heads)
105 | self.linear_out = nn.Sequential(*[
106 | nn.Linear(hidden_dim, hidden_d_ff),
107 | WaveAct(),
108 | nn.Linear(hidden_d_ff, hidden_d_ff),
109 | WaveAct(),
110 | nn.Linear(hidden_d_ff, out_dim)
111 | ])
112 |
113 | def forward(self, x, t):
114 | src = torch.cat((x, t), dim=-1)
115 | src = self.linear_emb(src)
116 |
117 | e_outputs = self.encoder(src)
118 | d_output = self.decoder(src, e_outputs)
119 | output = self.linear_out(d_output)
120 | return output
121 |
--------------------------------------------------------------------------------
/models/PINNsFormer_Enc_Only.py:
--------------------------------------------------------------------------------
1 | # implementation of PINNsformer
2 | # paper: PINNsFormer: A Transformer-Based Framework For Physics-Informed Neural Networks
3 | # link: https://arxiv.org/abs/2307.11833
4 |
5 | import torch
6 | import torch.nn as nn
7 | import pdb
8 | from util import get_clones
9 |
10 |
11 | class WaveAct(nn.Module):
12 | def __init__(self):
13 | super(WaveAct, self).__init__()
14 | self.w1 = nn.Parameter(torch.ones(1), requires_grad=True)
15 | self.w2 = nn.Parameter(torch.ones(1), requires_grad=True)
16 |
17 | def forward(self, x):
18 | return self.w1 * torch.sin(x) + self.w2 * torch.cos(x)
19 |
20 |
21 | class FeedForward(nn.Module):
22 | def __init__(self, d_model, d_ff=256):
23 | super(FeedForward, self).__init__()
24 | self.linear = nn.Sequential(*[
25 | nn.Linear(d_model, d_ff),
26 | WaveAct(),
27 | nn.Linear(d_ff, d_ff),
28 | WaveAct(),
29 | nn.Linear(d_ff, d_model)
30 | ])
31 |
32 | def forward(self, x):
33 | return self.linear(x)
34 |
35 |
36 | class EncoderLayer(nn.Module):
37 | def __init__(self, d_model, heads):
38 | super(EncoderLayer, self).__init__()
39 |
40 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
41 | self.ff = FeedForward(d_model)
42 | self.act1 = WaveAct()
43 | self.act2 = WaveAct()
44 |
45 | def forward(self, x):
46 | x2 = self.act1(x)
47 | # pdb.set_trace()
48 | x = x + self.attn(x2, x2, x2)[0]
49 | x2 = self.act2(x)
50 | x = x + self.ff(x2)
51 | return x
52 |
53 |
54 | class DecoderLayer(nn.Module):
55 | def __init__(self, d_model, heads):
56 | super(DecoderLayer, self).__init__()
57 |
58 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True)
59 | self.ff = FeedForward(d_model)
60 | self.act1 = WaveAct()
61 | self.act2 = WaveAct()
62 |
63 | def forward(self, x, e_outputs):
64 | x2 = self.act1(x)
65 | x = x + self.attn(x2, e_outputs, e_outputs)[0]
66 | x2 = self.act2(x)
67 | x = x + self.ff(x2)
68 | return x
69 |
70 |
71 | class Encoder(nn.Module):
72 | def __init__(self, d_model, N, heads):
73 | super(Encoder, self).__init__()
74 | self.N = N
75 | self.layers = get_clones(EncoderLayer(d_model, heads), N)
76 | self.act = WaveAct()
77 |
78 | def forward(self, x):
79 | for i in range(self.N):
80 | x = self.layers[i](x)
81 | return self.act(x)
82 |
83 |
84 | class Decoder(nn.Module):
85 | def __init__(self, d_model, N, heads):
86 | super(Decoder, self).__init__()
87 | self.N = N
88 | self.layers = get_clones(DecoderLayer(d_model, heads), N)
89 | self.act = WaveAct()
90 |
91 | def forward(self, x, e_outputs):
92 | for i in range(self.N):
93 | x = self.layers[i](x, e_outputs)
94 | return self.act(x)
95 |
96 |
97 | class Model(nn.Module):
98 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer, hidden_d_ff=512, heads=2):
99 | super(Model, self).__init__()
100 |
101 | self.linear_emb = nn.Linear(in_dim, hidden_dim)
102 |
103 | self.encoder = Encoder(hidden_dim, num_layer, heads)
104 | self.linear_out = nn.Sequential(*[
105 | nn.Linear(hidden_dim, hidden_d_ff),
106 | WaveAct(),
107 | nn.Linear(hidden_d_ff, hidden_d_ff),
108 | WaveAct(),
109 | nn.Linear(hidden_d_ff, out_dim)
110 | ])
111 |
112 | def forward(self, x, t):
113 | src = torch.cat((x, t), dim=-1)
114 | src = self.linear_emb(src)
115 | e_outputs = self.encoder(src)
116 | output = self.linear_out(e_outputs)
117 | return output
118 |
--------------------------------------------------------------------------------
/models/QRes.py:
--------------------------------------------------------------------------------
1 | # baseline implementation of QRes
2 | # paper: Quadratic residual networks: A new class of neural networks for solving forward and inverse problems in physics involving pdes
3 | # link: https://arxiv.org/abs/2101.08366
4 | # code: https://github.com/jayroxis/qres
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from util import get_clones
10 |
11 |
12 | class QRes_block(nn.Module):
13 | def __init__(self, in_dim, out_dim):
14 | super(QRes_block, self).__init__()
15 | self.H1 = nn.Linear(in_features=in_dim, out_features=out_dim)
16 | self.H2 = nn.Linear(in_features=in_dim, out_features=out_dim)
17 | self.act = nn.Sigmoid()
18 |
19 | def forward(self, x):
20 | x1 = self.H1(x)
21 | x2 = self.H2(x)
22 | return self.act(x1 * x2 + x1)
23 |
24 |
25 | class Model(nn.Module):
26 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer):
27 | super(Model, self).__init__()
28 | self.N = num_layer - 1
29 | self.inlayer = QRes_block(in_dim, hidden_dim)
30 | self.layers = get_clones(QRes_block(hidden_dim, hidden_dim), num_layer - 1)
31 | self.outlayer = nn.Linear(in_features=hidden_dim, out_features=out_dim)
32 |
33 | def forward(self, x, t):
34 | src = torch.cat((x, t), dim=-1)
35 | src = self.inlayer(src)
36 | for i in range(self.N):
37 | src = self.layers[i](src)
38 | src = self.outlayer(src)
39 | return src
40 |
--------------------------------------------------------------------------------
/models/Symbolic_KANLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import sympy
5 | from .utils import *
6 |
7 |
8 | class Symbolic_KANLayer(nn.Module):
9 | '''
10 | KANLayer class
11 |
12 | Attributes:
13 | -----------
14 | in_dim: int
15 | input dimension
16 | out_dim: int
17 | output dimension
18 | funs: 2D array of torch functions (or lambda functions)
19 | symbolic functions (torch)
20 | funs_name: 2D arry of str
21 | names of symbolic functions
22 | funs_sympy: 2D array of sympy functions (or lambda functions)
23 | symbolic functions (sympy)
24 | affine: 3D array of floats
25 | affine transformations of inputs and outputs
26 |
27 | Methods:
28 | --------
29 | __init__():
30 | initialize a Symbolic_KANLayer
31 | forward():
32 | forward
33 | get_subset():
34 | get subset of the KANLayer (used for pruning)
35 | fix_symbolic():
36 | fix an activation function to be symbolic
37 | '''
38 |
39 | def __init__(self, in_dim=3, out_dim=2):
40 | '''
41 | initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
42 |
43 | Args:
44 | -----
45 | in_dim : int
46 | input dimension
47 | out_dim : int
48 | output dimension
49 |
50 | Returns:
51 | --------
52 | self
53 |
54 | Example
55 | -------
56 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
57 | >>> len(sb.funs), len(sb.funs[0])
58 | (3, 3)
59 | '''
60 | super(Symbolic_KANLayer, self).__init__()
61 | self.out_dim = out_dim
62 | self.in_dim = in_dim
63 | self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim)).requires_grad_(False)
64 | # torch
65 | self.funs = [[lambda x: x for i in range(self.in_dim)] for j in range(self.out_dim)]
66 | # name
67 | self.funs_name = [['' for i in range(self.in_dim)] for j in range(self.out_dim)]
68 | # sympy
69 | self.funs_sympy = [['' for i in range(self.in_dim)] for j in range(self.out_dim)]
70 |
71 | self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4))
72 | # c*f(a*x+b)+d
73 |
74 | def forward(self, x):
75 | '''
76 | forward
77 |
78 | Args:
79 | -----
80 | x : 2D array
81 | inputs, shape (batch, input dimension)
82 |
83 | Returns:
84 | --------
85 | y : 2D array
86 | outputs, shape (batch, output dimension)
87 | postacts : 3D array
88 | activations after activation functions but before summing on nodes
89 |
90 | Example
91 | -------
92 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
93 | >>> x = torch.normal(0,1,size=(100,3))
94 | >>> y, postacts = sb(x)
95 | >>> y.shape, postacts.shape
96 | (torch.Size([100, 5]), torch.Size([100, 5, 3]))
97 | '''
98 |
99 | batch = x.shape[0]
100 | postacts = []
101 |
102 | for i in range(self.in_dim):
103 | postacts_ = []
104 | for j in range(self.out_dim):
105 | xij = self.affine[j, i, 2] * self.funs[j][i](self.affine[j, i, 0] * x[:, [i]] + self.affine[j, i, 1]) + \
106 | self.affine[j, i, 3]
107 | postacts_.append(self.mask[j][i] * xij)
108 | postacts.append(torch.stack(postacts_))
109 |
110 | postacts = torch.stack(postacts)
111 | postacts = postacts.permute(2, 1, 0, 3)[:, :, :, 0]
112 | y = torch.sum(postacts, dim=2)
113 |
114 | return y, postacts
115 |
116 | def get_subset(self, in_id, out_id):
117 | '''
118 | get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
119 |
120 | Args:
121 | -----
122 | in_id : list
123 | id of selected input neurons
124 | out_id : list
125 | id of selected output neurons
126 |
127 | Returns:
128 | --------
129 | spb : Symbolic_KANLayer
130 |
131 | Example
132 | -------
133 | >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
134 | >>> sb_small = sb_large.get_subset([0,9],[1,2,3])
135 | >>> sb_small.in_dim, sb_small.out_dim
136 | (2, 3)
137 | '''
138 | sbb = Symbolic_KANLayer(self.in_dim, self.out_dim)
139 | sbb.in_dim = len(in_id)
140 | sbb.out_dim = len(out_id)
141 | sbb.mask.data = self.mask.data[out_id][:, in_id]
142 | sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id]
143 | sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id]
144 | sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id]
145 | sbb.affine.data = self.affine.data[out_id][:, in_id]
146 | return sbb
147 |
148 | def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10, 10), b_range=(-10, 10),
149 | verbose=True):
150 | '''
151 | fix an activation function to be symbolic
152 |
153 | Args:
154 | -----
155 | i : int
156 | the id of input neuron
157 | j : int
158 | the id of output neuron
159 | fun_name : str
160 | the name of the symbolic functions
161 | x : 1D array
162 | preactivations
163 | y : 1D array
164 | postactivations
165 | a_range : tuple
166 | sweeping range of a
167 | b_range : tuple
168 | sweeping range of a
169 | verbose : bool
170 | print more information if True
171 |
172 | Returns:
173 | --------
174 | r2 (coefficient of determination)
175 |
176 | Example 1
177 | ---------
178 | >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
179 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
180 | >>> sb.fix_symbolic(2,1,'sin')
181 | >>> print(sb.funs_name)
182 | >>> print(sb.affine)
183 | [['', '', ''], ['', '', 'sin']]
184 | Parameter containing:
185 | tensor([[0., 0., 0., 0.],
186 | [0., 0., 0., 0.],
187 | [1., 0., 1., 0.]], requires_grad=True)
188 | Example 2
189 | ---------
190 | >>> # when x & y are provided, fit_params() is called to find the best fit coefficients
191 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
192 | >>> batch = 100
193 | >>> x = torch.linspace(-1,1,steps=batch)
194 | >>> noises = torch.normal(0,1,(batch,)) * 0.02
195 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
196 | >>> sb.fix_symbolic(2,1,'sin',x,y)
197 | >>> print(sb.funs_name)
198 | >>> print(sb.affine[1,2,:].data)
199 | r2 is 0.9999701976776123
200 | [['', '', ''], ['', '', 'sin']]
201 | tensor([2.9981, 1.9997, 5.0039, 0.6978])
202 | '''
203 | if isinstance(fun_name, str):
204 | fun = SYMBOLIC_LIB[fun_name][0]
205 | fun_sympy = SYMBOLIC_LIB[fun_name][1]
206 | self.funs_sympy[j][i] = fun_sympy
207 | self.funs_name[j][i] = fun_name
208 | if x == None or y == None:
209 | # initialzie from just fun
210 | self.funs[j][i] = fun
211 | if random == False:
212 | self.affine.data[j][i] = torch.tensor([1., 0., 1., 0.])
213 | else:
214 | self.affine.data[j][i] = torch.rand(4, ) * 2 - 1
215 | return None
216 | else:
217 | # initialize from x & y and fun
218 | params, r2 = fit_params(x, y, fun, a_range=a_range, b_range=b_range, verbose=verbose)
219 | self.funs[j][i] = fun
220 | self.affine.data[j][i] = params
221 | return r2
222 | else:
223 | # if fun_name itself is a function
224 | fun = fun_name
225 | fun_sympy = fun_name
226 | self.funs_sympy[j][i] = fun_sympy
227 | self.funs_name[j][i] = "anonymous"
228 |
229 | self.funs[j][i] = fun
230 | if random == False:
231 | self.affine.data[j][i] = torch.tensor([1., 0., 1., 0.])
232 | else:
233 | self.affine.data[j][i] = torch.rand(4, ) * 2 - 1
234 | return None
--------------------------------------------------------------------------------
/models/kan_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from .spline import *
5 |
6 |
7 | class KANLayer(nn.Module):
8 | """
9 | KANLayer class
10 |
11 |
12 | Attributes:
13 | -----------
14 | in_dim: int
15 | input dimension
16 | out_dim: int
17 | output dimension
18 | size: int
19 | the number of splines = input dimension * output dimension
20 | k: int
21 | the piecewise polynomial order of splines
22 | grid: 2D torch.float
23 | grid points
24 | noises: 2D torch.float
25 | injected noises to splines at initialization (to break degeneracy)
26 | coef: 2D torch.tensor
27 | coefficients of B-spline bases
28 | scale_base: 1D torch.float
29 | magnitude of the residual function b(x)
30 | scale_sp: 1D torch.float
31 | mangitude of the spline function spline(x)
32 | base_fun: fun
33 | residual function b(x)
34 | mask: 1D torch.float
35 | mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
36 | grid_eps: float in [0,1]
37 | a hyperparameter used in update_grid_from_samples. When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
38 | weight_sharing: 1D tensor int
39 | allow spline activations to share parameters
40 | lock_counter: int
41 | counter how many activation functions are locked (weight sharing)
42 | lock_id: 1D torch.int
43 | the id of activation functions that are locked
44 | device: str
45 | device
46 |
47 | Methods:
48 | --------
49 | __init__():
50 | initialize a KANLayer
51 | forward():
52 | forward
53 | update_grid_from_samples():
54 | update grids based on samples' incoming activations
55 | initialize_grid_from_parent():
56 | initialize grids from another model
57 | get_subset():
58 | get subset of the KANLayer (used for pruning)
59 | lock():
60 | lock several activation functions to share parameters
61 | unlock():
62 | unlock already locked activation functions
63 | """
64 |
65 | def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0,
66 | base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True,
67 | device='cpu'):
68 | ''''
69 | initialize a KANLayer
70 |
71 | Args:
72 | -----
73 | in_dim : int
74 | input dimension. Default: 2.
75 | out_dim : int
76 | output dimension. Default: 3.
77 | num : int
78 | the number of grid intervals = G. Default: 5.
79 | k : int
80 | the order of piecewise polynomial. Default: 3.
81 | noise_scale : float
82 | the scale of noise injected at initialization. Default: 0.1.
83 | scale_base : float
84 | the scale of the residual function b(x). Default: 1.0.
85 | scale_sp : float
86 | the scale of the base function spline(x). Default: 1.0.
87 | base_fun : function
88 | residual function b(x). Default: torch.nn.SiLU()
89 | grid_eps : float
90 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02.
91 | grid_range : list/np.array of shape (2,)
92 | setting the range of grids. Default: [-1,1].
93 | sp_trainable : bool
94 | If true, scale_sp is trainable. Default: True.
95 | sb_trainable : bool
96 | If true, scale_base is trainable. Default: True.
97 | device : str
98 | device
99 |
100 | Returns:
101 | --------
102 | self
103 |
104 | Example
105 | -------
106 | >>> model = KANLayer(in_dim=3, out_dim=5)
107 | >>> (model.in_dim, model.out_dim)
108 | (3, 5)
109 | '''
110 | super(KANLayer, self).__init__()
111 | # size
112 | self.size = size = out_dim * in_dim
113 | self.out_dim = out_dim
114 | self.in_dim = in_dim
115 | self.num = num
116 | self.k = k
117 |
118 | # shape: (size, num)
119 | self.grid = torch.einsum('i,j->ij', torch.ones(size, ),
120 | torch.linspace(grid_range[0], grid_range[1], steps=num + 1))
121 | self.grid = torch.nn.Parameter(self.grid).requires_grad_(False)
122 | noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num
123 | noises = noises.to(device)
124 | # shape: (size, coef)
125 | self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k))
126 | if isinstance(scale_base, float):
127 | self.scale_base = torch.nn.Parameter(torch.ones(size, ) * scale_base).requires_grad_(
128 | sb_trainable) # make scale trainable
129 | else:
130 | self.scale_base = torch.nn.Parameter(scale_base).requires_grad_(sb_trainable)
131 | self.scale_sp = torch.nn.Parameter(torch.ones(size, ) * scale_sp).requires_grad_(
132 | sp_trainable) # make scale trainable
133 | self.base_fun = base_fun
134 |
135 | self.mask = torch.nn.Parameter(torch.ones(size, )).requires_grad_(False)
136 | self.grid_eps = grid_eps
137 | self.weight_sharing = torch.arange(size)
138 | self.lock_counter = 0
139 | self.lock_id = torch.zeros(size)
140 | self.device = device
141 |
142 | def forward(self, x):
143 | '''
144 | KANLayer forward given input x
145 |
146 | Args:
147 | -----
148 | x : 2D torch.float
149 | inputs, shape (number of samples, input dimension)
150 |
151 | Returns:
152 | --------
153 | y : 2D torch.float
154 | outputs, shape (number of samples, output dimension)
155 | preacts : 3D torch.float
156 | fan out x into activations, shape (number of sampels, output dimension, input dimension)
157 | postacts : 3D torch.float
158 | the outputs of activation functions with preacts as inputs
159 | postspline : 3D torch.float
160 | the outputs of spline functions with preacts as inputs
161 |
162 | Example
163 | -------
164 | >>> model = KANLayer(in_dim=3, out_dim=5)
165 | >>> x = torch.normal(0,1,size=(100,3))
166 | >>> y, preacts, postacts, postspline = model(x)
167 | >>> y.shape, preacts.shape, postacts.shape, postspline.shape
168 | (torch.Size([100, 5]),
169 | torch.Size([100, 5, 3]),
170 | torch.Size([100, 5, 3]),
171 | torch.Size([100, 5, 3]))
172 | '''
173 | batch = x.shape[0]
174 | # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
175 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(
176 | 1, 0)
177 | preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim)
178 | base = self.base_fun(x).permute(1, 0) # shape (batch, size)
179 | y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k,
180 | device=self.device) # shape (size, batch)
181 | y = y.permute(1, 0) # shape (batch, size)
182 | postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
183 | y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
184 | y = self.mask[None, :] * y
185 | postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)
186 | y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2) # shape (batch, out_dim)
187 | # y shape: (batch, out_dim); preacts shape: (batch, in_dim, out_dim)
188 | # postspline shape: (batch, in_dim, out_dim); postacts: (batch, in_dim, out_dim)
189 | # postspline is for extension; postacts is for visualization
190 | return y, preacts, postacts, postspline
191 |
192 | def update_grid_from_samples(self, x):
193 | '''
194 | update grid from samples
195 |
196 | Args:
197 | -----
198 | x : 2D torch.float
199 | inputs, shape (number of samples, input dimension)
200 |
201 | Returns:
202 | --------
203 | None
204 |
205 | Example
206 | -------
207 | >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
208 | >>> print(model.grid.data)
209 | >>> x = torch.linspace(-3,3,steps=100)[:,None]
210 | >>> model.update_grid_from_samples(x)
211 | >>> print(model.grid.data)
212 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
213 | tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]])
214 | '''
215 | batch = x.shape[0]
216 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(
217 | 1, 0)
218 | x_pos = torch.sort(x, dim=1)[0]
219 | y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)
220 | num_interval = self.grid.shape[1] - 1
221 | ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
222 | grid_adaptive = x_pos[:, ids]
223 | margin = 0.01
224 | grid_uniform = torch.cat(
225 | [grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in
226 | np.linspace(0, 1, num=self.grid.shape[1])], dim=1)
227 | self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
228 | self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)
229 |
230 | def initialize_grid_from_parent(self, parent, x):
231 | '''
232 | update grid from a parent KANLayer & samples
233 |
234 | Args:
235 | -----
236 | parent : KANLayer
237 | a parent KANLayer (whose grid is usually coarser than the current model)
238 | x : 2D torch.float
239 | inputs, shape (number of samples, input dimension)
240 |
241 | Returns:
242 | --------
243 | None
244 |
245 | Example
246 | -------
247 | >>> batch = 100
248 | >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
249 | >>> print(parent_model.grid.data)
250 | >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
251 | >>> x = torch.normal(0,1,size=(batch, 1))
252 | >>> model.initialize_grid_from_parent(parent_model, x)
253 | >>> print(model.grid.data)
254 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
255 | tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000,
256 | 0.6000, 0.8000, 1.0000]])
257 | '''
258 | batch = x.shape[0]
259 | # preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
260 | x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch,
261 | self.size).permute(1,
262 | 0)
263 | x_pos = parent.grid
264 | sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device)
265 | sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
266 | y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device)
267 | percentile = torch.linspace(-1, 1, self.num + 1).to(self.device)
268 | self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)
269 | self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k, self.device)
270 |
271 | def get_subset(self, in_id, out_id):
272 | '''
273 | get a smaller KANLayer from a larger KANLayer (used for pruning)
274 |
275 | Args:
276 | -----
277 | in_id : list
278 | id of selected input neurons
279 | out_id : list
280 | id of selected output neurons
281 |
282 | Returns:
283 | --------
284 | spb : KANLayer
285 |
286 | Example
287 | -------
288 | >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
289 | >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
290 | >>> kanlayer_small.in_dim, kanlayer_small.out_dim
291 | (2, 3)
292 | '''
293 | spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
294 | spb.grid.data = self.grid.reshape(self.out_dim, self.in_dim, spb.num + 1)[out_id][:, in_id].reshape(-1,
295 | spb.num + 1)
296 | spb.coef.data = self.coef.reshape(self.out_dim, self.in_dim, spb.coef.shape[1])[out_id][:, in_id].reshape(-1,
297 | spb.coef.shape[
298 | 1])
299 | spb.scale_base.data = self.scale_base.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )
300 | spb.scale_sp.data = self.scale_sp.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )
301 | spb.mask.data = self.mask.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )
302 |
303 | spb.in_dim = len(in_id)
304 | spb.out_dim = len(out_id)
305 | spb.size = spb.in_dim * spb.out_dim
306 | return spb
307 |
308 | def lock(self, ids):
309 | '''
310 | lock activation functions to share parameters based on ids
311 |
312 | Args:
313 | -----
314 | ids : list
315 | list of ids of activation functions
316 |
317 | Returns:
318 | --------
319 | None
320 |
321 | Example
322 | -------
323 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3)
324 | >>> print(model.weight_sharing.reshape(3,3))
325 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same
326 | >>> print(model.weight_sharing.reshape(3,3))
327 | tensor([[0, 1, 2],
328 | [3, 4, 5],
329 | [6, 7, 8]])
330 | tensor([[0, 1, 2],
331 | [3, 4, 0],
332 | [6, 0, 8]])
333 | '''
334 | self.lock_counter += 1
335 | # ids: [[i1,j1],[i2,j2],[i3,j3],...]
336 | for i in range(len(ids)):
337 | if i != 0:
338 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[0][1] * self.in_dim + ids[0][0]
339 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = self.lock_counter
340 |
341 | def unlock(self, ids):
342 | '''
343 | unlock activation functions
344 |
345 | Args:
346 | -----
347 | ids : list
348 | list of ids of activation functions
349 |
350 | Returns:
351 | --------
352 | None
353 |
354 | Example
355 | -------
356 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3)
357 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same
358 | >>> print(model.weight_sharing.reshape(3,3))
359 | >>> model.unlock([[0,0],[1,2],[2,1]]) # unlock the locked functions
360 | >>> print(model.weight_sharing.reshape(3,3))
361 | tensor([[0, 1, 2],
362 | [3, 4, 0],
363 | [6, 0, 8]])
364 | tensor([[0, 1, 2],
365 | [3, 4, 5],
366 | [6, 7, 8]])
367 | '''
368 | # check ids are locked
369 | num = len(ids)
370 | locked = True
371 | for i in range(num):
372 | locked *= (self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] == self.weight_sharing[
373 | ids[0][1] * self.in_dim + ids[0][0]])
374 | if locked == False:
375 | print("they are not locked. unlock failed.")
376 | return 0
377 | for i in range(len(ids)):
378 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[i][1] * self.in_dim + ids[i][0]
379 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = 0
380 | self.lock_counter -= 1
--------------------------------------------------------------------------------
/models/spline.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def B_batch(x, grid, k=0, extend=True, device='cpu'):
5 | '''
6 | evaludate x on B-spline bases
7 |
8 | Args:
9 | -----
10 | x : 2D torch.tensor
11 | inputs, shape (number of splines, number of samples)
12 | grid : 2D torch.tensor
13 | grids, shape (number of splines, number of grid points)
14 | k : int
15 | the piecewise polynomial order of splines.
16 | extend : bool
17 | If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
18 | device : str
19 | devicde
20 |
21 | Returns:
22 | --------
23 | spline values : 3D torch.tensor
24 | shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1.
25 |
26 | Example
27 | -------
28 | >>> num_spline = 5
29 | >>> num_sample = 100
30 | >>> num_grid_interval = 10
31 | >>> k = 3
32 | >>> x = torch.normal(0,1,size=(num_spline, num_sample))
33 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
34 | >>> B_batch(x, grids, k=k).shape
35 | torch.Size([5, 13, 100])
36 | '''
37 |
38 | # x shape: (size, x); grid shape: (size, grid)
39 | def extend_grid(grid, k_extend=0):
40 | # pad k to left and right
41 | # grid shape: (batch, grid)
42 | h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
43 |
44 | for i in range(k_extend):
45 | grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
46 | grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
47 | grid = grid.to(device)
48 | return grid
49 |
50 | if extend == True:
51 | grid = extend_grid(grid, k_extend=k)
52 |
53 | grid = grid.unsqueeze(dim=2).to(device)
54 | x = x.unsqueeze(dim=1).to(device)
55 |
56 | if k == 0:
57 | value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
58 | else:
59 | B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
60 | value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
61 | grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
62 | return value
63 |
64 |
65 | def coef2curve(x_eval, grid, coef, k, device="cpu"):
66 | '''
67 | converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
68 |
69 | Args:
70 | -----
71 | x_eval : 2D torch.tensor)
72 | shape (number of splines, number of samples)
73 | grid : 2D torch.tensor)
74 | shape (number of splines, number of grid points)
75 | coef : 2D torch.tensor)
76 | shape (number of splines, number of coef params). number of coef params = number of grid intervals + k
77 | k : int
78 | the piecewise polynomial order of splines.
79 | device : str
80 | devicde
81 |
82 | Returns:
83 | --------
84 | y_eval : 2D torch.tensor
85 | shape (number of splines, number of samples)
86 |
87 | Example
88 | -------
89 | >>> num_spline = 5
90 | >>> num_sample = 100
91 | >>> num_grid_interval = 10
92 | >>> k = 3
93 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
94 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
95 | >>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))
96 | >>> coef2curve(x_eval, grids, coef, k=k).shape
97 | torch.Size([5, 100])
98 | '''
99 | # x_eval: (size, batch), grid: (size, grid), coef: (size, coef)
100 | # coef: (size, coef), B_batch: (size, coef, batch), summer over coef
101 | y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
102 | return y_eval
103 |
104 |
105 | def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
106 | '''
107 | converting B-spline curves to B-spline coefficients using least squares.
108 |
109 | Args:
110 | -----
111 | x_eval : 2D torch.tensor
112 | shape (number of splines, number of samples)
113 | y_eval : 2D torch.tensor
114 | shape (number of splines, number of samples)
115 | grid : 2D torch.tensor
116 | shape (number of splines, number of grid points)
117 | k : int
118 | the piecewise polynomial order of splines.
119 | device : str
120 | devicde
121 |
122 | Example
123 | -------
124 | >>> num_spline = 5
125 | >>> num_sample = 100
126 | >>> num_grid_interval = 10
127 | >>> k = 3
128 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
129 | >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample))
130 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
131 | >>> curve2coef(x_eval, y_eval, grids, k=k).shape
132 | torch.Size([5, 13])
133 | '''
134 | # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
135 | mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
136 | coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :,
137 | 0] # sometimes 'cuda' version may diverge
138 | return coef.to(device)
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.linear_model import LinearRegression
4 | import sympy
5 |
6 | # sigmoid = sympy.Function('sigmoid')
7 | # name: (torch implementation, sympy implementation)
8 | SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x),
9 | 'x^2': (lambda x: x ** 2, lambda x: x ** 2),
10 | 'x^3': (lambda x: x ** 3, lambda x: x ** 3),
11 | 'x^4': (lambda x: x ** 4, lambda x: x ** 4),
12 | '1/x': (lambda x: 1 / x, lambda x: 1 / x),
13 | '1/x^2': (lambda x: 1 / x ** 2, lambda x: 1 / x ** 2),
14 | '1/x^3': (lambda x: 1 / x ** 3, lambda x: 1 / x ** 3),
15 | '1/x^4': (lambda x: 1 / x ** 4, lambda x: 1 / x ** 4),
16 | 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x)),
17 | '1/sqrt(x)': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x)),
18 | 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x)),
19 | 'log': (lambda x: torch.log(x), lambda x: sympy.log(x)),
20 | 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x)),
21 | 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x)),
22 | 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x)),
23 | 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x)),
24 | 'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid')),
25 | # 'relu': (lambda x: torch.relu(x), relu),
26 | 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x)),
27 | 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.arcsin(x)),
28 | 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x)),
29 | 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x)),
30 | '0': (lambda x: x * 0, lambda x: x * 0),
31 | 'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2)),
32 | 'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x)),
33 | # 'logcosh': (lambda x: torch.log(torch.cosh(x)), lambda x: sympy.log(sympy.cosh(x))),
34 | # 'cosh^2': (lambda x: torch.cosh(x)**2, lambda x: sympy.cosh(x)**2),
35 | }
36 |
37 |
38 | def create_dataset(f,
39 | n_var=2,
40 | ranges=[-1, 1],
41 | train_num=1000,
42 | test_num=1000,
43 | normalize_input=False,
44 | normalize_label=False,
45 | device='cpu',
46 | seed=0):
47 | '''
48 | create dataset
49 |
50 | Args:
51 | -----
52 | f : function
53 | the symbolic formula used to create the synthetic dataset
54 | ranges : list or np.array; shape (2,) or (n_var, 2)
55 | the range of input variables. Default: [-1,1].
56 | train_num : int
57 | the number of training samples. Default: 1000.
58 | test_num : int
59 | the number of test samples. Default: 1000.
60 | normalize_input : bool
61 | If True, apply normalization to inputs. Default: False.
62 | normalize_label : bool
63 | If True, apply normalization to labels. Default: False.
64 | device : str
65 | device. Default: 'cpu'.
66 | seed : int
67 | random seed. Default: 0.
68 |
69 | Returns:
70 | --------
71 | dataset : dic
72 | Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
73 | dataset['test_input'], dataset['test_label']
74 |
75 | Example
76 | -------
77 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
78 | >>> dataset = create_dataset(f, n_var=2, train_num=100)
79 | >>> dataset['train_input'].shape
80 | torch.Size([100, 2])
81 | '''
82 |
83 | np.random.seed(seed)
84 | torch.manual_seed(seed)
85 |
86 | if len(np.array(ranges).shape) == 1:
87 | ranges = np.array(ranges * n_var).reshape(n_var, 2)
88 | else:
89 | ranges = np.array(ranges)
90 |
91 | train_input = torch.zeros(train_num, n_var)
92 | test_input = torch.zeros(test_num, n_var)
93 | for i in range(n_var):
94 | train_input[:, i] = torch.rand(train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
95 | test_input[:, i] = torch.rand(test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
96 |
97 | train_label = f(train_input)
98 | test_label = f(test_input)
99 |
100 | def normalize(data, mean, std):
101 | return (data - mean) / std
102 |
103 | if normalize_input == True:
104 | mean_input = torch.mean(train_input, dim=0, keepdim=True)
105 | std_input = torch.std(train_input, dim=0, keepdim=True)
106 | train_input = normalize(train_input, mean_input, std_input)
107 | test_input = normalize(test_input, mean_input, std_input)
108 |
109 | if normalize_label == True:
110 | mean_label = torch.mean(train_label, dim=0, keepdim=True)
111 | std_label = torch.std(train_label, dim=0, keepdim=True)
112 | train_label = normalize(train_label, mean_label, std_label)
113 | test_label = normalize(test_label, mean_label, std_label)
114 |
115 | dataset = {}
116 | dataset['train_input'] = train_input.to(device)
117 | dataset['test_input'] = test_input.to(device)
118 |
119 | dataset['train_label'] = train_label.to(device)
120 | dataset['test_label'] = test_label.to(device)
121 |
122 | return dataset
123 |
124 |
125 | def fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True):
126 | '''
127 | fit a, b, c, d such that
128 |
129 | .. math::
130 | |y-(cf(ax+b)+d)|^2
131 |
132 | is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
133 |
134 | Args:
135 | -----
136 | x : 1D array
137 | x values
138 | y : 1D array
139 | y values
140 | fun : function
141 | symbolic function
142 | a_range : tuple
143 | sweeping range of a
144 | b_range : tuple
145 | sweeping range of b
146 | grid_num : int
147 | number of steps along a and b
148 | iteration : int
149 | number of zooming in
150 | verbose : bool
151 | print extra information if True
152 |
153 | Returns:
154 | --------
155 | a_best : float
156 | best fitted a
157 | b_best : float
158 | best fitted b
159 | c_best : float
160 | best fitted c
161 | d_best : float
162 | best fitted d
163 | r2_best : float
164 | best r2 (coefficient of determination)
165 |
166 | Example
167 | -------
168 | >>> num = 100
169 | >>> x = torch.linspace(-1,1,steps=num)
170 | >>> noises = torch.normal(0,1,(num,)) * 0.02
171 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
172 | >>> fit_params(x, y, torch.sin)
173 | r2 is 0.9999727010726929
174 | (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
175 | '''
176 | # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
177 | # sweep a and b, choose the best fitted model
178 | for _ in range(iteration):
179 | a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number)
180 | b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number)
181 | a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
182 | post_fun = fun(a_grid[None, :, :] * x[:, None, None] + b_grid[None, :, :])
183 | x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
184 | y_mean = torch.mean(y, dim=[0], keepdim=True)
185 | numerator = torch.sum((post_fun - x_mean) * (y - y_mean)[:, None, None], dim=0) ** 2
186 | denominator = torch.sum((post_fun - x_mean) ** 2, dim=0) * torch.sum((y - y_mean)[:, None, None] ** 2, dim=0)
187 | r2 = numerator / (denominator + 1e-4)
188 | r2 = torch.nan_to_num(r2)
189 |
190 | best_id = torch.argmax(r2)
191 | a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
192 |
193 | if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
194 | if _ == 0 and verbose == True:
195 | print('Best value at boundary.')
196 | if a_id == 0:
197 | a_arange = [a_[0], a_[1]]
198 | if a_id == grid_number - 1:
199 | a_arange = [a_[-2], a_[-1]]
200 | if b_id == 0:
201 | b_arange = [b_[0], b_[1]]
202 | if b_id == grid_number - 1:
203 | b_arange = [b_[-2], b_[-1]]
204 |
205 | else:
206 | a_range = [a_[a_id - 1], a_[a_id + 1]]
207 | b_range = [b_[b_id - 1], b_[b_id + 1]]
208 |
209 | a_best = a_[a_id]
210 | b_best = b_[b_id]
211 | post_fun = fun(a_best * x + b_best)
212 | r2_best = r2[a_id, b_id]
213 |
214 | if verbose == True:
215 | print(f"r2 is {r2_best}")
216 | if r2_best < 0.9:
217 | print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')
218 |
219 | post_fun = torch.nan_to_num(post_fun)
220 | reg = LinearRegression().fit(post_fun[:, None].detach().numpy(), y.detach().numpy())
221 | c_best = torch.from_numpy(reg.coef_)[0]
222 | d_best = torch.from_numpy(np.array(reg.intercept_))
223 | return torch.stack([a_best, b_best, c_best, d_best]), r2_best
224 |
225 |
226 | def add_symbolic(name, fun):
227 | '''
228 | add a symbolic function to library
229 |
230 | Args:
231 | -----
232 | name : str
233 | name of the function
234 | fun : fun
235 | torch function or lambda function
236 |
237 | Returns:
238 | --------
239 | None
240 |
241 | Example
242 | -------
243 | >>> print(SYMBOLIC_LIB['Bessel'])
244 | KeyError: 'Bessel'
245 | >>> add_symbolic('Bessel', torch.special.bessel_j0)
246 | >>> print(SYMBOLIC_LIB['Bessel'])
247 | (, Bessel)
248 | '''
249 | exec(f"globals()['{name}'] = sympy.Function('{name}')")
250 | SYMBOLIC_LIB[name] = (fun, globals()[name])
251 |
--------------------------------------------------------------------------------
/pic/algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/algorithm.png
--------------------------------------------------------------------------------
/pic/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/comparison.png
--------------------------------------------------------------------------------
/pic/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/results.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | tqdm
3 | einops
4 | deepxde
5 | scipy
6 | scikit-learn
7 | pykan
8 | sympy
9 | torch==1.13.0
10 |
--------------------------------------------------------------------------------
/scripts/1d_reaction_point.sh:
--------------------------------------------------------------------------------
1 | python 1d_reaction_point_optimization.py --model PINN --device 'cuda:0'
2 | python 1d_reaction_point_optimization.py --model QRes --device 'cuda:0'
3 | python 1d_reaction_point_optimization.py --model FLS --device 'cuda:0'
4 | python 1d_reaction_point_optimization.py --model KAN --device 'cuda:0'
5 | python 1d_reaction_point_optimization.py --model PINNsFormer --device 'cuda:0'
--------------------------------------------------------------------------------
/scripts/1d_reaction_region.sh:
--------------------------------------------------------------------------------
1 | python 1d_reaction_region_optimization.py --model PINN --device 'cuda:0'
2 | python 1d_reaction_region_optimization.py --model QRes --device 'cuda:0'
3 | python 1d_reaction_region_optimization.py --model FLS --device 'cuda:0'
4 | python 1d_reaction_region_optimization.py --model KAN --device 'cuda:0'
5 | python 1d_reaction_region_optimization.py --model PINNsFormer --device 'cuda:0'
--------------------------------------------------------------------------------
/scripts/1d_wave_point.sh:
--------------------------------------------------------------------------------
1 | python 1d_wave_point_optimization.py --model PINN --device 'cuda:0'
2 | python 1d_wave_point_optimization.py --model QRes --device 'cuda:0'
3 | python 1d_wave_point_optimization.py --model FLS --device 'cuda:0'
4 | python 1d_wave_point_optimization.py --model KAN --device 'cuda:0'
5 | python 1d_wave_point_optimization.py --model PINNsFormer_Enc_Only --device 'cuda:0'
--------------------------------------------------------------------------------
/scripts/1d_wave_region.sh:
--------------------------------------------------------------------------------
1 | python 1d_wave_region_optimization.py --model PINN --device 'cuda:0'
2 | python 1d_wave_region_optimization.py --model QRes --device 'cuda:0'
3 | python 1d_wave_region_optimization.py --model FLS --device 'cuda:0'
4 | python 1d_wave_region_optimization.py --model KAN --device 'cuda:0'
5 | python 1d_wave_region_optimization.py --model PINNsFormer_Enc_Only --device 'cuda:0'
--------------------------------------------------------------------------------
/scripts/convection_point.sh:
--------------------------------------------------------------------------------
1 | python convection_point_optimization.py --model PINN --device 'cuda:0'
2 | python convection_point_optimization.py --model QRes --device 'cuda:0'
3 | python convection_point_optimization.py --model FLS --device 'cuda:0'
4 | python convection_point_optimization.py --model KAN --device 'cuda:0'
5 | python convection_point_optimization.py --model PINNsFormer --device 'cuda:0'
--------------------------------------------------------------------------------
/scripts/convection_region.sh:
--------------------------------------------------------------------------------
1 | python convection_region_optimization.py --model PINN --device 'cuda:0'
2 | python convection_region_optimization.py --model QRes --device 'cuda:0'
3 | python convection_region_optimization.py --model FLS --device 'cuda:0'
4 | python convection_region_optimization.py --model KAN --device 'cuda:0' # for KAN, the best past_iterations is 15
5 | python convection_region_optimization.py --model PINNsFormer --device 'cuda:0'
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import copy
4 |
5 |
6 | def get_data(x_range, y_range, x_num, y_num):
7 | x = np.linspace(x_range[0], x_range[1], x_num)
8 | t = np.linspace(y_range[0], y_range[1], y_num)
9 |
10 | x_mesh, t_mesh = np.meshgrid(x, t)
11 | data = np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1)
12 |
13 | b_left = data[0, :, :]
14 | b_right = data[-1, :, :]
15 | b_upper = data[:, -1, :]
16 | b_lower = data[:, 0, :]
17 | res = data.reshape(-1, 2)
18 |
19 | return res, b_left, b_right, b_upper, b_lower
20 |
21 |
22 | def get_n_params(model):
23 | pp = 0
24 | for p in list(model.parameters()):
25 | nn = 1
26 | for s in list(p.size()):
27 | nn = nn * s
28 | pp += nn
29 | return pp
30 |
31 |
32 | def make_time_sequence(src, num_step=5, step=1e-4):
33 | dim = num_step
34 | src = np.repeat(np.expand_dims(src, axis=1), dim, axis=1) # (N, L, 2)
35 | for i in range(num_step):
36 | src[:, i, -1] += step * i
37 | return src
38 |
39 |
40 | def make_space_time_sequence(src, space_num_step=5, space_step=1e-4, time_num_step=5, time_step=1e-4):
41 | dim = space_num_step * time_num_step
42 | src = np.repeat(np.expand_dims(src, axis=1), dim, axis=1) # (N, L, 2)
43 | for i in range(time_num_step):
44 | for j in range(space_num_step):
45 | src[:, i * space_num_step + j, -1] += time_step * i
46 | src[:, i * space_num_step + j, 0] += space_step * (j - space_num_step // 2)
47 | return src
48 |
49 |
50 | def get_clones(module, N):
51 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
52 |
53 |
54 | def get_data_3d(x_range, y_range, t_range, x_num, y_num, t_num):
55 | step_x = (x_range[1] - x_range[0]) / float(x_num - 1)
56 | step_y = (y_range[1] - y_range[0]) / float(y_num - 1)
57 | step_t = (t_range[1] - t_range[0]) / float(t_num - 1)
58 |
59 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y,
60 | t_range[0]:t_range[1] + step_t:step_t]
61 |
62 | data = np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1)
63 | res = data.reshape(-1, 3)
64 |
65 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[0] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y,
66 | t_range[0]:t_range[1] + step_t:step_t]
67 | b_left = np.squeeze(
68 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[
69 | 1:-1].reshape(-1, 3)
70 |
71 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[1]:x_range[1] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y,
72 | t_range[0]:t_range[1] + step_t:step_t]
73 | b_right = np.squeeze(
74 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[
75 | 1:-1].reshape(-1, 3)
76 |
77 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[0]:y_range[0] + step_y:step_y,
78 | t_range[0]:t_range[1] + step_t:step_t]
79 | b_lower = np.squeeze(
80 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[
81 | 1:-1].reshape(-1, 3)
82 |
83 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[1]:y_range[1] + step_y:step_y,
84 | t_range[0]:t_range[1] + step_t:step_t]
85 | b_upper = np.squeeze(
86 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[
87 | 1:-1].reshape(-1, 3)
88 |
89 | return res, b_left, b_right, b_upper, b_lower
90 |
--------------------------------------------------------------------------------