├── .gitignore ├── 1d_reaction_point_optimization.py ├── 1d_reaction_region_optimization.py ├── 1d_wave_point_optimization.py ├── 1d_wave_region_optimization.py ├── LICENSE ├── README.md ├── convection_point_optimization.py ├── convection_region_optimization.py ├── model_dict.py ├── models ├── FLS.py ├── KAN.py ├── LBFGS.py ├── PINN.py ├── PINNsFormer.py ├── PINNsFormer_Enc_Only.py ├── QRes.py ├── Symbolic_KANLayer.py ├── kan_layer.py ├── spline.py └── utils.py ├── pic ├── algorithm.png ├── comparison.png └── results.png ├── requirements.txt ├── scripts ├── 1d_reaction_point.sh ├── 1d_reaction_region.sh ├── 1d_wave_point.sh ├── 1d_wave_region.sh ├── convection_point.sh └── convection_region.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /1d_reaction_point_optimization.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | import random 7 | from torch.optim import LBFGS 8 | from tqdm import tqdm 9 | import argparse 10 | from util import * 11 | from model_dict import get_model 12 | 13 | seed = 0 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | 19 | parser = argparse.ArgumentParser('Training Point Optimization') 20 | parser.add_argument('--model', type=str, default='pinn') 21 | parser.add_argument('--device', type=str, default='cuda:0') 22 | args = parser.parse_args() 23 | device = args.device 24 | 25 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101) 26 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 27 | 28 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 29 | res = make_time_sequence(res, num_step=5, step=1e-4) 30 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 31 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 32 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 33 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 34 | 35 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 36 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 37 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 38 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 39 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 40 | 41 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 42 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 43 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 44 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 45 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 46 | 47 | 48 | def init_weights(m): 49 | if isinstance(m, nn.Linear): 50 | torch.nn.init.xavier_uniform(m.weight) 51 | m.bias.data.fill_(0.01) 52 | 53 | 54 | if args.model == 'KAN': 55 | model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \ 56 | noise_scale_base=0.25, device=device).to(device) 57 | elif args.model == 'QRes': 58 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device) 59 | model.apply(init_weights) 60 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 61 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 62 | model.apply(init_weights) 63 | else: 64 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 65 | model.apply(init_weights) 66 | 67 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 68 | 69 | print(model) 70 | print(get_n_params(model)) 71 | loss_track = [] 72 | 73 | for i in tqdm(range(1000)): 74 | def closure(): 75 | pred_res = model(x_res, t_res) 76 | pred_left = model(x_left, t_left) 77 | pred_right = model(x_right, t_right) 78 | pred_upper = model(x_upper, t_upper) 79 | pred_lower = model(x_lower, t_lower) 80 | 81 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 82 | create_graph=True)[0] 83 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 84 | create_graph=True)[0] 85 | 86 | loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2) 87 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2) 88 | loss_ic = torch.mean( 89 | (pred_left[:, 0] - torch.exp(- (x_left[:, 0] - torch.pi) ** 2 / (2 * (torch.pi / 4) ** 2))) ** 2) 90 | 91 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()]) 92 | 93 | loss = loss_res + loss_bc + loss_ic 94 | optim.zero_grad() 95 | loss.backward() 96 | return loss 97 | 98 | 99 | optim.step(closure) 100 | 101 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 102 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 103 | 104 | if not os.path.exists('./results/'): 105 | os.makedirs('./results/') 106 | torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_point.pt') 107 | 108 | # Visualize 109 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 110 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 111 | 112 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 113 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 114 | 115 | with torch.no_grad(): 116 | pred = model(x_test, t_test)[:, 0:1] 117 | pred = pred.cpu().detach().numpy() 118 | 119 | pred = pred.reshape(101, 101) 120 | 121 | 122 | def h(x): 123 | return np.exp(- (x - np.pi) ** 2 / (2 * (np.pi / 4) ** 2)) 124 | 125 | 126 | def u_ana(x, t): 127 | return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x)) 128 | 129 | 130 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 131 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 132 | 133 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 134 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 135 | 136 | print('relative L1 error: {:4f}'.format(rl1)) 137 | print('relative L2 error: {:4f}'.format(rl2)) 138 | 139 | plt.figure(figsize=(4, 3)) 140 | plt.imshow(pred, aspect='equal') 141 | plt.xlabel('x') 142 | plt.ylabel('t') 143 | plt.title('Predicted u(x,t)') 144 | plt.colorbar() 145 | plt.tight_layout() 146 | plt.axis('off') 147 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight') 148 | 149 | plt.figure(figsize=(4, 3)) 150 | plt.imshow(u, aspect='equal') 151 | plt.xlabel('x') 152 | plt.ylabel('t') 153 | plt.title('Exact u(x,t)') 154 | plt.colorbar() 155 | plt.tight_layout() 156 | plt.axis('off') 157 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight') 158 | 159 | plt.figure(figsize=(4, 3)) 160 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.15, vmax=0.15) 161 | plt.xlabel('x') 162 | plt.ylabel('t') 163 | plt.title('Absolute Error') 164 | plt.colorbar() 165 | plt.tight_layout() 166 | plt.axis('off') 167 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight') 168 | -------------------------------------------------------------------------------- /1d_reaction_region_optimization.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | import random 7 | from torch.optim import LBFGS 8 | from tqdm import tqdm 9 | import argparse 10 | from util import * 11 | from model_dict import get_model 12 | 13 | seed = 0 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | 19 | parser = argparse.ArgumentParser('Training Region Optimization') 20 | parser.add_argument('--model', type=str, default='pinn') 21 | parser.add_argument('--device', type=str, default='cuda:0') 22 | parser.add_argument('--initial_region', type=float, default=1e-4) 23 | parser.add_argument('--sample_num', type=int, default=1) 24 | parser.add_argument('--past_iterations', type=int, default=10) 25 | args = parser.parse_args() 26 | device = args.device 27 | 28 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101) 29 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 30 | 31 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 32 | res = make_time_sequence(res, num_step=5, step=1e-4) 33 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 34 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 35 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 36 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 37 | 38 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 39 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 40 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 41 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 42 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 43 | 44 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 45 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 46 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 47 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 48 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 49 | 50 | 51 | def init_weights(m): 52 | if isinstance(m, nn.Linear): 53 | torch.nn.init.xavier_uniform(m.weight) 54 | m.bias.data.fill_(0.01) 55 | 56 | 57 | if args.model == 'KAN': 58 | model = get_model(args).Model(width=[2, 5, 1], grid=5, k=3, grid_eps=1.0, \ 59 | noise_scale_base=0.25, device=device).to(device) 60 | elif args.model == 'QRes': 61 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device) 62 | model.apply(init_weights) 63 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 64 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 65 | model.apply(init_weights) 66 | else: 67 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 68 | model.apply(init_weights) 69 | 70 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 71 | 72 | print(model) 73 | print(get_n_params(model)) 74 | loss_track = [] 75 | 76 | # for region optimization 77 | initial_region = args.initial_region 78 | sample_num = args.sample_num 79 | past_iterations = args.past_iterations 80 | gradient_list_overall = [] 81 | gradient_list_temp = [] 82 | gradient_variance = 1 83 | 84 | for i in tqdm(range(1000)): 85 | 86 | ###### Region Optimization with Monte Carlo Approximation ###### 87 | def closure(): 88 | x_res_region_sample_list = [] 89 | t_res_region_sample_list = [] 90 | for i in range(sample_num): 91 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance, 92 | a_min=0, 93 | a_max=0.01) 94 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance, 95 | a_min=0, 96 | a_max=0.01) 97 | x_res_region_sample_list.append(x_res + x_region_sample) 98 | t_res_region_sample_list.append(t_res + t_region_sample) 99 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0) 100 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0) 101 | pred_res = model(x_res_region_sample, t_res_region_sample) 102 | pred_left = model(x_left, t_left) 103 | pred_right = model(x_right, t_right) 104 | pred_upper = model(x_upper, t_upper) 105 | pred_lower = model(x_lower, t_lower) 106 | 107 | u_x = \ 108 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res), 109 | retain_graph=True, 110 | create_graph=True)[0] 111 | u_t = \ 112 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res), 113 | retain_graph=True, 114 | create_graph=True)[0] 115 | 116 | loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2) 117 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2) 118 | loss_ic = torch.mean( 119 | (pred_left[:, 0] - torch.exp(- (x_left[:, 0] - torch.pi) ** 2 / (2 * (torch.pi / 4) ** 2))) ** 2) 120 | 121 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()]) 122 | 123 | loss = loss_res + loss_bc + loss_ic 124 | optim.zero_grad() 125 | loss.backward(retain_graph=True) 126 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in 127 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph 128 | return loss 129 | 130 | 131 | optim.step(closure) 132 | 133 | ###### Trust Region Calibration ###### 134 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0)) 135 | gradient_list_overall = gradient_list_overall[-past_iterations:] 136 | gradient_list = np.array(gradient_list_overall) 137 | gradient_variance = ( 138 | np.std(gradient_list, axis=0) / (np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean() 139 | gradient_list_temp = [] 140 | if gradient_variance == 0: 141 | gradient_variance = 1 # for numerical stability 142 | 143 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 144 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 145 | 146 | if not os.path.exists('./results/'): 147 | os.makedirs('./results/') 148 | torch.save(model.state_dict(), f'./results/1dreaction_{args.model}_region.pt') 149 | 150 | # Visualize 151 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 152 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 153 | 154 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 155 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 156 | 157 | with torch.no_grad(): 158 | pred = model(x_test, t_test)[:, 0:1] 159 | pred = pred.cpu().detach().numpy() 160 | 161 | pred = pred.reshape(101, 101) 162 | 163 | 164 | def h(x): 165 | return np.exp(- (x - np.pi) ** 2 / (2 * (np.pi / 4) ** 2)) 166 | 167 | 168 | def u_ana(x, t): 169 | return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x)) 170 | 171 | 172 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 173 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 174 | 175 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 176 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 177 | 178 | print('relative L1 error: {:4f}'.format(rl1)) 179 | print('relative L2 error: {:4f}'.format(rl2)) 180 | 181 | plt.figure(figsize=(4, 3)) 182 | plt.imshow(pred, aspect='equal') 183 | plt.xlabel('x') 184 | plt.ylabel('t') 185 | plt.title('Predicted u(x,t)') 186 | plt.colorbar() 187 | plt.tight_layout() 188 | plt.axis('off') 189 | plt.savefig(f'./results/1dreaction_{args.model}_region_optimization_pred.pdf', bbox_inches='tight') 190 | 191 | plt.figure(figsize=(4, 3)) 192 | plt.imshow(u, aspect='equal') 193 | plt.xlabel('x') 194 | plt.ylabel('t') 195 | plt.title('Exact u(x,t)') 196 | plt.colorbar() 197 | plt.tight_layout() 198 | plt.axis('off') 199 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight') 200 | 201 | plt.figure(figsize=(4, 3)) 202 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.15, vmax=0.15) 203 | plt.xlabel('x') 204 | plt.ylabel('t') 205 | plt.title('Absolute Error') 206 | plt.colorbar() 207 | plt.tight_layout() 208 | plt.axis('off') 209 | plt.savefig(f'./results/1dreaction_{args.model}_region_optimization_error.pdf', bbox_inches='tight') 210 | -------------------------------------------------------------------------------- /1d_wave_point_optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import matplotlib.pyplot as plt 5 | import random 6 | from torch.optim import LBFGS, Adam 7 | from tqdm import tqdm 8 | import argparse 9 | from util import * 10 | from model_dict import get_model 11 | 12 | seed = 0 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | parser = argparse.ArgumentParser('Training Point Optimization') 19 | parser.add_argument('--model', type=str, default='PINN') 20 | parser.add_argument('--device', type=str, default='cuda:0') 21 | args = parser.parse_args() 22 | device = args.device 23 | 24 | res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101) 25 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101) 26 | 27 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 28 | res = make_time_sequence(res, num_step=5, step=1e-4) 29 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 30 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 31 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 32 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 33 | 34 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 35 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 36 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 37 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 38 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 39 | 40 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 41 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 42 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 43 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 44 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 45 | 46 | 47 | def init_weights(m): 48 | if isinstance(m, nn.Linear): 49 | torch.nn.init.xavier_uniform(m.weight) 50 | m.bias.data.fill_(0.01) 51 | 52 | 53 | if args.model == 'KAN': 54 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \ 55 | noise_scale_base=0.25, device=device).to(device) 56 | elif args.model == 'QRes': 57 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device) 58 | model.apply(init_weights) 59 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 60 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 61 | model.apply(init_weights) 62 | else: 63 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 64 | model.apply(init_weights) 65 | 66 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 67 | 68 | n_params = get_n_params(model) 69 | 70 | print(model) 71 | print(get_n_params(model)) 72 | 73 | loss_track = [] 74 | pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False).to(device) 75 | 76 | for i in tqdm(range(1000)): 77 | def closure(): 78 | pred_res = model(x_res, t_res) 79 | pred_left = model(x_left, t_left) 80 | pred_right = model(x_right, t_right) 81 | pred_upper = model(x_upper, t_upper) 82 | pred_lower = model(x_lower, t_lower) 83 | 84 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 85 | create_graph=True)[0] 86 | u_xx = \ 87 | torch.autograd.grad(u_x, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 88 | create_graph=True)[0] 89 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 90 | create_graph=True)[0] 91 | u_tt = \ 92 | torch.autograd.grad(u_t, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 93 | create_graph=True)[0] 94 | 95 | loss_res = torch.mean((u_tt - 4 * u_xx) ** 2) 96 | loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower) ** 2) 97 | 98 | ui_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True, 99 | create_graph=True)[0] 100 | 101 | loss_ic_1 = torch.mean( 102 | (pred_left[:, 0] - torch.sin(pi * x_left[:, 0]) - 0.5 * torch.sin(3 * pi * x_left[:, 0])) ** 2) 103 | loss_ic_2 = torch.mean((ui_t) ** 2) 104 | 105 | loss_ic = loss_ic_1 + loss_ic_2 106 | 107 | loss_track.append([loss_res.item(), loss_ic.item(), loss_bc.item()]) 108 | 109 | loss = loss_res + loss_ic + loss_bc 110 | optim.zero_grad() 111 | loss.backward() 112 | return loss 113 | 114 | 115 | optim.step(closure) 116 | 117 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 118 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 119 | 120 | if not os.path.exists('./results/'): 121 | os.makedirs('./results/') 122 | 123 | torch.save(model.state_dict(), f'./results/1dwave_{args.model}_point.pt') 124 | 125 | # Visualize 126 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 127 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 128 | 129 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 130 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 131 | 132 | with torch.no_grad(): 133 | pred = model(x_test, t_test)[:, 0:1] 134 | pred = pred.cpu().detach().numpy() 135 | 136 | pred = pred.reshape(101, 101) 137 | 138 | 139 | def u_ana(x, t): 140 | return np.sin(np.pi * x) * np.cos(2 * np.pi * t) + 0.5 * np.sin(3 * np.pi * x) * np.cos(6 * np.pi * t) 141 | 142 | 143 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101) 144 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 145 | 146 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 147 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 148 | 149 | print('relative L1 error: {:4f}'.format(rl1)) 150 | print('relative L2 error: {:4f}'.format(rl2)) 151 | 152 | plt.figure(figsize=(4, 3)) 153 | plt.imshow(pred, aspect='equal') 154 | plt.xlabel('x') 155 | plt.ylabel('t') 156 | plt.title('Predicted u(x,t)') 157 | plt.colorbar() 158 | plt.tight_layout() 159 | plt.axis('off') 160 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight') 161 | 162 | plt.figure(figsize=(4, 3)) 163 | plt.imshow(u, aspect='equal') 164 | plt.xlabel('x') 165 | plt.ylabel('t') 166 | plt.title('Exact u(x,t)') 167 | plt.colorbar() 168 | plt.tight_layout() 169 | plt.axis('off') 170 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight') 171 | 172 | plt.figure(figsize=(4, 3)) 173 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.3, vmax=0.3) 174 | plt.xlabel('x') 175 | plt.ylabel('t') 176 | plt.title('Absolute Error') 177 | plt.colorbar() 178 | plt.tight_layout() 179 | plt.axis('off') 180 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight') 181 | -------------------------------------------------------------------------------- /1d_wave_region_optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import matplotlib.pyplot as plt 5 | import random 6 | from torch.optim import LBFGS, Adam 7 | from tqdm import tqdm 8 | import argparse 9 | from util import * 10 | from model_dict import get_model 11 | 12 | seed = 0 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | parser = argparse.ArgumentParser('Training Region Optimization') 19 | parser.add_argument('--model', type=str, default='pinn') 20 | parser.add_argument('--device', type=str, default='cuda:0') 21 | parser.add_argument('--initial_region', type=float, default=1e-4) 22 | parser.add_argument('--sample_num', type=int, default=1) 23 | parser.add_argument('--past_iterations', type=int, default=10) 24 | args = parser.parse_args() 25 | device = args.device 26 | 27 | res, b_left, b_right, b_upper, b_lower = get_data([0, 1], [0, 1], 101, 101) 28 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101) 29 | 30 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 31 | res = make_time_sequence(res, num_step=5, step=1e-4) 32 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 33 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 34 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 35 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 36 | 37 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 38 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 39 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 40 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 41 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 42 | 43 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 44 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 45 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 46 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 47 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 48 | 49 | 50 | def init_weights(m): 51 | if isinstance(m, nn.Linear): 52 | torch.nn.init.xavier_uniform(m.weight) 53 | m.bias.data.fill_(0.01) 54 | 55 | 56 | if args.model == 'KAN': 57 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \ 58 | noise_scale_base=0.25, device=device).to(device) 59 | elif args.model == 'QRes': 60 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device) 61 | model.apply(init_weights) 62 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 63 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 64 | model.apply(init_weights) 65 | else: 66 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 67 | model.apply(init_weights) 68 | 69 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 70 | 71 | n_params = get_n_params(model) 72 | 73 | print(model) 74 | print(get_n_params(model)) 75 | loss_track = [] 76 | pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False).to(device) 77 | 78 | # for region optimization 79 | initial_region = args.initial_region 80 | sample_num = args.sample_num 81 | past_iterations = args.past_iterations 82 | gradient_list_overall = [] 83 | gradient_list_temp = [] 84 | gradient_variance = 1 85 | 86 | for i in tqdm(range(1000)): 87 | 88 | ###### Region Optimization with Monte Carlo Approximation ###### 89 | def closure(): 90 | x_res_region_sample_list = [] 91 | t_res_region_sample_list = [] 92 | for i in range(sample_num): 93 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance, 94 | a_min=0, 95 | a_max=0.01) 96 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance, 97 | a_min=0, 98 | a_max=0.01) 99 | x_res_region_sample_list.append(x_res + x_region_sample) 100 | t_res_region_sample_list.append(t_res + t_region_sample) 101 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0) 102 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0) 103 | pred_res = model(x_res_region_sample, t_res_region_sample) 104 | pred_left = model(x_left, t_left) 105 | pred_right = model(x_right, t_right) 106 | pred_upper = model(x_upper, t_upper) 107 | pred_lower = model(x_lower, t_lower) 108 | 109 | u_x = \ 110 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res), 111 | retain_graph=True, 112 | create_graph=True)[0] 113 | u_xx = torch.autograd.grad(u_x, x_res_region_sample, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 114 | create_graph=True)[0] 115 | u_t = \ 116 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res), 117 | retain_graph=True, 118 | create_graph=True)[0] 119 | u_tt = torch.autograd.grad(u_t, t_res_region_sample, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 120 | create_graph=True)[0] 121 | 122 | loss_res = torch.mean((u_tt - 4 * u_xx) ** 2) 123 | loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower) ** 2) 124 | 125 | ui_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True, 126 | create_graph=True)[0] 127 | 128 | loss_ic_1 = torch.mean( 129 | (pred_left[:, 0] - torch.sin(pi * x_left[:, 0]) - 0.5 * torch.sin(3 * pi * x_left[:, 0])) ** 2) 130 | loss_ic_2 = torch.mean((ui_t) ** 2) 131 | 132 | loss_ic = loss_ic_1 + loss_ic_2 133 | 134 | loss_track.append([loss_res.item(), loss_ic.item(), loss_bc.item()]) 135 | 136 | loss = loss_res + loss_ic + loss_bc 137 | optim.zero_grad() 138 | loss.backward(retain_graph=True) 139 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in 140 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph 141 | return loss 142 | 143 | 144 | optim.step(closure) 145 | 146 | ###### Trust Region Calibration ###### 147 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0)) 148 | gradient_list_overall = gradient_list_overall[-past_iterations:] 149 | gradient_list = np.array(gradient_list_overall) 150 | gradient_variance = (np.std(gradient_list, axis=0) / ( 151 | np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean() # normalized variance 152 | gradient_list_temp = [] 153 | print(gradient_variance) 154 | if gradient_variance == 0: 155 | gradient_variance = 1 # for numerical stability 156 | 157 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 158 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 159 | 160 | if not os.path.exists('./results/'): 161 | os.makedirs('./results/') 162 | torch.save(model.state_dict(), f'./results/1dwave_{args.model}_region.pt') 163 | 164 | # Visualize PINNs 165 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 166 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 167 | 168 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 169 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 170 | 171 | with torch.no_grad(): 172 | pred = model(x_test, t_test)[:, 0:1] 173 | pred = pred.cpu().detach().numpy() 174 | 175 | pred = pred.reshape(101, 101) 176 | 177 | 178 | def u_ana(x, t): 179 | return np.sin(np.pi * x) * np.cos(2 * np.pi * t) + 0.5 * np.sin(3 * np.pi * x) * np.cos(6 * np.pi * t) 180 | 181 | 182 | res_test, _, _, _, _ = get_data([0, 1], [0, 1], 101, 101) 183 | u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 184 | 185 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 186 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 187 | 188 | print('relative L1 error: {:4f}'.format(rl1)) 189 | print('relative L2 error: {:4f}'.format(rl2)) 190 | 191 | plt.figure(figsize=(4, 3)) 192 | plt.imshow(pred, aspect='equal') 193 | plt.xlabel('x') 194 | plt.ylabel('t') 195 | plt.title('Predicted u(x,t)') 196 | plt.colorbar() 197 | plt.tight_layout() 198 | plt.axis('off') 199 | plt.savefig(f'./results/1dwave_{args.model}_region_optimization_pred.pdf', bbox_inches='tight') 200 | 201 | plt.figure(figsize=(4, 3)) 202 | plt.imshow(u, aspect='equal') 203 | plt.xlabel('x') 204 | plt.ylabel('t') 205 | plt.title('Exact u(x,t)') 206 | plt.colorbar() 207 | plt.tight_layout() 208 | plt.axis('off') 209 | plt.savefig('./results/1dwave_exact.pdf', bbox_inches='tight') 210 | 211 | plt.figure(figsize=(4, 3)) 212 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-0.3, vmax=0.3) 213 | plt.xlabel('x') 214 | plt.ylabel('t') 215 | plt.title('Absolute Error') 216 | plt.colorbar() 217 | plt.tight_layout() 218 | plt.axis('off') 219 | plt.savefig(f'./results/1dwave_{args.model}_region_optimization_error.pdf', bbox_inches='tight') 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoPINN (NeurIPS 2024) 2 | 3 | RoPINN: Region Optimized Physics-Informed Neural Networks. See [Paper](https://arxiv.org/abs/2405.14369) or [Slides](https://wuhaixu2016.github.io/pdf/NeurIPS2024_RoPINN.pdf). 4 | 5 | This paper proposes and theoretically studies a new training paradigm of PINNs as **region optimization** and presents [RoPINN](https://arxiv.org/abs/2405.14369) as a practical algorithm, which can bring the following benefits: 6 | 7 | - **Better generalization bound:** Introducing "region" can theoretically decrease generalization error and provide a general theoretical framework that first reveals the balance between generalization and optimization. 8 | - **Efficient practical algorithm:** We present RoPINN with a trust region calibration strategy, which can effectively accomplish the region optimization and reduce the gradient estimation error caused by sampling. 9 | - **Boost extensive backbones:** RoPINN consistently improves various PINN backbones (i.e. PINN, KAN and PINNsFormer) on a wide range of PDEs (19 different tasks) without extra gradient calculation. 10 | 11 | ## Point Optimization vs. Region Optimization 12 | 13 | Unlike conventional point optimization, our proposed region optimization extends the optimization process of PINNs from isolated points to their continuous neighborhood region. 14 | 15 |

16 | 17 |

18 | Figure 1. Comparison between previous methods and RoPINN. 19 |

20 | 21 | ## Practical Algorithm 22 | 23 | We present RoPINN for PINN training based on Monte Carlo sampling, which can effectively accomplish the region optimization without extra gradient calculation. A trust region calibration strategy is proposed to reduce the gradient estimation error caused by sampling for more trustworthy optimization. 24 | 25 |

26 | 27 |

28 | Figure 2. RoPINN algorithm. 29 |

30 | 31 | ## Get Started 32 | 33 | 1. Install Python 3.8 or Python 3.9 and **Pytorch 1.13.0**. For convenience, execute the following command. 34 | 35 | ```shell 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | 2. Train and evaluate model. We provide the experiment scripts of all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples: 40 | 41 | ```shell 42 | bash scripts/1d_reaction_point.sh # canonical point optimization 43 | bash scripts/1d_reaction_region.sh # RoPINN: region optimization 44 | bash scripts/1d_wave_point.sh # canonical point optimization 45 | bash scripts/1d_wave_region.sh # RoPINN: region optimization 46 | bash scripts/convection_point.sh # canonical point optimization 47 | bash scripts/convection_region.sh # RoPINN: region optimization 48 | ``` 49 | 50 | Specifically, we have included the following PINN models in this repo: 51 | 52 | - [x] PINN (Journal of Computational Physics 2019) [[Paper]](https://github.com/maziarraissi/PINNs) 53 | - [x] FLS - (IEEE Transactions on Artificial Intelligence 2022) [[Paper]](https://arxiv.org/abs/2109.09338) 54 | - [x] QRes - (SIAM 2021) [[Paper]](https://arxiv.org/abs/2101.08366) 55 | - [x] KAN - (arXiv 2024) [[Paper]](https://arxiv.org/abs/2404.19756) 56 | - [x] PINNsFormer - (ICLR 2024) [[Paper]](https://arxiv.org/abs/2307.11833) 57 | 58 | ## Results 59 | 60 | We have experimented with 19 different PDE tasks. See [our paper](https://arxiv.org/abs/2405.14369) for the full results. 61 | 62 |

63 | 64 |

65 | Figure 3. Part of experimental results of RoPINN. 66 |

67 | 68 | ## Citation 69 | 70 | If you find this repo useful, please cite our paper. 71 | 72 | ``` 73 | @inproceedings{wu2024ropinn, 74 | title={RoPINN: Region Optimized Physics-Informed Neural Networks}, 75 | author={Haixu Wu and Huakun Luo and Yuezhou Ma and Jianmin Wang and Mingsheng Long}, 76 | booktitle={Advances in Neural Information Processing Systems}, 77 | year={2024} 78 | } 79 | ``` 80 | 81 | ## Contact 82 | 83 | If you have any questions or want to use the code, please contact [wuhx23@mails.tsinghua.edu.cn](mailto:wuhx23@mails.tsinghua.edu.cn). 84 | 85 | ## Acknowledgement 86 | 87 | We appreciate the following GitHub repos a lot for their valuable code base or datasets: 88 | 89 | https://github.com/AdityaLab/pinnsformer 90 | 91 | https://github.com/i207M/PINNacle 92 | -------------------------------------------------------------------------------- /convection_point_optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | import random 5 | from torch.optim import LBFGS 6 | from tqdm import tqdm 7 | import os 8 | import argparse 9 | from util import * 10 | from model_dict import get_model 11 | 12 | seed = 0 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | parser = argparse.ArgumentParser('Training Point Optimization') 19 | parser.add_argument('--model', type=str, default='pinn') 20 | parser.add_argument('--device', type=str, default='cuda:0') 21 | args = parser.parse_args() 22 | device = args.device 23 | 24 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101) 25 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 26 | 27 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 28 | res = make_time_sequence(res, num_step=5, step=1e-4) 29 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 30 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 31 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 32 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 33 | 34 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 35 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 36 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 37 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 38 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 39 | 40 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 41 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 42 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 43 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 44 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 45 | 46 | 47 | def init_weights(m): 48 | if isinstance(m, nn.Linear): 49 | torch.nn.init.xavier_uniform(m.weight) 50 | m.bias.data.fill_(0.0) 51 | 52 | 53 | if args.model == 'KAN': 54 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \ 55 | noise_scale_base=0.25, device=device).to(device) 56 | elif args.model == 'QRes': 57 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=4).to(device) 58 | model.apply(init_weights) 59 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 60 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 61 | model.apply(init_weights) 62 | else: 63 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 64 | model.apply(init_weights) 65 | 66 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 67 | 68 | print(model) 69 | print(get_n_params(model)) 70 | 71 | loss_track = [] 72 | 73 | for i in tqdm(range(1000)): 74 | def closure(): 75 | pred_res = model(x_res, t_res) 76 | pred_left = model(x_left, t_left) 77 | pred_right = model(x_right, t_right) 78 | pred_upper = model(x_upper, t_upper) 79 | pred_lower = model(x_lower, t_lower) 80 | 81 | u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 82 | create_graph=True)[0] 83 | u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, 84 | create_graph=True)[0] 85 | 86 | loss_res = torch.mean((u_t + 50 * u_x) ** 2) 87 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2) 88 | loss_ic = torch.mean((pred_left[:, 0] - torch.sin(x_left[:, 0])) ** 2) 89 | 90 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()]) 91 | 92 | loss = loss_res + loss_bc + loss_ic 93 | optim.zero_grad() 94 | loss.backward() 95 | return loss 96 | 97 | 98 | optim.step(closure) 99 | 100 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 101 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 102 | 103 | if not os.path.exists('./results/'): 104 | os.makedirs('./results/') 105 | 106 | torch.save(model.state_dict(), f'./results/1dconvection_{args.model}_point.pt') 107 | 108 | # Visualize 109 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 110 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 111 | 112 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 113 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 114 | 115 | with torch.no_grad(): 116 | pred = model(x_test, t_test)[:, 0:1] 117 | pred = pred.cpu().detach().numpy() 118 | 119 | pred = pred.reshape(101, 101) 120 | 121 | 122 | def u_res(x, t): 123 | print(x.shape) 124 | print(t.shape) 125 | return np.sin(x - 50 * t) 126 | 127 | 128 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 129 | u = u_res(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 130 | 131 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 132 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 133 | 134 | print('relative L1 error: {:4f}'.format(rl1)) 135 | print('relative L2 error: {:4f}'.format(rl2)) 136 | 137 | plt.figure(figsize=(4, 3)) 138 | plt.imshow(pred, aspect='equal') 139 | plt.xlabel('x') 140 | plt.ylabel('t') 141 | plt.title('Predicted u(x,t)') 142 | plt.colorbar() 143 | plt.tight_layout() 144 | plt.axis('off') 145 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_pred.pdf', bbox_inches='tight') 146 | 147 | plt.figure(figsize=(4, 3)) 148 | plt.imshow(u, aspect='equal') 149 | plt.xlabel('x') 150 | plt.ylabel('t') 151 | plt.title('Exact u(x,t)') 152 | plt.colorbar() 153 | plt.tight_layout() 154 | plt.axis('off') 155 | plt.savefig('./results/1dreaction_exact.pdf', bbox_inches='tight') 156 | 157 | plt.figure(figsize=(4, 3)) 158 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-1, vmax=1) 159 | plt.xlabel('x') 160 | plt.ylabel('t') 161 | plt.title('Absolute Error') 162 | plt.colorbar() 163 | plt.tight_layout() 164 | plt.axis('off') 165 | plt.savefig(f'./results/1dreaction_{args.model}_point_optimization_error.pdf', bbox_inches='tight') 166 | -------------------------------------------------------------------------------- /convection_region_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import random 6 | from torch.optim import LBFGS 7 | from tqdm import tqdm 8 | import argparse 9 | from util import * 10 | from model_dict import get_model 11 | 12 | seed = 0 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | parser = argparse.ArgumentParser('Training Region Optimization') 19 | parser.add_argument('--model', type=str, default='pinn') 20 | parser.add_argument('--device', type=str, default='cuda:0') 21 | parser.add_argument('--initial_region', type=float, default=1e-4) 22 | parser.add_argument('--sample_num', type=int, default=1) 23 | parser.add_argument('--past_iterations', type=int, default=5) 24 | args = parser.parse_args() 25 | device = args.device 26 | 27 | res, b_left, b_right, b_upper, b_lower = get_data([0, 2 * np.pi], [0, 1], 101, 101) 28 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 29 | 30 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 31 | res = make_time_sequence(res, num_step=5, step=1e-4) 32 | b_left = make_time_sequence(b_left, num_step=5, step=1e-4) 33 | b_right = make_time_sequence(b_right, num_step=5, step=1e-4) 34 | b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4) 35 | b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4) 36 | 37 | res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device) 38 | b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device) 39 | b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device) 40 | b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device) 41 | b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device) 42 | 43 | x_res, t_res = res[:, ..., 0:1], res[:, ..., 1:2] 44 | x_left, t_left = b_left[:, ..., 0:1], b_left[:, ..., 1:2] 45 | x_right, t_right = b_right[:, ..., 0:1], b_right[:, ..., 1:2] 46 | x_upper, t_upper = b_upper[:, ..., 0:1], b_upper[:, ..., 1:2] 47 | x_lower, t_lower = b_lower[:, ..., 0:1], b_lower[:, ..., 1:2] 48 | 49 | 50 | def init_weights(m): 51 | if isinstance(m, nn.Linear): 52 | torch.nn.init.xavier_uniform(m.weight) 53 | m.bias.data.fill_(0.0) 54 | 55 | 56 | if args.model == 'KAN': 57 | model = get_model(args).Model(width=[2, 5, 5, 1], grid=5, k=3, grid_eps=1.0, \ 58 | noise_scale_base=0.25, device=device).to(device) 59 | elif args.model == 'QRes': 60 | model = get_model(args).Model(in_dim=2, hidden_dim=256, out_dim=1, num_layer=2).to(device) 61 | model.apply(init_weights) 62 | elif args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 63 | model = get_model(args).Model(in_dim=2, hidden_dim=32, out_dim=1, num_layer=1).to(device) 64 | model.apply(init_weights) 65 | else: 66 | model = get_model(args).Model(in_dim=2, hidden_dim=512, out_dim=1, num_layer=4).to(device) 67 | model.apply(init_weights) 68 | 69 | optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe') 70 | 71 | print(model) 72 | print(get_n_params(model)) 73 | loss_track = [] 74 | 75 | # for region optimization 76 | initial_region = args.initial_region 77 | sample_num = args.sample_num 78 | past_iterations = args.past_iterations 79 | gradient_list_overall = [] 80 | gradient_list_temp = [] 81 | gradient_variance = 1 82 | 83 | for i in tqdm(range(1000)): 84 | 85 | ###### Region Optimization with Monte Carlo Approximation ###### 86 | def closure(): 87 | x_res_region_sample_list = [] 88 | t_res_region_sample_list = [] 89 | for i in range(sample_num): 90 | x_region_sample = (torch.rand(x_res.shape).to(x_res.device)) * np.clip(initial_region / gradient_variance, 91 | a_min=0, 92 | a_max=0.01) 93 | t_region_sample = (torch.rand(x_res.shape).to(t_res.device)) * np.clip(initial_region / gradient_variance, 94 | a_min=0, 95 | a_max=0.01) 96 | x_res_region_sample_list.append(x_res + x_region_sample) 97 | t_res_region_sample_list.append(t_res + t_region_sample) 98 | x_res_region_sample = torch.cat(x_res_region_sample_list, dim=0) 99 | t_res_region_sample = torch.cat(t_res_region_sample_list, dim=0) 100 | pred_res = model(x_res_region_sample, t_res_region_sample) 101 | pred_left = model(x_left, t_left) 102 | pred_right = model(x_right, t_right) 103 | pred_upper = model(x_upper, t_upper) 104 | pred_lower = model(x_lower, t_lower) 105 | 106 | u_x = \ 107 | torch.autograd.grad(pred_res, x_res_region_sample, grad_outputs=torch.ones_like(pred_res), 108 | retain_graph=True, 109 | create_graph=True)[0] 110 | u_t = \ 111 | torch.autograd.grad(pred_res, t_res_region_sample, grad_outputs=torch.ones_like(pred_res), 112 | retain_graph=True, 113 | create_graph=True)[0] 114 | 115 | loss_res = torch.mean((u_t + 50 * u_x) ** 2) 116 | loss_bc = torch.mean((pred_upper - pred_lower) ** 2) 117 | loss_ic = torch.mean((pred_left[:, 0] - torch.sin(x_left[:, 0])) ** 2) 118 | 119 | loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()]) 120 | 121 | loss = loss_res + loss_bc + loss_ic 122 | optim.zero_grad() 123 | loss.backward(retain_graph=True) 124 | gradient_list_temp.append(torch.cat([(p.grad.view(-1)) if p.grad is not None else torch.zeros(1).cuda() for p in 125 | model.parameters()]).cpu().numpy()) # hook gradients from computation graph 126 | return loss 127 | 128 | 129 | optim.step(closure) 130 | 131 | ###### Trust Region Calibration ###### 132 | gradient_list_overall.append(np.mean(np.array(gradient_list_temp), axis=0)) 133 | gradient_list_overall = gradient_list_overall[-past_iterations:] 134 | gradient_list = np.array(gradient_list_overall) 135 | gradient_variance = (np.std(gradient_list, axis=0) / ( 136 | np.mean(np.abs(gradient_list), axis=0) + 1e-6)).mean() # normalized variance 137 | gradient_list_temp = [] 138 | if gradient_variance == 0: 139 | gradient_variance = 1 # for numerical stability 140 | 141 | print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2])) 142 | print('Train Loss: {:4f}'.format(np.sum(loss_track[-1]))) 143 | 144 | if not os.path.exists('./results/'): 145 | os.makedirs('./results/') 146 | torch.save(model.state_dict(), f'./results/convection_{args.model}_region.pt') 147 | 148 | # Visualize 149 | if args.model == 'PINNsFormer' or args.model == 'PINNsFormer_Enc_Only': 150 | res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 151 | 152 | res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device) 153 | x_test, t_test = res_test[:, ..., 0:1], res_test[:, ..., 1:2] 154 | 155 | with torch.no_grad(): 156 | pred = model(x_test, t_test)[:, 0:1] 157 | pred = pred.cpu().detach().numpy() 158 | 159 | pred = pred.reshape(101, 101) 160 | 161 | 162 | def u_res(x, t): 163 | print(x.shape) 164 | print(t.shape) 165 | return np.sin(x - 50 * t) 166 | 167 | 168 | res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], 101, 101) 169 | u = u_res(res_test[:, 0], res_test[:, 1]).reshape(101, 101) 170 | 171 | rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u)) 172 | rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u ** 2)) 173 | 174 | print('relative L1 error: {:4f}'.format(rl1)) 175 | print('relative L2 error: {:4f}'.format(rl2)) 176 | 177 | plt.figure(figsize=(4, 3)) 178 | plt.imshow(pred, aspect='equal') 179 | plt.xlabel('x') 180 | plt.ylabel('t') 181 | plt.title('Predicted u(x,t)') 182 | plt.colorbar() 183 | plt.tight_layout() 184 | plt.axis('off') 185 | plt.savefig(f'./results/convection_{args.model}_region_optimization_pred.pdf', bbox_inches='tight') 186 | 187 | plt.figure(figsize=(4, 3)) 188 | plt.imshow(u, aspect='equal') 189 | plt.xlabel('x') 190 | plt.ylabel('t') 191 | plt.title('Exact u(x,t)') 192 | plt.colorbar() 193 | plt.tight_layout() 194 | plt.axis('off') 195 | plt.savefig('./results/convection_exact.pdf', bbox_inches='tight') 196 | 197 | plt.figure(figsize=(4, 3)) 198 | plt.imshow(pred - u, aspect='equal', cmap='coolwarm', vmin=-1, vmax=1) 199 | plt.xlabel('x') 200 | plt.ylabel('t') 201 | plt.title('Absolute Error') 202 | plt.colorbar() 203 | plt.tight_layout() 204 | plt.axis('off') 205 | plt.savefig(f'./results/convection_{args.model}_region_optimization_error.pdf', bbox_inches='tight') 206 | -------------------------------------------------------------------------------- /model_dict.py: -------------------------------------------------------------------------------- 1 | from models import PINN, QRes, FLS, KAN, PINNsFormer, PINNsFormer_Enc_Only 2 | 3 | 4 | def get_model(args): 5 | model_dict = { 6 | 'PINN': PINN, 7 | 'QRes': QRes, 8 | 'FLS': FLS, 9 | 'KAN': KAN, 10 | 'PINNsFormer': PINNsFormer, 11 | 'PINNsFormer_Enc_Only': PINNsFormer_Enc_Only, # more efficient and with better performance than original PINNsFormer 12 | } 13 | return model_dict[args.model] -------------------------------------------------------------------------------- /models/FLS.py: -------------------------------------------------------------------------------- 1 | # baseline implementation of First Layer Sine 2 | # paper: Learning in Sinusoidal Spaces with Physics-Informed Neural Networks 3 | # link: https://arxiv.org/abs/2109.09338 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SinAct(nn.Module): 10 | def __init__(self): 11 | super(SinAct, self).__init__() 12 | 13 | def forward(self, x): 14 | return torch.sin(x) 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer): 19 | super(Model, self).__init__() 20 | 21 | layers = [] 22 | for i in range(num_layer - 1): 23 | if i == 0: 24 | layers.append(nn.Linear(in_features=in_dim, out_features=hidden_dim)) 25 | layers.append(SinAct()) 26 | else: 27 | layers.append(nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) 28 | layers.append(nn.Tanh()) 29 | 30 | layers.append(nn.Linear(in_features=hidden_dim, out_features=out_dim)) 31 | 32 | self.linear = nn.Sequential(*layers) 33 | 34 | def forward(self, x, t): 35 | src = torch.cat((x, t), dim=-1) 36 | return self.linear(src) 37 | -------------------------------------------------------------------------------- /models/KAN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .kan_layer import * 3 | from .Symbolic_KANLayer import * 4 | from .LBFGS import * 5 | import os 6 | import glob 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | import random 10 | import copy 11 | 12 | 13 | class Model(nn.Module): 14 | ''' 15 | KAN class 16 | 17 | Attributes: 18 | ----------- 19 | biases: a list of nn.Linear() 20 | biases are added on nodes (in principle, biases can be absorbed into activation functions. However, we still have them for better optimization) 21 | act_fun: a list of KANLayer 22 | KANLayers 23 | depth: int 24 | depth of KAN 25 | width: list 26 | number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons. 27 | grid: int 28 | the number of grid intervals 29 | k: int 30 | the order of piecewise polynomial 31 | base_fun: fun 32 | residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x) 33 | symbolic_fun: a list of Symbolic_KANLayer 34 | Symbolic_KANLayers 35 | symbolic_enabled: bool 36 | If False, the symbolic front is not computed (to save time). Default: True. 37 | 38 | Methods: 39 | -------- 40 | __init__(): 41 | initialize a KAN 42 | initialize_from_another_model(): 43 | initialize a KAN from another KAN (with the same shape, but potentially different grids) 44 | update_grid_from_samples(): 45 | update spline grids based on samples 46 | initialize_grid_from_another_model(): 47 | initalize KAN grids from another KAN 48 | forward(): 49 | forward 50 | set_mode(): 51 | set the mode of an activation function: 'n' for numeric, 's' for symbolic, 'ns' for combined (note they are visualized differently in plot(). 'n' as black, 's' as red, 'ns' as purple). 52 | fix_symbolic(): 53 | fix an activation function to be symbolic 54 | suggest_symbolic(): 55 | suggest the symbolic candicates of a numeric spline-based activation function 56 | lock(): 57 | lock activation functions to share parameters 58 | unlock(): 59 | unlock locked activations 60 | get_range(): 61 | get the input and output ranges of an activation function 62 | plot(): 63 | plot the diagram of KAN 64 | train(): 65 | train KAN 66 | prune(): 67 | prune KAN 68 | remove_edge(): 69 | remove some edge of KAN 70 | remove_node(): 71 | remove some node of KAN 72 | auto_symbolic(): 73 | automatically fit all splines to be symbolic functions 74 | symbolic_formula(): 75 | obtain the symbolic formula of the KAN network 76 | ''' 77 | 78 | def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(), 79 | symbolic_enabled=True, bias_trainable=True, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, 80 | sb_trainable=True, 81 | device='cpu', seed=0): 82 | ''' 83 | initalize a KAN model 84 | 85 | Args: 86 | ----- 87 | width : list of int 88 | :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs) 89 | grid : int 90 | number of grid intervals. Default: 3. 91 | k : int 92 | order of piecewise polynomial. Default: 3. 93 | noise_scale : float 94 | initial injected noise to spline. Default: 0.1. 95 | base_fun : fun 96 | the residual function b(x). Default: torch.nn.SiLU(). 97 | symbolic_enabled : bool 98 | compute or skip symbolic computations (for efficiency). By default: True. 99 | bias_trainable : bool 100 | bias parameters are updated or not. By default: True 101 | grid_eps : float 102 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02. 103 | grid_range : list/np.array of shape (2,)) 104 | setting the range of grids. Default: [-1,1]. 105 | sp_trainable : bool 106 | If true, scale_sp is trainable. Default: True. 107 | sb_trainable : bool 108 | If true, scale_base is trainable. Default: True. 109 | device : str 110 | device 111 | seed : int 112 | random seed 113 | 114 | Returns: 115 | -------- 116 | self 117 | 118 | Example 119 | ------- 120 | >>> model = Model(width=[2,5,1], grid=5, k=3) 121 | >>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), (model.act_fun[1].in_dim, model.act_fun[1].out_dim) 122 | ((2, 5), (5, 1)) 123 | ''' 124 | super(Model, self).__init__() 125 | 126 | torch.manual_seed(seed) 127 | np.random.seed(seed) 128 | random.seed(seed) 129 | 130 | ### initializeing the numerical front ### 131 | 132 | self.biases = [] 133 | self.act_fun = [] 134 | self.depth = len(width) - 1 135 | self.width = width 136 | 137 | for l in range(self.depth): 138 | # splines 139 | scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base 140 | sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, 141 | scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, 142 | grid_range=grid_range, sp_trainable=sp_trainable, 143 | sb_trainable=sb_trainable, device=device) 144 | self.act_fun.append(sp_batch) 145 | 146 | # bias 147 | bias = nn.Linear(width[l + 1], 1, bias=False).requires_grad_(bias_trainable) 148 | bias.weight.data *= 0. 149 | self.biases.append(bias) 150 | 151 | self.biases = nn.ModuleList(self.biases) 152 | self.act_fun = nn.ModuleList(self.act_fun) 153 | 154 | self.grid = grid 155 | self.k = k 156 | self.base_fun = base_fun 157 | 158 | ### initializing the symbolic front ### 159 | self.symbolic_fun = [] 160 | for l in range(self.depth): 161 | sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1]) 162 | self.symbolic_fun.append(sb_batch) 163 | 164 | self.symbolic_fun = nn.ModuleList(self.symbolic_fun) 165 | self.symbolic_enabled = symbolic_enabled 166 | 167 | def initialize_from_another_model(self, another_model, x): 168 | ''' 169 | initialize from a parent model. The parent has the same width as the current model but may have different grids. 170 | 171 | Args: 172 | ----- 173 | another_model : KAN 174 | the parent model used to initialize the current model 175 | x : 2D torch.float 176 | inputs, shape (batch, input dimension) 177 | 178 | Returns: 179 | -------- 180 | self : KAN 181 | 182 | Example 183 | ------- 184 | >>> model_coarse = KAN(width=[2,5,1], grid=5, k=3) 185 | >>> model_fine = KAN(width=[2,5,1], grid=10, k=3) 186 | >>> print(model_fine.act_fun[0].coef[0][0].data) 187 | >>> x = torch.normal(0,1,size=(100,2)) 188 | >>> model_fine.initialize_from_another_model(model_coarse, x); 189 | >>> print(model_fine.act_fun[0].coef[0][0].data) 190 | tensor(-0.0030) 191 | tensor(0.0506) 192 | ''' 193 | another_model(x) # get activations 194 | batch = x.shape[0] 195 | 196 | self.initialize_grid_from_another_model(another_model, x) 197 | 198 | for l in range(self.depth): 199 | spb = self.act_fun[l] 200 | spb_parent = another_model.act_fun[l] 201 | 202 | # spb = spb_parent 203 | preacts = another_model.spline_preacts[l] 204 | postsplines = another_model.spline_postsplines[l] 205 | self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), 206 | postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, 207 | k=spb.k) 208 | spb.scale_base.data = spb_parent.scale_base.data 209 | spb.scale_sp.data = spb_parent.scale_sp.data 210 | spb.mask.data = spb_parent.mask.data 211 | # print(spb.mask.data, self.act_fun[l].mask.data) 212 | 213 | for l in range(self.depth): 214 | self.biases[l].weight.data = another_model.biases[l].weight.data 215 | 216 | for l in range(self.depth): 217 | self.symbolic_fun[l] = another_model.symbolic_fun[l] 218 | 219 | return self 220 | 221 | def update_grid_from_samples(self, x): 222 | ''' 223 | update grid from samples 224 | 225 | Args: 226 | ----- 227 | x : 2D torch.float 228 | inputs, shape (batch, input dimension) 229 | 230 | Returns: 231 | -------- 232 | None 233 | 234 | Example 235 | ------- 236 | >>> model = KAN(width=[2,5,1], grid=5, k=3) 237 | >>> print(model.act_fun[0].grid[0].data) 238 | >>> x = torch.rand(100,2)*5 239 | >>> model.update_grid_from_samples(x) 240 | >>> print(model.act_fun[0].grid[0].data) 241 | tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]) 242 | tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809]) 243 | ''' 244 | for l in range(self.depth): 245 | self.forward(x) 246 | self.act_fun[l].update_grid_from_samples(self.acts[l]) 247 | 248 | def initialize_grid_from_another_model(self, model, x): 249 | ''' 250 | initialize grid from a parent model 251 | 252 | Args: 253 | ----- 254 | model : KAN 255 | parent model 256 | x : 2D torch.float 257 | inputs, shape (batch, input dimension) 258 | 259 | Returns: 260 | -------- 261 | None 262 | 263 | Example 264 | ------- 265 | >>> model_parent = KAN(width=[1,1], grid=5, k=3) 266 | >>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)[None,:] 267 | >>> x = torch.linspace(-2,2,steps=1001)[:,None] 268 | >>> model = KAN(width=[1,1], grid=5, k=3) 269 | >>> print(model.act_fun[0].grid.data) 270 | >>> model = model.initialize_from_another_model(model_parent, x) 271 | >>> print(model.act_fun[0].grid.data) 272 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 273 | tensor([[-2.0000, -1.2000, -0.4000, 0.4000, 1.2000, 2.0000]]) 274 | ''' 275 | model(x) 276 | for l in range(self.depth): 277 | self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l]) 278 | 279 | def forward(self, x_res, t_res): 280 | ''' 281 | KAN forward 282 | 283 | Args: 284 | ----- 285 | x : 2D torch.float 286 | inputs, shape (batch, input dimension) 287 | 288 | Returns: 289 | -------- 290 | y : 2D torch.float 291 | outputs, shape (batch, output dimension) 292 | 293 | Example 294 | ------- 295 | >>> model = KAN(width=[2,5,3], grid=5, k=3) 296 | >>> x = torch.normal(0,1,size=(100,2)) 297 | >>> model(x).shape 298 | torch.Size([100, 3]) 299 | ''' 300 | x = torch.cat([x_res, t_res], dim=-1) 301 | self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L]) 302 | self.spline_preacts = [] 303 | self.spline_postsplines = [] 304 | self.spline_postacts = [] 305 | self.acts_scale = [] 306 | self.acts_scale_std = [] 307 | # self.neurons_scale = [] 308 | 309 | self.acts.append(x) # acts shape: (batch, width[l]) 310 | 311 | for l in range(self.depth): 312 | 313 | x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x) 314 | 315 | if self.symbolic_enabled == True: 316 | x_symbolic, postacts_symbolic = self.symbolic_fun[l](x) 317 | else: 318 | x_symbolic = 0. 319 | postacts_symbolic = 0. 320 | 321 | x = x_numerical + x_symbolic 322 | postacts = postacts_numerical + postacts_symbolic 323 | 324 | # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0)) 325 | grid_reshape = self.act_fun[l].grid.reshape(self.width[l + 1], self.width[l], -1) 326 | input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4 327 | output_range = torch.mean(torch.abs(postacts), dim=0) 328 | self.acts_scale.append(output_range / input_range) 329 | self.acts_scale_std.append(torch.std(postacts, dim=0)) 330 | self.spline_preacts.append(preacts.detach()) 331 | self.spline_postacts.append(postacts.detach()) 332 | self.spline_postsplines.append(postspline.detach()) 333 | 334 | x = x + self.biases[l].weight 335 | self.acts.append(x) 336 | 337 | return x 338 | 339 | def set_mode(self, l, i, j, mode, mask_n=None): 340 | ''' 341 | set (l,i,j) activation to have mode 342 | 343 | Args: 344 | ----- 345 | l : int 346 | layer index 347 | i : int 348 | input neuron index 349 | j : int 350 | output neuron index 351 | mode : str 352 | 'n' (numeric) or 's' (symbolic) or 'ns' (combined) 353 | mask_n : None or float) 354 | magnitude of the numeric front 355 | 356 | Returns: 357 | -------- 358 | None 359 | ''' 360 | if mode == "s": 361 | mask_n = 0.; 362 | mask_s = 1. 363 | elif mode == "n": 364 | mask_n = 1.; 365 | mask_s = 0. 366 | elif mode == "sn" or mode == "ns": 367 | if mask_n == None: 368 | mask_n = 1. 369 | else: 370 | mask_n = mask_n 371 | mask_s = 1. 372 | else: 373 | mask_n = 0.; 374 | mask_s = 0. 375 | 376 | self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n 377 | self.symbolic_fun[l].mask.data[j, i] = mask_s 378 | 379 | def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, 380 | random=False): 381 | ''' 382 | set (l,i,j) activation to be symbolic (specified by fun_name) 383 | 384 | Args: 385 | ----- 386 | l : int 387 | layer index 388 | i : int 389 | input neuron index 390 | j : int 391 | output neuron index 392 | fun_name : str 393 | function name 394 | fit_params_bool : bool 395 | obtaining affine parameters through fitting (True) or setting default values (False) 396 | a_range : tuple 397 | sweeping range of a 398 | b_range : tuple 399 | sweeping range of b 400 | verbose : bool 401 | If True, more information is printed. 402 | random : bool 403 | initialize affine parameteres randomly or as [1,0,1,0] 404 | 405 | Returns: 406 | -------- 407 | None or r2 (coefficient of determination) 408 | 409 | Example 1 410 | --------- 411 | >>> # when fit_params_bool = False 412 | >>> model = KAN(width=[2,5,1], grid=5, k=3) 413 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False) 414 | >>> print(model.act_fun[0].mask.reshape(2,5)) 415 | >>> print(model.symbolic_fun[0].mask.reshape(2,5)) 416 | tensor([[1., 1., 1., 1., 1.], 417 | [1., 1., 0., 1., 1.]]) 418 | tensor([[0., 0., 0., 0., 0.], 419 | [0., 0., 1., 0., 0.]]) 420 | 421 | Example 2 422 | --------- 423 | >>> # when fit_params_bool = True 424 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) 425 | >>> x = torch.normal(0,1,size=(100,2)) 426 | >>> model(x) # obtain activations (otherwise model does not have attributes acts) 427 | >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True) 428 | >>> print(model.act_fun[0].mask.reshape(2,5)) 429 | >>> print(model.symbolic_fun[0].mask.reshape(2,5)) 430 | r2 is 0.8131332993507385 431 | r2 is not very high, please double check if you are choosing the correct symbolic function. 432 | tensor([[1., 1., 1., 1., 1.], 433 | [1., 1., 0., 1., 1.]]) 434 | tensor([[0., 0., 0., 0., 0.], 435 | [0., 0., 1., 0., 0.]]) 436 | ''' 437 | self.set_mode(l, i, j, mode="s") 438 | if not fit_params_bool: 439 | self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random) 440 | return None 441 | else: 442 | x = self.acts[l][:, i] 443 | y = self.spline_postacts[l][:, j, i] 444 | r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range, 445 | verbose=verbose) 446 | return r2 447 | 448 | def unfix_symbolic(self, l, i, j): 449 | ''' 450 | unfix the (l,i,j) activation function. 451 | ''' 452 | self.set_mode(l, i, j, mode="n") 453 | 454 | def unfix_symbolic_all(self): 455 | ''' 456 | unfix all activation functions. 457 | ''' 458 | for l in range(len(self.width) - 1): 459 | for i in range(self.width[l]): 460 | for j in range(self.width[l + 1]): 461 | self.unfix_symbolic(l, i, j) 462 | 463 | def lock(self, l, ids): 464 | ''' 465 | lock ids in the l-th layer to be the same function 466 | 467 | Args: 468 | ----- 469 | l : int 470 | layer index 471 | ids : 2D list 472 | :math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), (l,i_2,j_2), ...` to be the same function 473 | 474 | Returns: 475 | -------- 476 | None 477 | 478 | Example 479 | ------- 480 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 481 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 482 | >>> model.lock(0,[[1,0],[1,1]]) 483 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 484 | tensor([[0, 1], 485 | [2, 3], 486 | [4, 5]]) 487 | tensor([[0, 1], 488 | [2, 1], 489 | [4, 5]]) 490 | ''' 491 | self.act_fun[l].lock(ids) 492 | 493 | def unlock(self, l, ids): 494 | ''' 495 | unlock ids in the l-th layer to be the same function 496 | 497 | Args: 498 | ----- 499 | l : int 500 | layer index 501 | ids : 2D list) 502 | [[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be unlocked 503 | 504 | Example: 505 | -------- 506 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 507 | >>> model.lock(0,[[1,0],[1,1]]) 508 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 509 | >>> model.unlock(0,[[1,0],[1,1]]) 510 | >>> print(model.act_fun[0].weight_sharing.reshape(3,2)) 511 | tensor([[0, 1], 512 | [2, 1], 513 | [4, 5]]) 514 | tensor([[0, 1], 515 | [2, 3], 516 | [4, 5]]) 517 | ''' 518 | self.act_fun[l].unlock(ids) 519 | 520 | def get_range(self, l, i, j, verbose=True): 521 | ''' 522 | Get the input range and output range of the (l,i,j) activation 523 | 524 | Args: 525 | ----- 526 | l : int 527 | layer index 528 | i : int 529 | input neuron index 530 | j : int 531 | output neuron index 532 | 533 | Returns: 534 | -------- 535 | x_min : float 536 | minimum of input 537 | x_max : float 538 | maximum of input 539 | y_min : float 540 | minimum of output 541 | y_max : float 542 | maximum of output 543 | 544 | Example 545 | ------- 546 | >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) 547 | >>> x = torch.normal(0,1,size=(100,2)) 548 | >>> model(x) # do a forward pass to obtain model.acts 549 | >>> model.get_range(0,0,0) 550 | x range: [-2.13 , 2.75 ] 551 | y range: [-0.50 , 1.83 ] 552 | (tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275)) 553 | ''' 554 | x = self.spline_preacts[l][:, j, i] 555 | y = self.spline_postacts[l][:, j, i] 556 | x_min = torch.min(x) 557 | x_max = torch.max(x) 558 | y_min = torch.min(y) 559 | y_max = torch.max(y) 560 | if verbose: 561 | print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']') 562 | print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']') 563 | return x_min, x_max, y_min, y_max 564 | 565 | def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, 566 | in_vars=None, out_vars=None, title=None): 567 | ''' 568 | plot KAN 569 | 570 | Args: 571 | ----- 572 | folder : str 573 | the folder to store pngs 574 | beta : float 575 | positive number. control the transparency of each activation. transparency = tanh(beta*l1). 576 | mask : bool 577 | If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions. 578 | mode : bool 579 | "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean). 580 | scale : float 581 | control the size of the diagram 582 | in_vars: None or list of str 583 | the name(s) of input variables 584 | out_vars: None or list of str 585 | the name(s) of output variables 586 | title: None or str 587 | title 588 | 589 | Returns: 590 | -------- 591 | Figure 592 | 593 | Example 594 | ------- 595 | >>> # see more interactive examples in demos 596 | >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) 597 | >>> x = torch.normal(0,1,size=(100,2)) 598 | >>> model(x) # do a forward pass to obtain model.acts 599 | >>> model.plot() 600 | ''' 601 | if not os.path.exists(folder): 602 | os.makedirs(folder) 603 | # matplotlib.use('Agg') 604 | depth = len(self.width) - 1 605 | for l in range(depth): 606 | w_large = 2.0 607 | for i in range(self.width[l]): 608 | for j in range(self.width[l + 1]): 609 | rank = torch.argsort(self.acts[l][:, i]) 610 | fig, ax = plt.subplots(figsize=(w_large, w_large)) 611 | 612 | num = rank.shape[0] 613 | 614 | symbol_mask = self.symbolic_fun[l].mask[j][i] 615 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i] 616 | if symbol_mask > 0. and numerical_mask > 0.: 617 | color = 'purple' 618 | alpha_mask = 1 619 | if symbol_mask > 0. and numerical_mask == 0.: 620 | color = "red" 621 | alpha_mask = 1 622 | if symbol_mask == 0. and numerical_mask > 0.: 623 | color = "black" 624 | alpha_mask = 1 625 | if symbol_mask == 0. and numerical_mask == 0.: 626 | color = "white" 627 | alpha_mask = 0 628 | 629 | if tick == True: 630 | ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50) 631 | ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50) 632 | x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False) 633 | plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max]) 634 | plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max]) 635 | else: 636 | plt.xticks([]) 637 | plt.yticks([]) 638 | if alpha_mask == 1: 639 | plt.gca().patch.set_edgecolor('black') 640 | else: 641 | plt.gca().patch.set_edgecolor('white') 642 | plt.gca().patch.set_linewidth(1.5) 643 | # plt.axis('off') 644 | 645 | plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(), 646 | self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5) 647 | if sample == True: 648 | plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(), 649 | self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, 650 | s=400 * scale ** 2) 651 | plt.gca().spines[:].set_color(color) 652 | 653 | lock_id = self.act_fun[l].lock_id[j * self.width[l] + i].long().item() 654 | if lock_id > 0: 655 | im = plt.imread(f'{folder}/lock.png') 656 | newax = fig.add_axes([0.15, 0.7, 0.15, 0.15]) 657 | plt.text(500, 400, lock_id, fontsize=15) 658 | newax.imshow(im) 659 | newax.axis('off') 660 | 661 | plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400) 662 | plt.close() 663 | 664 | def score2alpha(score): 665 | return np.tanh(beta * score) 666 | 667 | if mode == "supervised": 668 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale] 669 | elif mode == "unsupervised": 670 | alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale_std] 671 | 672 | # draw skeleton 673 | width = np.array(self.width) 674 | A = 1 675 | y0 = 0.4 # 0.4 676 | 677 | # plt.figure(figsize=(5,5*(neuron_depth-1)*y0)) 678 | neuron_depth = len(width) 679 | min_spacing = A / np.maximum(np.max(width), 5) 680 | 681 | max_neuron = np.max(width) 682 | max_num_weights = np.max(width[:-1] * width[1:]) 683 | y1 = 0.4 / np.maximum(max_num_weights, 3) 684 | 685 | fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0)) 686 | # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0)) 687 | 688 | # plot scatters and lines 689 | for l in range(neuron_depth): 690 | n = width[l] 691 | spacing = A / n 692 | for i in range(n): 693 | plt.scatter(1 / (2 * n) + i / n, l * y0, s=min_spacing ** 2 * 10000 * scale ** 2, color='black') 694 | 695 | if l < neuron_depth - 1: 696 | # plot connections 697 | n_next = width[l + 1] 698 | N = n * n_next 699 | for j in range(n_next): 700 | id_ = i * n_next + j 701 | 702 | symbol_mask = self.symbolic_fun[l].mask[j][i] 703 | numerical_mask = self.act_fun[l].mask.reshape(self.width[l + 1], self.width[l])[j][i] 704 | if symbol_mask == 1. and numerical_mask == 1.: 705 | color = 'purple' 706 | alpha_mask = 1. 707 | if symbol_mask == 1. and numerical_mask == 0.: 708 | color = "red" 709 | alpha_mask = 1. 710 | if symbol_mask == 0. and numerical_mask == 1.: 711 | color = "black" 712 | alpha_mask = 1. 713 | if symbol_mask == 0. and numerical_mask == 0.: 714 | color = "white" 715 | alpha_mask = 0. 716 | if mask == True: 717 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], 718 | color=color, lw=2 * scale, 719 | alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 720 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], 721 | [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, 722 | alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 723 | else: 724 | plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], 725 | color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) 726 | plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], 727 | [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, 728 | alpha=alpha[l][j][i] * alpha_mask) 729 | 730 | plt.xlim(0, 1) 731 | plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0) 732 | 733 | # -- Transformation functions 734 | DC_to_FC = ax.transData.transform 735 | FC_to_NFC = fig.transFigure.inverted().transform 736 | # -- Take data coordinates and transform them to normalized figure coordinates 737 | DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x)) 738 | 739 | plt.axis('off') 740 | 741 | # plot splines 742 | for l in range(neuron_depth - 1): 743 | n = width[l] 744 | for i in range(n): 745 | n_next = width[l + 1] 746 | N = n * n_next 747 | for j in range(n_next): 748 | id_ = i * n_next + j 749 | im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png') 750 | left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0] 751 | right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0] 752 | bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1] 753 | up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1] 754 | newax = fig.add_axes([left, bottom, right - left, up - bottom]) 755 | # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE') 756 | if mask == False: 757 | newax.imshow(im, alpha=alpha[l][j][i]) 758 | else: 759 | ### make sure to run model.prune() first to compute mask ### 760 | newax.imshow(im, alpha=alpha[l][j][i] * self.mask[l][i].item() * self.mask[l + 1][j].item()) 761 | newax.axis('off') 762 | 763 | if in_vars != None: 764 | n = self.width[0] 765 | for i in range(n): 766 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale, 767 | horizontalalignment='center', verticalalignment='center') 768 | 769 | if out_vars != None: 770 | n = self.width[-1] 771 | for i in range(n): 772 | plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), y0 * (len(self.width) - 1) + 0.1, out_vars[i], 773 | fontsize=40 * scale, horizontalalignment='center', 774 | verticalalignment='center') 775 | 776 | if title != None: 777 | plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale, 778 | horizontalalignment='center', verticalalignment='center') 779 | 780 | def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., 781 | lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, 782 | batch=-1, 783 | small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, 784 | in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'): 785 | ''' 786 | training 787 | 788 | Args: 789 | ----- 790 | dataset : dic 791 | contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] 792 | opt : str 793 | "LBFGS" or "Adam" 794 | steps : int 795 | training steps 796 | log : int 797 | logging frequency 798 | lamb : float 799 | overall penalty strength 800 | lamb_l1 : float 801 | l1 penalty strength 802 | lamb_entropy : float 803 | entropy penalty strength 804 | lamb_coef : float 805 | coefficient magnitude penalty strength 806 | lamb_coefdiff : float 807 | difference of nearby coefficits (smoothness) penalty strength 808 | update_grid : bool 809 | If True, update grid regularly before stop_grid_update_step 810 | grid_update_num : int 811 | the number of grid updates before stop_grid_update_step 812 | stop_grid_update_step : int 813 | no grid updates after this training step 814 | batch : int 815 | batch size, if -1 then full. 816 | small_mag_threshold : float 817 | threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers) 818 | small_reg_factor : float 819 | penalty strength applied to small factors relative to large factos 820 | device : str 821 | device 822 | save_fig_freq : int 823 | save figure every (save_fig_freq) step 824 | 825 | Returns: 826 | -------- 827 | results : dic 828 | results['train_loss'], 1D array of training losses (RMSE) 829 | results['test_loss'], 1D array of test losses (RMSE) 830 | results['reg'], 1D array of regularization 831 | 832 | Example 833 | ------- 834 | >>> # for interactive examples, please see demos 835 | >>> from utils import create_dataset 836 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 837 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 838 | >>> dataset = create_dataset(f, n_var=2) 839 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 840 | >>> model.plot() 841 | ''' 842 | 843 | def reg(acts_scale): 844 | 845 | def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): 846 | return (x < th) * x * factor + (x > th) * (x + (factor - 1) * th) 847 | 848 | reg_ = 0. 849 | for i in range(len(acts_scale)): 850 | vec = acts_scale[i].reshape(-1, ) 851 | 852 | p = vec / torch.sum(vec) 853 | l1 = torch.sum(nonlinear(vec)) 854 | entropy = - torch.sum(p * torch.log2(p + 1e-4)) 855 | reg_ += lamb_l1 * l1 + lamb_entropy * entropy # both l1 and entropy 856 | 857 | # regularize coefficient to encourage spline to be zero 858 | for i in range(len(self.act_fun)): 859 | coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1)) 860 | coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1)) 861 | reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1 862 | 863 | return reg_ 864 | 865 | pbar = tqdm(range(steps), desc='description', ncols=100) 866 | 867 | if loss_fn == None: 868 | loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) 869 | else: 870 | loss_fn = loss_fn_eval = loss_fn 871 | 872 | grid_update_freq = int(stop_grid_update_step / grid_update_num) 873 | 874 | if opt == "Adam": 875 | optimizer = torch.optim.Adam(self.parameters(), lr=lr) 876 | elif opt == "LBFGS": 877 | optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", 878 | tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) 879 | 880 | results = {} 881 | results['train_loss'] = [] 882 | results['test_loss'] = [] 883 | results['reg'] = [] 884 | if metrics != None: 885 | for i in range(len(metrics)): 886 | results[metrics[i].__name__] = [] 887 | 888 | if batch == -1 or batch > dataset['train_input'].shape[0]: 889 | batch_size = dataset['train_input'].shape[0] 890 | batch_size_test = dataset['test_input'].shape[0] 891 | else: 892 | batch_size = batch 893 | batch_size_test = batch 894 | 895 | global train_loss, reg_ 896 | 897 | def closure(): 898 | global train_loss, reg_ 899 | optimizer.zero_grad() 900 | pred = self.forward(dataset['train_input'][train_id].to(device)) 901 | if sglr_avoid == True: 902 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0] 903 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device)) 904 | else: 905 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device)) 906 | reg_ = reg(self.acts_scale) 907 | objective = train_loss + lamb * reg_ 908 | objective.backward() 909 | return objective 910 | 911 | if save_fig: 912 | if not os.path.exists(img_folder): 913 | os.makedirs(img_folder) 914 | 915 | for _ in pbar: 916 | 917 | train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) 918 | test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) 919 | 920 | if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid: 921 | self.update_grid_from_samples(dataset['train_input'][train_id].to(device)) 922 | 923 | if opt == "LBFGS": 924 | optimizer.step(closure) 925 | 926 | if opt == "Adam": 927 | pred = self.forward(dataset['train_input'][train_id].to(device)) 928 | if sglr_avoid == True: 929 | id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0] 930 | train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device)) 931 | else: 932 | train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device)) 933 | reg_ = reg(self.acts_scale) 934 | loss = train_loss + lamb * reg_ 935 | optimizer.zero_grad() 936 | loss.backward() 937 | optimizer.step() 938 | 939 | test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), 940 | dataset['test_label'][test_id].to(device)) 941 | 942 | if _ % log == 0: 943 | pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % ( 944 | torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), 945 | reg_.cpu().detach().numpy())) 946 | 947 | if metrics != None: 948 | for i in range(len(metrics)): 949 | results[metrics[i].__name__].append(metrics[i]().item()) 950 | 951 | results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) 952 | results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) 953 | results['reg'].append(reg_.cpu().detach().numpy()) 954 | 955 | if save_fig and _ % save_fig_freq == 0: 956 | self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta) 957 | plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200) 958 | plt.close() 959 | 960 | return results 961 | 962 | def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None): 963 | ''' 964 | pruning KAN on the node level. If a node has small incoming or outgoing connection, it will be pruned away. 965 | 966 | Args: 967 | ----- 968 | threshold : float 969 | the threshold used to determine whether a node is small enough 970 | mode : str 971 | "auto" or "manual". If "auto", the thresold will be used to automatically prune away nodes. If "manual", active_neuron_id is needed to specify which neurons are kept (others are thrown away). 972 | active_neuron_id : list of id lists 973 | For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet. 974 | 975 | Returns: 976 | -------- 977 | model2 : KAN 978 | pruned model 979 | 980 | Example 981 | ------- 982 | >>> # for more interactive examples, please see demos 983 | >>> from utils import create_dataset 984 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 985 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 986 | >>> dataset = create_dataset(f, n_var=2) 987 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 988 | >>> model.prune() 989 | >>> model.plot(mask=True) 990 | ''' 991 | mask = [torch.ones(self.width[0], )] 992 | active_neurons = [list(range(self.width[0]))] 993 | for i in range(len(self.acts_scale) - 1): 994 | if mode == "auto": 995 | in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold 996 | out_important = torch.max(self.acts_scale[i + 1], dim=0)[0] > threshold 997 | overall_important = in_important * out_important 998 | elif mode == "manual": 999 | overall_important = torch.zeros(self.width[i + 1], dtype=torch.bool) 1000 | overall_important[active_neurons_id[i + 1]] = True 1001 | mask.append(overall_important.float()) 1002 | active_neurons.append(torch.where(overall_important == True)[0]) 1003 | active_neurons.append(list(range(self.width[-1]))) 1004 | mask.append(torch.ones(self.width[-1], )) 1005 | 1006 | self.mask = mask # this is neuron mask for the whole model 1007 | 1008 | # update act_fun[l].mask 1009 | for l in range(len(self.acts_scale) - 1): 1010 | for i in range(self.width[l + 1]): 1011 | if i not in active_neurons[l + 1]: 1012 | self.remove_node(l + 1, i) 1013 | 1014 | model2 = KAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun) 1015 | model2.load_state_dict(self.state_dict()) 1016 | for i in range(len(self.acts_scale)): 1017 | if i < len(self.acts_scale) - 1: 1018 | model2.biases[i].weight.data = model2.biases[i].weight.data[:, active_neurons[i + 1]] 1019 | 1020 | model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i], active_neurons[i + 1]) 1021 | model2.width[i] = len(active_neurons[i]) 1022 | model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons[i], active_neurons[i + 1]) 1023 | 1024 | return model2 1025 | 1026 | def remove_edge(self, l, i, j): 1027 | ''' 1028 | remove activtion phi(l,i,j) (set its mask to zero) 1029 | 1030 | Args: 1031 | ----- 1032 | l : int 1033 | layer index 1034 | i : int 1035 | input neuron index 1036 | j : int 1037 | output neuron index 1038 | 1039 | Returns: 1040 | -------- 1041 | None 1042 | ''' 1043 | self.act_fun[l].mask[j * self.width[l] + i] = 0. 1044 | 1045 | def remove_node(self, l, i): 1046 | ''' 1047 | remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero) 1048 | 1049 | Args: 1050 | ----- 1051 | l : int 1052 | layer index 1053 | i : int 1054 | neuron index 1055 | 1056 | Returns: 1057 | -------- 1058 | None 1059 | ''' 1060 | self.act_fun[l - 1].mask[i * self.width[l - 1] + torch.arange(self.width[l - 1])] = 0. 1061 | self.act_fun[l].mask[torch.arange(self.width[l + 1]) * self.width[l] + i] = 0. 1062 | self.symbolic_fun[l - 1].mask[i, :] *= 0. 1063 | self.symbolic_fun[l].mask[:, i] *= 0. 1064 | 1065 | def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True): 1066 | '''suggest the symbolic candidates of phi(l,i,j) 1067 | 1068 | Args: 1069 | ----- 1070 | l : int 1071 | layer index 1072 | i : int 1073 | input neuron index 1074 | j : int 1075 | output neuron index 1076 | lib : dic 1077 | library of symbolic bases. If lib = None, the global default library will be used. 1078 | topk : int 1079 | display the top k symbolic functions (according to r2) 1080 | verbose : bool 1081 | If True, more information will be printed. 1082 | 1083 | Returns: 1084 | -------- 1085 | None 1086 | 1087 | Example 1088 | ------- 1089 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1090 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1091 | >>> dataset = create_dataset(f, n_var=2) 1092 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1093 | >>> model = model.prune() 1094 | >>> model(dataset['train_input']) 1095 | >>> model.suggest_symbolic(0,0,0) 1096 | function , r2 1097 | sin , 0.9994412064552307 1098 | gaussian , 0.9196369051933289 1099 | tanh , 0.8608126044273376 1100 | sigmoid , 0.8578218817710876 1101 | arctan , 0.842217743396759 1102 | ''' 1103 | r2s = [] 1104 | 1105 | if lib == None: 1106 | symbolic_lib = SYMBOLIC_LIB 1107 | else: 1108 | symbolic_lib = {} 1109 | for item in lib: 1110 | symbolic_lib[item] = SYMBOLIC_LIB[item] 1111 | 1112 | for (name, fun) in symbolic_lib.items(): 1113 | r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False) 1114 | r2s.append(r2.item()) 1115 | 1116 | self.unfix_symbolic(l, i, j) 1117 | 1118 | sorted_ids = np.argsort(r2s)[::-1][:topk] 1119 | r2s = np.array(r2s)[sorted_ids][:topk] 1120 | topk = np.minimum(topk, len(symbolic_lib)) 1121 | if verbose == True: 1122 | print('function', ',', 'r2') 1123 | for i in range(topk): 1124 | print(list(symbolic_lib.items())[sorted_ids[i]][0], ',', r2s[i]) 1125 | 1126 | best_name = list(symbolic_lib.items())[sorted_ids[0]][0] 1127 | best_fun = list(symbolic_lib.items())[sorted_ids[0]][1] 1128 | best_r2 = r2s[0] 1129 | return best_name, best_fun, best_r2 1130 | 1131 | def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1): 1132 | ''' 1133 | automatic symbolic regression: using top 1 suggestion from suggest_symbolic to replace splines with symbolic activations 1134 | 1135 | Args: 1136 | ----- 1137 | lib : None or a list of function names 1138 | the symbolic library 1139 | verbose : int 1140 | verbosity 1141 | 1142 | Returns: 1143 | -------- 1144 | None (print suggested symbolic formulas) 1145 | 1146 | Example 1 1147 | --------- 1148 | >>> # default library 1149 | >>> from utils import create_dataset 1150 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1151 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1152 | >>> dataset = create_dataset(f, n_var=2) 1153 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1154 | >>> >>> model = model.prune() 1155 | >>> model(dataset['train_input']) 1156 | >>> model.auto_symbolic() 1157 | fixing (0,0,0) with sin, r2=0.9994837045669556 1158 | fixing (0,1,0) with cosh, r2=0.9978033900260925 1159 | fixing (1,0,0) with arctan, r2=0.9997088313102722 1160 | 1161 | Example 2 1162 | --------- 1163 | >>> # customized library 1164 | >>> from utils import create_dataset 1165 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) 1166 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1167 | >>> dataset = create_dataset(f, n_var=2) 1168 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1169 | >>> >>> model = model.prune() 1170 | >>> model(dataset['train_input']) 1171 | >>> model.auto_symbolic(lib=['exp','sin','x^2']) 1172 | fixing (0,0,0) with sin, r2=0.999411404132843 1173 | fixing (0,1,0) with x^2, r2=0.9962921738624573 1174 | fixing (1,0,0) with exp, r2=0.9980258941650391 1175 | ''' 1176 | for l in range(len(self.width) - 1): 1177 | for i in range(self.width[l]): 1178 | for j in range(self.width[l + 1]): 1179 | if self.symbolic_fun[l].mask[j, i] > 0.: 1180 | print(f'skipping ({l},{i},{j}) since already symbolic') 1181 | else: 1182 | name, fun, r2 = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, 1183 | verbose=False) 1184 | self.fix_symbolic(l, i, j, name, verbose=verbose > 1) 1185 | if verbose >= 1: 1186 | print(f'fixing ({l},{i},{j}) with {name}, r2={r2}') 1187 | 1188 | def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify=False): 1189 | ''' 1190 | obtain the symbolic formula 1191 | 1192 | Args: 1193 | ----- 1194 | floating_digit : int 1195 | the number of digits to display 1196 | var : list of str 1197 | the name of variables (if not provided, by default using ['x_1', 'x_2', ...]) 1198 | normalizer : [mean array (floats), varaince array (floats)] 1199 | the normalization applied to inputs 1200 | simplify : bool 1201 | If True, simplify the equation at each step (usually quite slow), so set up False by default. 1202 | 1203 | Returns: 1204 | -------- 1205 | symbolic formula : sympy function 1206 | 1207 | Example 1208 | ------- 1209 | >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02) 1210 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 1211 | >>> dataset = create_dataset(f, n_var=2) 1212 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); 1213 | >>> model = model.prune() 1214 | >>> model(dataset['train_input']) 1215 | >>> model.auto_symbolic(lib=['exp','sin','x^2']) 1216 | >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False); 1217 | >>> model.symbolic_formula() 1218 | ''' 1219 | symbolic_acts = [] 1220 | x = [] 1221 | 1222 | def ex_round(ex1, floating_digit=floating_digit): 1223 | ex2 = ex1 1224 | for a in sympy.preorder_traversal(ex1): 1225 | if isinstance(a, sympy.Float): 1226 | ex2 = ex2.subs(a, round(a, floating_digit)) 1227 | return ex2 1228 | 1229 | # define variables 1230 | if var == None: 1231 | for ii in range(1, self.width[0] + 1): 1232 | exec(f"x{ii} = sympy.Symbol('x_{ii}')") 1233 | exec(f"x.append(x{ii})") 1234 | else: 1235 | x = [sympy.symbols(var_) for var_ in var] 1236 | 1237 | x0 = x 1238 | 1239 | if normalizer != None: 1240 | mean = normalizer[0] 1241 | std = normalizer[1] 1242 | x = [(x[i] - mean[i]) / std[i] for i in range(len(x))] 1243 | 1244 | symbolic_acts.append(x) 1245 | 1246 | for l in range(len(self.width) - 1): 1247 | y = [] 1248 | for j in range(self.width[l + 1]): 1249 | yj = 0. 1250 | for i in range(self.width[l]): 1251 | a, b, c, d = self.symbolic_fun[l].affine[j, i] 1252 | sympy_fun = self.symbolic_fun[l].funs_sympy[j][i] 1253 | try: 1254 | yj += c * sympy_fun(a * x[i] + b) + d 1255 | except: 1256 | print('make sure all activations need to be converted to symbolic formulas first!') 1257 | return 1258 | if simplify == True: 1259 | y.append(sympy.simplify(yj + self.biases[l].weight.data[0, j])) 1260 | else: 1261 | y.append(yj + self.biases[l].weight.data[0, j]) 1262 | 1263 | x = y 1264 | symbolic_acts.append(x) 1265 | 1266 | self.symbolic_acts = [[ex_round(symbolic_acts[l][i]) for i in range(len(symbolic_acts[l]))] for l in 1267 | range(len(symbolic_acts))] 1268 | 1269 | out_dim = len(symbolic_acts[-1]) 1270 | return [ex_round(symbolic_acts[-1][i]) for i in range(len(symbolic_acts[-1]))], x0 1271 | 1272 | def clear_ckpts(self, folder='./model_ckpt'): 1273 | ''' 1274 | clear all checkpoints 1275 | 1276 | Args: 1277 | ----- 1278 | folder : str 1279 | the folder that stores checkpoints 1280 | 1281 | Returns: 1282 | -------- 1283 | None 1284 | ''' 1285 | if os.path.exists(folder): 1286 | files = glob.glob(folder + '/*') 1287 | for f in files: 1288 | os.remove(f) 1289 | else: 1290 | os.makedirs(folder) 1291 | 1292 | def save_ckpt(self, name, folder='./model_ckpt'): 1293 | ''' 1294 | save the current model as checkpoint 1295 | 1296 | Args: 1297 | ----- 1298 | name: str 1299 | the name of the checkpoint to be saved 1300 | folder : str 1301 | the folder that stores checkpoints 1302 | 1303 | Returns: 1304 | -------- 1305 | None 1306 | ''' 1307 | 1308 | if not os.path.exists(folder): 1309 | os.makedirs(folder) 1310 | 1311 | torch.save(self.state_dict(), folder + '/' + name) 1312 | print('save this model to', folder + '/' + name) 1313 | 1314 | def load_ckpt(self, name, folder='./model_ckpt'): 1315 | ''' 1316 | load a checkpoint to the current model 1317 | 1318 | Args: 1319 | ----- 1320 | name: str 1321 | the name of the checkpoint to be loaded 1322 | folder : str 1323 | the folder that stores checkpoints 1324 | 1325 | Returns: 1326 | -------- 1327 | None 1328 | ''' 1329 | self.load_state_dict(torch.load(folder + '/' + name)) -------------------------------------------------------------------------------- /models/LBFGS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim import Optimizer 4 | 5 | __all__ = ['LBFGS'] 6 | 7 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 8 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 9 | # Compute bounds of interpolation area 10 | if bounds is not None: 11 | xmin_bound, xmax_bound = bounds 12 | else: 13 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 14 | 15 | # Code for most common case: cubic interpolation of 2 points 16 | # w/ function and derivative values for both 17 | # Solution in this case (where x2 is the farthest point): 18 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 19 | # d2 = sqrt(d1^2 - g1*g2); 20 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 21 | # t_new = min(max(min_pos,xmin_bound),xmax_bound); 22 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 23 | d2_square = d1**2 - g1 * g2 24 | if d2_square >= 0: 25 | d2 = d2_square.sqrt() 26 | if x1 <= x2: 27 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 28 | else: 29 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 30 | return min(max(min_pos, xmin_bound), xmax_bound) 31 | else: 32 | return (xmin_bound + xmax_bound) / 2. 33 | 34 | 35 | def _strong_wolfe(obj_func, 36 | x, 37 | t, 38 | d, 39 | f, 40 | g, 41 | gtd, 42 | c1=1e-4, 43 | c2=0.9, 44 | tolerance_change=1e-9, 45 | max_ls=25): 46 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 47 | d_norm = d.abs().max() 48 | g = g.clone(memory_format=torch.contiguous_format) 49 | # evaluate objective and gradient using initial step 50 | f_new, g_new = obj_func(x, t, d) 51 | ls_func_evals = 1 52 | gtd_new = g_new.dot(d) 53 | 54 | # bracket an interval containing a point satisfying the Wolfe criteria 55 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 56 | done = False 57 | ls_iter = 0 58 | while ls_iter < max_ls: 59 | # check conditions 60 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 61 | bracket = [t_prev, t] 62 | bracket_f = [f_prev, f_new] 63 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 64 | bracket_gtd = [gtd_prev, gtd_new] 65 | break 66 | 67 | if abs(gtd_new) <= -c2 * gtd: 68 | bracket = [t] 69 | bracket_f = [f_new] 70 | bracket_g = [g_new] 71 | done = True 72 | break 73 | 74 | if gtd_new >= 0: 75 | bracket = [t_prev, t] 76 | bracket_f = [f_prev, f_new] 77 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 78 | bracket_gtd = [gtd_prev, gtd_new] 79 | break 80 | 81 | # interpolate 82 | min_step = t + 0.01 * (t - t_prev) 83 | max_step = t * 10 84 | tmp = t 85 | t = _cubic_interpolate( 86 | t_prev, 87 | f_prev, 88 | gtd_prev, 89 | t, 90 | f_new, 91 | gtd_new, 92 | bounds=(min_step, max_step)) 93 | 94 | # next step 95 | t_prev = tmp 96 | f_prev = f_new 97 | g_prev = g_new.clone(memory_format=torch.contiguous_format) 98 | gtd_prev = gtd_new 99 | f_new, g_new = obj_func(x, t, d) 100 | ls_func_evals += 1 101 | gtd_new = g_new.dot(d) 102 | ls_iter += 1 103 | 104 | # reached max number of iterations? 105 | if ls_iter == max_ls: 106 | bracket = [0, t] 107 | bracket_f = [f, f_new] 108 | bracket_g = [g, g_new] 109 | 110 | # zoom phase: we now have a point satisfying the criteria, or 111 | # a bracket around it. We refine the bracket until we find the 112 | # exact point satisfying the criteria 113 | insuf_progress = False 114 | # find high and low points in bracket 115 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) 116 | while not done and ls_iter < max_ls: 117 | # line-search bracket is so small 118 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: 119 | break 120 | 121 | # compute new trial value 122 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], 123 | bracket[1], bracket_f[1], bracket_gtd[1]) 124 | 125 | # test that we are making sufficient progress: 126 | # in case `t` is so close to boundary, we mark that we are making 127 | # insufficient progress, and if 128 | # + we have made insufficient progress in the last step, or 129 | # + `t` is at one of the boundary, 130 | # we will move `t` to a position which is `0.1 * len(bracket)` 131 | # away from the nearest boundary point. 132 | eps = 0.1 * (max(bracket) - min(bracket)) 133 | if min(max(bracket) - t, t - min(bracket)) < eps: 134 | # interpolation close to boundary 135 | if insuf_progress or t >= max(bracket) or t <= min(bracket): 136 | # evaluate at 0.1 away from boundary 137 | if abs(t - max(bracket)) < abs(t - min(bracket)): 138 | t = max(bracket) - eps 139 | else: 140 | t = min(bracket) + eps 141 | insuf_progress = False 142 | else: 143 | insuf_progress = True 144 | else: 145 | insuf_progress = False 146 | 147 | # Evaluate new point 148 | f_new, g_new = obj_func(x, t, d) 149 | ls_func_evals += 1 150 | gtd_new = g_new.dot(d) 151 | ls_iter += 1 152 | 153 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 154 | # Armijo condition not satisfied or not lower than lowest point 155 | bracket[high_pos] = t 156 | bracket_f[high_pos] = f_new 157 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) 158 | bracket_gtd[high_pos] = gtd_new 159 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 160 | else: 161 | if abs(gtd_new) <= -c2 * gtd: 162 | # Wolfe conditions satisfied 163 | done = True 164 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 165 | # old low becomes new high 166 | bracket[high_pos] = bracket[low_pos] 167 | bracket_f[high_pos] = bracket_f[low_pos] 168 | bracket_g[high_pos] = bracket_g[low_pos] 169 | bracket_gtd[high_pos] = bracket_gtd[low_pos] 170 | 171 | # new point becomes new low 172 | bracket[low_pos] = t 173 | bracket_f[low_pos] = f_new 174 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) 175 | bracket_gtd[low_pos] = gtd_new 176 | 177 | # return stuff 178 | t = bracket[low_pos] 179 | f_new = bracket_f[low_pos] 180 | g_new = bracket_g[low_pos] 181 | return f_new, g_new, t, ls_func_evals 182 | 183 | 184 | 185 | class LBFGS(Optimizer): 186 | """Implements L-BFGS algorithm. 187 | 188 | Heavily inspired by `minFunc 189 | `_. 190 | 191 | .. warning:: 192 | This optimizer doesn't support per-parameter options and parameter 193 | groups (there can be only one). 194 | 195 | .. warning:: 196 | Right now all parameters have to be on a single device. This will be 197 | improved in the future. 198 | 199 | .. note:: 200 | This is a very memory intensive optimizer (it requires additional 201 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 202 | try reducing the history size, or use a different algorithm. 203 | 204 | Args: 205 | lr (float): learning rate (default: 1) 206 | max_iter (int): maximal number of iterations per optimization step 207 | (default: 20) 208 | max_eval (int): maximal number of function evaluations per optimization 209 | step (default: max_iter * 1.25). 210 | tolerance_grad (float): termination tolerance on first order optimality 211 | (default: 1e-7). 212 | tolerance_change (float): termination tolerance on function 213 | value/parameter changes (default: 1e-9). 214 | history_size (int): update history size (default: 100). 215 | line_search_fn (str): either 'strong_wolfe' or None (default: None). 216 | """ 217 | 218 | def __init__(self, 219 | params, 220 | lr=1, 221 | max_iter=20, 222 | max_eval=None, 223 | tolerance_grad=1e-7, 224 | tolerance_change=1e-9, 225 | tolerance_ys=1e-32, 226 | history_size=100, 227 | line_search_fn=None): 228 | if max_eval is None: 229 | max_eval = max_iter * 5 // 4 230 | defaults = dict( 231 | lr=lr, 232 | max_iter=max_iter, 233 | max_eval=max_eval, 234 | tolerance_grad=tolerance_grad, 235 | tolerance_change=tolerance_change, 236 | tolerance_ys=tolerance_ys, 237 | history_size=history_size, 238 | line_search_fn=line_search_fn) 239 | super().__init__(params, defaults) 240 | 241 | if len(self.param_groups) != 1: 242 | raise ValueError("LBFGS doesn't support per-parameter options " 243 | "(parameter groups)") 244 | 245 | self._params = self.param_groups[0]['params'] 246 | self._numel_cache = None 247 | 248 | def _numel(self): 249 | if self._numel_cache is None: 250 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 251 | return self._numel_cache 252 | 253 | def _gather_flat_grad(self): 254 | views = [] 255 | for p in self._params: 256 | if p.grad is None: 257 | view = p.new(p.numel()).zero_() 258 | elif p.grad.is_sparse: 259 | view = p.grad.to_dense().view(-1) 260 | else: 261 | view = p.grad.view(-1) 262 | views.append(view) 263 | return torch.cat(views, 0) 264 | 265 | def _add_grad(self, step_size, update): 266 | offset = 0 267 | for p in self._params: 268 | numel = p.numel() 269 | # view as to avoid deprecated pointwise semantics 270 | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) 271 | offset += numel 272 | assert offset == self._numel() 273 | 274 | def _clone_param(self): 275 | return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 276 | 277 | def _set_param(self, params_data): 278 | for p, pdata in zip(self._params, params_data): 279 | p.copy_(pdata) 280 | 281 | def _directional_evaluate(self, closure, x, t, d): 282 | self._add_grad(t, d) 283 | loss = float(closure()) 284 | flat_grad = self._gather_flat_grad() 285 | self._set_param(x) 286 | return loss, flat_grad 287 | 288 | 289 | @torch.no_grad() 290 | def step(self, closure): 291 | """Perform a single optimization step. 292 | 293 | Args: 294 | closure (Callable): A closure that reevaluates the model 295 | and returns the loss. 296 | """ 297 | assert len(self.param_groups) == 1 298 | 299 | # Make sure the closure is always called with grad enabled 300 | closure = torch.enable_grad()(closure) 301 | 302 | group = self.param_groups[0] 303 | lr = group['lr'] 304 | max_iter = group['max_iter'] 305 | max_eval = group['max_eval'] 306 | tolerance_grad = group['tolerance_grad'] 307 | tolerance_change = group['tolerance_change'] 308 | tolerance_ys = group['tolerance_ys'] 309 | line_search_fn = group['line_search_fn'] 310 | history_size = group['history_size'] 311 | 312 | # NOTE: LBFGS has only global state, but we register it as state for 313 | # the first param, because this helps with casting in load_state_dict 314 | state = self.state[self._params[0]] 315 | state.setdefault('func_evals', 0) 316 | state.setdefault('n_iter', 0) 317 | 318 | # evaluate initial f(x) and df/dx 319 | orig_loss = closure() 320 | loss = float(orig_loss) 321 | current_evals = 1 322 | state['func_evals'] += 1 323 | 324 | flat_grad = self._gather_flat_grad() 325 | opt_cond = flat_grad.abs().max() <= tolerance_grad 326 | 327 | # optimal condition 328 | if opt_cond: 329 | return orig_loss 330 | 331 | # tensors cached in state (for tracing) 332 | d = state.get('d') 333 | t = state.get('t') 334 | old_dirs = state.get('old_dirs') 335 | old_stps = state.get('old_stps') 336 | ro = state.get('ro') 337 | H_diag = state.get('H_diag') 338 | prev_flat_grad = state.get('prev_flat_grad') 339 | prev_loss = state.get('prev_loss') 340 | 341 | n_iter = 0 342 | # optimize for a max of max_iter iterations 343 | while n_iter < max_iter: 344 | # keep track of nb of iterations 345 | n_iter += 1 346 | state['n_iter'] += 1 347 | 348 | ############################################################ 349 | # compute gradient descent direction 350 | ############################################################ 351 | if state['n_iter'] == 1: 352 | d = flat_grad.neg() 353 | old_dirs = [] 354 | old_stps = [] 355 | ro = [] 356 | H_diag = 1 357 | else: 358 | # do lbfgs update (update memory) 359 | y = flat_grad.sub(prev_flat_grad) 360 | s = d.mul(t) 361 | ys = y.dot(s) # y*s 362 | if ys > tolerance_ys: 363 | # updating memory 364 | if len(old_dirs) == history_size: 365 | # shift history by one (limited-memory) 366 | old_dirs.pop(0) 367 | old_stps.pop(0) 368 | ro.pop(0) 369 | 370 | # store new direction/step 371 | old_dirs.append(y) 372 | old_stps.append(s) 373 | ro.append(1. / ys) 374 | 375 | # update scale of initial Hessian approximation 376 | H_diag = ys / y.dot(y) # (y*y) 377 | 378 | # compute the approximate (L-BFGS) inverse Hessian 379 | # multiplied by the gradient 380 | num_old = len(old_dirs) 381 | 382 | if 'al' not in state: 383 | state['al'] = [None] * history_size 384 | al = state['al'] 385 | 386 | # iteration in L-BFGS loop collapsed to use just one buffer 387 | q = flat_grad.neg() 388 | for i in range(num_old - 1, -1, -1): 389 | al[i] = old_stps[i].dot(q) * ro[i] 390 | q.add_(old_dirs[i], alpha=-al[i]) 391 | 392 | # multiply by initial Hessian 393 | # r/d is the final direction 394 | d = r = torch.mul(q, H_diag) 395 | for i in range(num_old): 396 | be_i = old_dirs[i].dot(r) * ro[i] 397 | r.add_(old_stps[i], alpha=al[i] - be_i) 398 | 399 | if prev_flat_grad is None: 400 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 401 | else: 402 | prev_flat_grad.copy_(flat_grad) 403 | prev_loss = loss 404 | 405 | ############################################################ 406 | # compute step length 407 | ############################################################ 408 | # reset initial guess for step size 409 | if state['n_iter'] == 1: 410 | t = min(1., 1. / flat_grad.abs().sum()) * lr 411 | else: 412 | t = lr 413 | 414 | # directional derivative 415 | gtd = flat_grad.dot(d) # g * d 416 | 417 | # directional derivative is below tolerance 418 | if gtd > -tolerance_change: 419 | break 420 | 421 | # optional line search: user function 422 | ls_func_evals = 0 423 | if line_search_fn is not None: 424 | # perform line search, using user function 425 | if line_search_fn != "strong_wolfe": 426 | raise RuntimeError("only 'strong_wolfe' is supported") 427 | else: 428 | x_init = self._clone_param() 429 | 430 | def obj_func(x, t, d): 431 | return self._directional_evaluate(closure, x, t, d) 432 | 433 | loss, flat_grad, t, ls_func_evals = _strong_wolfe( 434 | obj_func, x_init, t, d, loss, flat_grad, gtd) 435 | self._add_grad(t, d) 436 | opt_cond = flat_grad.abs().max() <= tolerance_grad 437 | else: 438 | # no line search, simply move with fixed-step 439 | self._add_grad(t, d) 440 | if n_iter != max_iter: 441 | # re-evaluate function only if not in last iteration 442 | # the reason we do this: in a stochastic setting, 443 | # no use to re-evaluate that function here 444 | with torch.enable_grad(): 445 | loss = float(closure()) 446 | flat_grad = self._gather_flat_grad() 447 | opt_cond = flat_grad.abs().max() <= tolerance_grad 448 | ls_func_evals = 1 449 | 450 | # update func eval 451 | current_evals += ls_func_evals 452 | state['func_evals'] += ls_func_evals 453 | 454 | ############################################################ 455 | # check conditions 456 | ############################################################ 457 | if n_iter == max_iter: 458 | break 459 | 460 | if current_evals >= max_eval: 461 | break 462 | 463 | # optimal condition 464 | if opt_cond: 465 | break 466 | 467 | # lack of progress 468 | if d.mul(t).abs().max() <= tolerance_change: 469 | break 470 | 471 | if abs(loss - prev_loss) < tolerance_change: 472 | break 473 | 474 | state['d'] = d 475 | state['t'] = t 476 | state['old_dirs'] = old_dirs 477 | state['old_stps'] = old_stps 478 | state['ro'] = ro 479 | state['H_diag'] = H_diag 480 | state['prev_flat_grad'] = prev_flat_grad 481 | state['prev_loss'] = prev_loss 482 | 483 | return orig_loss -------------------------------------------------------------------------------- /models/PINN.py: -------------------------------------------------------------------------------- 1 | # baseline implementation of PINNs 2 | # paper: Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations 3 | # link: https://www.sciencedirect.com/science/article/pii/S0021999118307125 4 | # code: https://github.com/maziarraissi/PINNs 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer): 12 | super(Model, self).__init__() 13 | 14 | layers = [] 15 | for i in range(num_layer - 1): 16 | if i == 0: 17 | layers.append(nn.Linear(in_features=in_dim, out_features=hidden_dim)) 18 | layers.append(nn.Tanh()) 19 | else: 20 | layers.append(nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) 21 | layers.append(nn.Tanh()) 22 | 23 | layers.append(nn.Linear(in_features=hidden_dim, out_features=out_dim)) 24 | 25 | self.linear = nn.Sequential(*layers) 26 | 27 | def forward(self, x, t): 28 | src = torch.cat((x, t), dim=-1) 29 | return self.linear(src) 30 | -------------------------------------------------------------------------------- /models/PINNsFormer.py: -------------------------------------------------------------------------------- 1 | # implementation of PINNsformer 2 | # paper: PINNsFormer: A Transformer-Based Framework For Physics-Informed Neural Networks 3 | # link: https://arxiv.org/abs/2307.11833 4 | 5 | import torch 6 | import torch.nn as nn 7 | import pdb 8 | from util import get_clones 9 | 10 | 11 | class WaveAct(nn.Module): 12 | def __init__(self): 13 | super(WaveAct, self).__init__() 14 | self.w1 = nn.Parameter(torch.ones(1), requires_grad=True) 15 | self.w2 = nn.Parameter(torch.ones(1), requires_grad=True) 16 | 17 | def forward(self, x): 18 | return self.w1 * torch.sin(x) + self.w2 * torch.cos(x) 19 | 20 | 21 | class FeedForward(nn.Module): 22 | def __init__(self, d_model, d_ff=256): 23 | super(FeedForward, self).__init__() 24 | self.linear = nn.Sequential(*[ 25 | nn.Linear(d_model, d_ff), 26 | WaveAct(), 27 | nn.Linear(d_ff, d_ff), 28 | WaveAct(), 29 | nn.Linear(d_ff, d_model) 30 | ]) 31 | 32 | def forward(self, x): 33 | return self.linear(x) 34 | 35 | 36 | class EncoderLayer(nn.Module): 37 | def __init__(self, d_model, heads): 38 | super(EncoderLayer, self).__init__() 39 | 40 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True) 41 | self.ff = FeedForward(d_model) 42 | self.act1 = WaveAct() 43 | self.act2 = WaveAct() 44 | 45 | def forward(self, x): 46 | x2 = self.act1(x) 47 | # pdb.set_trace() 48 | x = x + self.attn(x2, x2, x2)[0] 49 | x2 = self.act2(x) 50 | x = x + self.ff(x2) 51 | return x 52 | 53 | 54 | class DecoderLayer(nn.Module): 55 | def __init__(self, d_model, heads): 56 | super(DecoderLayer, self).__init__() 57 | 58 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True) 59 | self.ff = FeedForward(d_model) 60 | self.act1 = WaveAct() 61 | self.act2 = WaveAct() 62 | 63 | def forward(self, x, e_outputs): 64 | x2 = self.act1(x) 65 | x = x + self.attn(x2, e_outputs, e_outputs)[0] 66 | x2 = self.act2(x) 67 | x = x + self.ff(x2) 68 | return x 69 | 70 | 71 | class Encoder(nn.Module): 72 | def __init__(self, d_model, N, heads): 73 | super(Encoder, self).__init__() 74 | self.N = N 75 | self.layers = get_clones(EncoderLayer(d_model, heads), N) 76 | self.act = WaveAct() 77 | 78 | def forward(self, x): 79 | for i in range(self.N): 80 | x = self.layers[i](x) 81 | return self.act(x) 82 | 83 | 84 | class Decoder(nn.Module): 85 | def __init__(self, d_model, N, heads): 86 | super(Decoder, self).__init__() 87 | self.N = N 88 | self.layers = get_clones(DecoderLayer(d_model, heads), N) 89 | self.act = WaveAct() 90 | 91 | def forward(self, x, e_outputs): 92 | for i in range(self.N): 93 | x = self.layers[i](x, e_outputs) 94 | return self.act(x) 95 | 96 | 97 | class Model(nn.Module): 98 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer, hidden_d_ff=512, heads=2): 99 | super(Model, self).__init__() 100 | 101 | self.linear_emb = nn.Linear(in_dim, hidden_dim) 102 | 103 | self.encoder = Encoder(hidden_dim, num_layer, heads) 104 | self.decoder = Decoder(hidden_dim, num_layer, heads) 105 | self.linear_out = nn.Sequential(*[ 106 | nn.Linear(hidden_dim, hidden_d_ff), 107 | WaveAct(), 108 | nn.Linear(hidden_d_ff, hidden_d_ff), 109 | WaveAct(), 110 | nn.Linear(hidden_d_ff, out_dim) 111 | ]) 112 | 113 | def forward(self, x, t): 114 | src = torch.cat((x, t), dim=-1) 115 | src = self.linear_emb(src) 116 | 117 | e_outputs = self.encoder(src) 118 | d_output = self.decoder(src, e_outputs) 119 | output = self.linear_out(d_output) 120 | return output 121 | -------------------------------------------------------------------------------- /models/PINNsFormer_Enc_Only.py: -------------------------------------------------------------------------------- 1 | # implementation of PINNsformer 2 | # paper: PINNsFormer: A Transformer-Based Framework For Physics-Informed Neural Networks 3 | # link: https://arxiv.org/abs/2307.11833 4 | 5 | import torch 6 | import torch.nn as nn 7 | import pdb 8 | from util import get_clones 9 | 10 | 11 | class WaveAct(nn.Module): 12 | def __init__(self): 13 | super(WaveAct, self).__init__() 14 | self.w1 = nn.Parameter(torch.ones(1), requires_grad=True) 15 | self.w2 = nn.Parameter(torch.ones(1), requires_grad=True) 16 | 17 | def forward(self, x): 18 | return self.w1 * torch.sin(x) + self.w2 * torch.cos(x) 19 | 20 | 21 | class FeedForward(nn.Module): 22 | def __init__(self, d_model, d_ff=256): 23 | super(FeedForward, self).__init__() 24 | self.linear = nn.Sequential(*[ 25 | nn.Linear(d_model, d_ff), 26 | WaveAct(), 27 | nn.Linear(d_ff, d_ff), 28 | WaveAct(), 29 | nn.Linear(d_ff, d_model) 30 | ]) 31 | 32 | def forward(self, x): 33 | return self.linear(x) 34 | 35 | 36 | class EncoderLayer(nn.Module): 37 | def __init__(self, d_model, heads): 38 | super(EncoderLayer, self).__init__() 39 | 40 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True) 41 | self.ff = FeedForward(d_model) 42 | self.act1 = WaveAct() 43 | self.act2 = WaveAct() 44 | 45 | def forward(self, x): 46 | x2 = self.act1(x) 47 | # pdb.set_trace() 48 | x = x + self.attn(x2, x2, x2)[0] 49 | x2 = self.act2(x) 50 | x = x + self.ff(x2) 51 | return x 52 | 53 | 54 | class DecoderLayer(nn.Module): 55 | def __init__(self, d_model, heads): 56 | super(DecoderLayer, self).__init__() 57 | 58 | self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, batch_first=True) 59 | self.ff = FeedForward(d_model) 60 | self.act1 = WaveAct() 61 | self.act2 = WaveAct() 62 | 63 | def forward(self, x, e_outputs): 64 | x2 = self.act1(x) 65 | x = x + self.attn(x2, e_outputs, e_outputs)[0] 66 | x2 = self.act2(x) 67 | x = x + self.ff(x2) 68 | return x 69 | 70 | 71 | class Encoder(nn.Module): 72 | def __init__(self, d_model, N, heads): 73 | super(Encoder, self).__init__() 74 | self.N = N 75 | self.layers = get_clones(EncoderLayer(d_model, heads), N) 76 | self.act = WaveAct() 77 | 78 | def forward(self, x): 79 | for i in range(self.N): 80 | x = self.layers[i](x) 81 | return self.act(x) 82 | 83 | 84 | class Decoder(nn.Module): 85 | def __init__(self, d_model, N, heads): 86 | super(Decoder, self).__init__() 87 | self.N = N 88 | self.layers = get_clones(DecoderLayer(d_model, heads), N) 89 | self.act = WaveAct() 90 | 91 | def forward(self, x, e_outputs): 92 | for i in range(self.N): 93 | x = self.layers[i](x, e_outputs) 94 | return self.act(x) 95 | 96 | 97 | class Model(nn.Module): 98 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer, hidden_d_ff=512, heads=2): 99 | super(Model, self).__init__() 100 | 101 | self.linear_emb = nn.Linear(in_dim, hidden_dim) 102 | 103 | self.encoder = Encoder(hidden_dim, num_layer, heads) 104 | self.linear_out = nn.Sequential(*[ 105 | nn.Linear(hidden_dim, hidden_d_ff), 106 | WaveAct(), 107 | nn.Linear(hidden_d_ff, hidden_d_ff), 108 | WaveAct(), 109 | nn.Linear(hidden_d_ff, out_dim) 110 | ]) 111 | 112 | def forward(self, x, t): 113 | src = torch.cat((x, t), dim=-1) 114 | src = self.linear_emb(src) 115 | e_outputs = self.encoder(src) 116 | output = self.linear_out(e_outputs) 117 | return output 118 | -------------------------------------------------------------------------------- /models/QRes.py: -------------------------------------------------------------------------------- 1 | # baseline implementation of QRes 2 | # paper: Quadratic residual networks: A new class of neural networks for solving forward and inverse problems in physics involving pdes 3 | # link: https://arxiv.org/abs/2101.08366 4 | # code: https://github.com/jayroxis/qres 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from util import get_clones 10 | 11 | 12 | class QRes_block(nn.Module): 13 | def __init__(self, in_dim, out_dim): 14 | super(QRes_block, self).__init__() 15 | self.H1 = nn.Linear(in_features=in_dim, out_features=out_dim) 16 | self.H2 = nn.Linear(in_features=in_dim, out_features=out_dim) 17 | self.act = nn.Sigmoid() 18 | 19 | def forward(self, x): 20 | x1 = self.H1(x) 21 | x2 = self.H2(x) 22 | return self.act(x1 * x2 + x1) 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, in_dim, hidden_dim, out_dim, num_layer): 27 | super(Model, self).__init__() 28 | self.N = num_layer - 1 29 | self.inlayer = QRes_block(in_dim, hidden_dim) 30 | self.layers = get_clones(QRes_block(hidden_dim, hidden_dim), num_layer - 1) 31 | self.outlayer = nn.Linear(in_features=hidden_dim, out_features=out_dim) 32 | 33 | def forward(self, x, t): 34 | src = torch.cat((x, t), dim=-1) 35 | src = self.inlayer(src) 36 | for i in range(self.N): 37 | src = self.layers[i](src) 38 | src = self.outlayer(src) 39 | return src 40 | -------------------------------------------------------------------------------- /models/Symbolic_KANLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sympy 5 | from .utils import * 6 | 7 | 8 | class Symbolic_KANLayer(nn.Module): 9 | ''' 10 | KANLayer class 11 | 12 | Attributes: 13 | ----------- 14 | in_dim: int 15 | input dimension 16 | out_dim: int 17 | output dimension 18 | funs: 2D array of torch functions (or lambda functions) 19 | symbolic functions (torch) 20 | funs_name: 2D arry of str 21 | names of symbolic functions 22 | funs_sympy: 2D array of sympy functions (or lambda functions) 23 | symbolic functions (sympy) 24 | affine: 3D array of floats 25 | affine transformations of inputs and outputs 26 | 27 | Methods: 28 | -------- 29 | __init__(): 30 | initialize a Symbolic_KANLayer 31 | forward(): 32 | forward 33 | get_subset(): 34 | get subset of the KANLayer (used for pruning) 35 | fix_symbolic(): 36 | fix an activation function to be symbolic 37 | ''' 38 | 39 | def __init__(self, in_dim=3, out_dim=2): 40 | ''' 41 | initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions) 42 | 43 | Args: 44 | ----- 45 | in_dim : int 46 | input dimension 47 | out_dim : int 48 | output dimension 49 | 50 | Returns: 51 | -------- 52 | self 53 | 54 | Example 55 | ------- 56 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3) 57 | >>> len(sb.funs), len(sb.funs[0]) 58 | (3, 3) 59 | ''' 60 | super(Symbolic_KANLayer, self).__init__() 61 | self.out_dim = out_dim 62 | self.in_dim = in_dim 63 | self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim)).requires_grad_(False) 64 | # torch 65 | self.funs = [[lambda x: x for i in range(self.in_dim)] for j in range(self.out_dim)] 66 | # name 67 | self.funs_name = [['' for i in range(self.in_dim)] for j in range(self.out_dim)] 68 | # sympy 69 | self.funs_sympy = [['' for i in range(self.in_dim)] for j in range(self.out_dim)] 70 | 71 | self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4)) 72 | # c*f(a*x+b)+d 73 | 74 | def forward(self, x): 75 | ''' 76 | forward 77 | 78 | Args: 79 | ----- 80 | x : 2D array 81 | inputs, shape (batch, input dimension) 82 | 83 | Returns: 84 | -------- 85 | y : 2D array 86 | outputs, shape (batch, output dimension) 87 | postacts : 3D array 88 | activations after activation functions but before summing on nodes 89 | 90 | Example 91 | ------- 92 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5) 93 | >>> x = torch.normal(0,1,size=(100,3)) 94 | >>> y, postacts = sb(x) 95 | >>> y.shape, postacts.shape 96 | (torch.Size([100, 5]), torch.Size([100, 5, 3])) 97 | ''' 98 | 99 | batch = x.shape[0] 100 | postacts = [] 101 | 102 | for i in range(self.in_dim): 103 | postacts_ = [] 104 | for j in range(self.out_dim): 105 | xij = self.affine[j, i, 2] * self.funs[j][i](self.affine[j, i, 0] * x[:, [i]] + self.affine[j, i, 1]) + \ 106 | self.affine[j, i, 3] 107 | postacts_.append(self.mask[j][i] * xij) 108 | postacts.append(torch.stack(postacts_)) 109 | 110 | postacts = torch.stack(postacts) 111 | postacts = postacts.permute(2, 1, 0, 3)[:, :, :, 0] 112 | y = torch.sum(postacts, dim=2) 113 | 114 | return y, postacts 115 | 116 | def get_subset(self, in_id, out_id): 117 | ''' 118 | get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning) 119 | 120 | Args: 121 | ----- 122 | in_id : list 123 | id of selected input neurons 124 | out_id : list 125 | id of selected output neurons 126 | 127 | Returns: 128 | -------- 129 | spb : Symbolic_KANLayer 130 | 131 | Example 132 | ------- 133 | >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10) 134 | >>> sb_small = sb_large.get_subset([0,9],[1,2,3]) 135 | >>> sb_small.in_dim, sb_small.out_dim 136 | (2, 3) 137 | ''' 138 | sbb = Symbolic_KANLayer(self.in_dim, self.out_dim) 139 | sbb.in_dim = len(in_id) 140 | sbb.out_dim = len(out_id) 141 | sbb.mask.data = self.mask.data[out_id][:, in_id] 142 | sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id] 143 | sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id] 144 | sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id] 145 | sbb.affine.data = self.affine.data[out_id][:, in_id] 146 | return sbb 147 | 148 | def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10, 10), b_range=(-10, 10), 149 | verbose=True): 150 | ''' 151 | fix an activation function to be symbolic 152 | 153 | Args: 154 | ----- 155 | i : int 156 | the id of input neuron 157 | j : int 158 | the id of output neuron 159 | fun_name : str 160 | the name of the symbolic functions 161 | x : 1D array 162 | preactivations 163 | y : 1D array 164 | postactivations 165 | a_range : tuple 166 | sweeping range of a 167 | b_range : tuple 168 | sweeping range of a 169 | verbose : bool 170 | print more information if True 171 | 172 | Returns: 173 | -------- 174 | r2 (coefficient of determination) 175 | 176 | Example 1 177 | --------- 178 | >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0 179 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) 180 | >>> sb.fix_symbolic(2,1,'sin') 181 | >>> print(sb.funs_name) 182 | >>> print(sb.affine) 183 | [['', '', ''], ['', '', 'sin']] 184 | Parameter containing: 185 | tensor([[0., 0., 0., 0.], 186 | [0., 0., 0., 0.], 187 | [1., 0., 1., 0.]], requires_grad=True) 188 | Example 2 189 | --------- 190 | >>> # when x & y are provided, fit_params() is called to find the best fit coefficients 191 | >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) 192 | >>> batch = 100 193 | >>> x = torch.linspace(-1,1,steps=batch) 194 | >>> noises = torch.normal(0,1,(batch,)) * 0.02 195 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises 196 | >>> sb.fix_symbolic(2,1,'sin',x,y) 197 | >>> print(sb.funs_name) 198 | >>> print(sb.affine[1,2,:].data) 199 | r2 is 0.9999701976776123 200 | [['', '', ''], ['', '', 'sin']] 201 | tensor([2.9981, 1.9997, 5.0039, 0.6978]) 202 | ''' 203 | if isinstance(fun_name, str): 204 | fun = SYMBOLIC_LIB[fun_name][0] 205 | fun_sympy = SYMBOLIC_LIB[fun_name][1] 206 | self.funs_sympy[j][i] = fun_sympy 207 | self.funs_name[j][i] = fun_name 208 | if x == None or y == None: 209 | # initialzie from just fun 210 | self.funs[j][i] = fun 211 | if random == False: 212 | self.affine.data[j][i] = torch.tensor([1., 0., 1., 0.]) 213 | else: 214 | self.affine.data[j][i] = torch.rand(4, ) * 2 - 1 215 | return None 216 | else: 217 | # initialize from x & y and fun 218 | params, r2 = fit_params(x, y, fun, a_range=a_range, b_range=b_range, verbose=verbose) 219 | self.funs[j][i] = fun 220 | self.affine.data[j][i] = params 221 | return r2 222 | else: 223 | # if fun_name itself is a function 224 | fun = fun_name 225 | fun_sympy = fun_name 226 | self.funs_sympy[j][i] = fun_sympy 227 | self.funs_name[j][i] = "anonymous" 228 | 229 | self.funs[j][i] = fun 230 | if random == False: 231 | self.affine.data[j][i] = torch.tensor([1., 0., 1., 0.]) 232 | else: 233 | self.affine.data[j][i] = torch.rand(4, ) * 2 - 1 234 | return None -------------------------------------------------------------------------------- /models/kan_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .spline import * 5 | 6 | 7 | class KANLayer(nn.Module): 8 | """ 9 | KANLayer class 10 | 11 | 12 | Attributes: 13 | ----------- 14 | in_dim: int 15 | input dimension 16 | out_dim: int 17 | output dimension 18 | size: int 19 | the number of splines = input dimension * output dimension 20 | k: int 21 | the piecewise polynomial order of splines 22 | grid: 2D torch.float 23 | grid points 24 | noises: 2D torch.float 25 | injected noises to splines at initialization (to break degeneracy) 26 | coef: 2D torch.tensor 27 | coefficients of B-spline bases 28 | scale_base: 1D torch.float 29 | magnitude of the residual function b(x) 30 | scale_sp: 1D torch.float 31 | mangitude of the spline function spline(x) 32 | base_fun: fun 33 | residual function b(x) 34 | mask: 1D torch.float 35 | mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function. 36 | grid_eps: float in [0,1] 37 | a hyperparameter used in update_grid_from_samples. When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. 38 | weight_sharing: 1D tensor int 39 | allow spline activations to share parameters 40 | lock_counter: int 41 | counter how many activation functions are locked (weight sharing) 42 | lock_id: 1D torch.int 43 | the id of activation functions that are locked 44 | device: str 45 | device 46 | 47 | Methods: 48 | -------- 49 | __init__(): 50 | initialize a KANLayer 51 | forward(): 52 | forward 53 | update_grid_from_samples(): 54 | update grids based on samples' incoming activations 55 | initialize_grid_from_parent(): 56 | initialize grids from another model 57 | get_subset(): 58 | get subset of the KANLayer (used for pruning) 59 | lock(): 60 | lock several activation functions to share parameters 61 | unlock(): 62 | unlock already locked activation functions 63 | """ 64 | 65 | def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0, 66 | base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, 67 | device='cpu'): 68 | '''' 69 | initialize a KANLayer 70 | 71 | Args: 72 | ----- 73 | in_dim : int 74 | input dimension. Default: 2. 75 | out_dim : int 76 | output dimension. Default: 3. 77 | num : int 78 | the number of grid intervals = G. Default: 5. 79 | k : int 80 | the order of piecewise polynomial. Default: 3. 81 | noise_scale : float 82 | the scale of noise injected at initialization. Default: 0.1. 83 | scale_base : float 84 | the scale of the residual function b(x). Default: 1.0. 85 | scale_sp : float 86 | the scale of the base function spline(x). Default: 1.0. 87 | base_fun : function 88 | residual function b(x). Default: torch.nn.SiLU() 89 | grid_eps : float 90 | When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02. 91 | grid_range : list/np.array of shape (2,) 92 | setting the range of grids. Default: [-1,1]. 93 | sp_trainable : bool 94 | If true, scale_sp is trainable. Default: True. 95 | sb_trainable : bool 96 | If true, scale_base is trainable. Default: True. 97 | device : str 98 | device 99 | 100 | Returns: 101 | -------- 102 | self 103 | 104 | Example 105 | ------- 106 | >>> model = KANLayer(in_dim=3, out_dim=5) 107 | >>> (model.in_dim, model.out_dim) 108 | (3, 5) 109 | ''' 110 | super(KANLayer, self).__init__() 111 | # size 112 | self.size = size = out_dim * in_dim 113 | self.out_dim = out_dim 114 | self.in_dim = in_dim 115 | self.num = num 116 | self.k = k 117 | 118 | # shape: (size, num) 119 | self.grid = torch.einsum('i,j->ij', torch.ones(size, ), 120 | torch.linspace(grid_range[0], grid_range[1], steps=num + 1)) 121 | self.grid = torch.nn.Parameter(self.grid).requires_grad_(False) 122 | noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num 123 | noises = noises.to(device) 124 | # shape: (size, coef) 125 | self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k)) 126 | if isinstance(scale_base, float): 127 | self.scale_base = torch.nn.Parameter(torch.ones(size, ) * scale_base).requires_grad_( 128 | sb_trainable) # make scale trainable 129 | else: 130 | self.scale_base = torch.nn.Parameter(scale_base).requires_grad_(sb_trainable) 131 | self.scale_sp = torch.nn.Parameter(torch.ones(size, ) * scale_sp).requires_grad_( 132 | sp_trainable) # make scale trainable 133 | self.base_fun = base_fun 134 | 135 | self.mask = torch.nn.Parameter(torch.ones(size, )).requires_grad_(False) 136 | self.grid_eps = grid_eps 137 | self.weight_sharing = torch.arange(size) 138 | self.lock_counter = 0 139 | self.lock_id = torch.zeros(size) 140 | self.device = device 141 | 142 | def forward(self, x): 143 | ''' 144 | KANLayer forward given input x 145 | 146 | Args: 147 | ----- 148 | x : 2D torch.float 149 | inputs, shape (number of samples, input dimension) 150 | 151 | Returns: 152 | -------- 153 | y : 2D torch.float 154 | outputs, shape (number of samples, output dimension) 155 | preacts : 3D torch.float 156 | fan out x into activations, shape (number of sampels, output dimension, input dimension) 157 | postacts : 3D torch.float 158 | the outputs of activation functions with preacts as inputs 159 | postspline : 3D torch.float 160 | the outputs of spline functions with preacts as inputs 161 | 162 | Example 163 | ------- 164 | >>> model = KANLayer(in_dim=3, out_dim=5) 165 | >>> x = torch.normal(0,1,size=(100,3)) 166 | >>> y, preacts, postacts, postspline = model(x) 167 | >>> y.shape, preacts.shape, postacts.shape, postspline.shape 168 | (torch.Size([100, 5]), 169 | torch.Size([100, 5, 3]), 170 | torch.Size([100, 5, 3]), 171 | torch.Size([100, 5, 3])) 172 | ''' 173 | batch = x.shape[0] 174 | # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim) 175 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute( 176 | 1, 0) 177 | preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim) 178 | base = self.base_fun(x).permute(1, 0) # shape (batch, size) 179 | y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, 180 | device=self.device) # shape (size, batch) 181 | y = y.permute(1, 0) # shape (batch, size) 182 | postspline = y.clone().reshape(batch, self.out_dim, self.in_dim) 183 | y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y 184 | y = self.mask[None, :] * y 185 | postacts = y.clone().reshape(batch, self.out_dim, self.in_dim) 186 | y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2) # shape (batch, out_dim) 187 | # y shape: (batch, out_dim); preacts shape: (batch, in_dim, out_dim) 188 | # postspline shape: (batch, in_dim, out_dim); postacts: (batch, in_dim, out_dim) 189 | # postspline is for extension; postacts is for visualization 190 | return y, preacts, postacts, postspline 191 | 192 | def update_grid_from_samples(self, x): 193 | ''' 194 | update grid from samples 195 | 196 | Args: 197 | ----- 198 | x : 2D torch.float 199 | inputs, shape (number of samples, input dimension) 200 | 201 | Returns: 202 | -------- 203 | None 204 | 205 | Example 206 | ------- 207 | >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) 208 | >>> print(model.grid.data) 209 | >>> x = torch.linspace(-3,3,steps=100)[:,None] 210 | >>> model.update_grid_from_samples(x) 211 | >>> print(model.grid.data) 212 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 213 | tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]]) 214 | ''' 215 | batch = x.shape[0] 216 | x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute( 217 | 1, 0) 218 | x_pos = torch.sort(x, dim=1)[0] 219 | y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device) 220 | num_interval = self.grid.shape[1] - 1 221 | ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] 222 | grid_adaptive = x_pos[:, ids] 223 | margin = 0.01 224 | grid_uniform = torch.cat( 225 | [grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in 226 | np.linspace(0, 1, num=self.grid.shape[1])], dim=1) 227 | self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 228 | self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device) 229 | 230 | def initialize_grid_from_parent(self, parent, x): 231 | ''' 232 | update grid from a parent KANLayer & samples 233 | 234 | Args: 235 | ----- 236 | parent : KANLayer 237 | a parent KANLayer (whose grid is usually coarser than the current model) 238 | x : 2D torch.float 239 | inputs, shape (number of samples, input dimension) 240 | 241 | Returns: 242 | -------- 243 | None 244 | 245 | Example 246 | ------- 247 | >>> batch = 100 248 | >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) 249 | >>> print(parent_model.grid.data) 250 | >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) 251 | >>> x = torch.normal(0,1,size=(batch, 1)) 252 | >>> model.initialize_grid_from_parent(parent_model, x) 253 | >>> print(model.grid.data) 254 | tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]]) 255 | tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000, 256 | 0.6000, 0.8000, 1.0000]]) 257 | ''' 258 | batch = x.shape[0] 259 | # preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim) 260 | x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, 261 | self.size).permute(1, 262 | 0) 263 | x_pos = parent.grid 264 | sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device) 265 | sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1) 266 | y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device) 267 | percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) 268 | self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) 269 | self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k, self.device) 270 | 271 | def get_subset(self, in_id, out_id): 272 | ''' 273 | get a smaller KANLayer from a larger KANLayer (used for pruning) 274 | 275 | Args: 276 | ----- 277 | in_id : list 278 | id of selected input neurons 279 | out_id : list 280 | id of selected output neurons 281 | 282 | Returns: 283 | -------- 284 | spb : KANLayer 285 | 286 | Example 287 | ------- 288 | >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) 289 | >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) 290 | >>> kanlayer_small.in_dim, kanlayer_small.out_dim 291 | (2, 3) 292 | ''' 293 | spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) 294 | spb.grid.data = self.grid.reshape(self.out_dim, self.in_dim, spb.num + 1)[out_id][:, in_id].reshape(-1, 295 | spb.num + 1) 296 | spb.coef.data = self.coef.reshape(self.out_dim, self.in_dim, spb.coef.shape[1])[out_id][:, in_id].reshape(-1, 297 | spb.coef.shape[ 298 | 1]) 299 | spb.scale_base.data = self.scale_base.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 300 | spb.scale_sp.data = self.scale_sp.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 301 | spb.mask.data = self.mask.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, ) 302 | 303 | spb.in_dim = len(in_id) 304 | spb.out_dim = len(out_id) 305 | spb.size = spb.in_dim * spb.out_dim 306 | return spb 307 | 308 | def lock(self, ids): 309 | ''' 310 | lock activation functions to share parameters based on ids 311 | 312 | Args: 313 | ----- 314 | ids : list 315 | list of ids of activation functions 316 | 317 | Returns: 318 | -------- 319 | None 320 | 321 | Example 322 | ------- 323 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3) 324 | >>> print(model.weight_sharing.reshape(3,3)) 325 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same 326 | >>> print(model.weight_sharing.reshape(3,3)) 327 | tensor([[0, 1, 2], 328 | [3, 4, 5], 329 | [6, 7, 8]]) 330 | tensor([[0, 1, 2], 331 | [3, 4, 0], 332 | [6, 0, 8]]) 333 | ''' 334 | self.lock_counter += 1 335 | # ids: [[i1,j1],[i2,j2],[i3,j3],...] 336 | for i in range(len(ids)): 337 | if i != 0: 338 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[0][1] * self.in_dim + ids[0][0] 339 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = self.lock_counter 340 | 341 | def unlock(self, ids): 342 | ''' 343 | unlock activation functions 344 | 345 | Args: 346 | ----- 347 | ids : list 348 | list of ids of activation functions 349 | 350 | Returns: 351 | -------- 352 | None 353 | 354 | Example 355 | ------- 356 | >>> model = KANLayer(in_dim=3, out_dim=3, num=5, k=3) 357 | >>> model.lock([[0,0],[1,2],[2,1]]) # set (0,0),(1,2),(2,1) functions to be the same 358 | >>> print(model.weight_sharing.reshape(3,3)) 359 | >>> model.unlock([[0,0],[1,2],[2,1]]) # unlock the locked functions 360 | >>> print(model.weight_sharing.reshape(3,3)) 361 | tensor([[0, 1, 2], 362 | [3, 4, 0], 363 | [6, 0, 8]]) 364 | tensor([[0, 1, 2], 365 | [3, 4, 5], 366 | [6, 7, 8]]) 367 | ''' 368 | # check ids are locked 369 | num = len(ids) 370 | locked = True 371 | for i in range(num): 372 | locked *= (self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] == self.weight_sharing[ 373 | ids[0][1] * self.in_dim + ids[0][0]]) 374 | if locked == False: 375 | print("they are not locked. unlock failed.") 376 | return 0 377 | for i in range(len(ids)): 378 | self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[i][1] * self.in_dim + ids[i][0] 379 | self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = 0 380 | self.lock_counter -= 1 -------------------------------------------------------------------------------- /models/spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def B_batch(x, grid, k=0, extend=True, device='cpu'): 5 | ''' 6 | evaludate x on B-spline bases 7 | 8 | Args: 9 | ----- 10 | x : 2D torch.tensor 11 | inputs, shape (number of splines, number of samples) 12 | grid : 2D torch.tensor 13 | grids, shape (number of splines, number of grid points) 14 | k : int 15 | the piecewise polynomial order of splines. 16 | extend : bool 17 | If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True 18 | device : str 19 | devicde 20 | 21 | Returns: 22 | -------- 23 | spline values : 3D torch.tensor 24 | shape (number of splines, number of B-spline bases (coeffcients), number of samples). The numbef of B-spline bases = number of grid points + k - 1. 25 | 26 | Example 27 | ------- 28 | >>> num_spline = 5 29 | >>> num_sample = 100 30 | >>> num_grid_interval = 10 31 | >>> k = 3 32 | >>> x = torch.normal(0,1,size=(num_spline, num_sample)) 33 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 34 | >>> B_batch(x, grids, k=k).shape 35 | torch.Size([5, 13, 100]) 36 | ''' 37 | 38 | # x shape: (size, x); grid shape: (size, grid) 39 | def extend_grid(grid, k_extend=0): 40 | # pad k to left and right 41 | # grid shape: (batch, grid) 42 | h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) 43 | 44 | for i in range(k_extend): 45 | grid = torch.cat([grid[:, [0]] - h, grid], dim=1) 46 | grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) 47 | grid = grid.to(device) 48 | return grid 49 | 50 | if extend == True: 51 | grid = extend_grid(grid, k_extend=k) 52 | 53 | grid = grid.unsqueeze(dim=2).to(device) 54 | x = x.unsqueeze(dim=1).to(device) 55 | 56 | if k == 0: 57 | value = (x >= grid[:, :-1]) * (x < grid[:, 1:]) 58 | else: 59 | B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device) 60 | value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + ( 61 | grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:] 62 | return value 63 | 64 | 65 | def coef2curve(x_eval, grid, coef, k, device="cpu"): 66 | ''' 67 | converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). 68 | 69 | Args: 70 | ----- 71 | x_eval : 2D torch.tensor) 72 | shape (number of splines, number of samples) 73 | grid : 2D torch.tensor) 74 | shape (number of splines, number of grid points) 75 | coef : 2D torch.tensor) 76 | shape (number of splines, number of coef params). number of coef params = number of grid intervals + k 77 | k : int 78 | the piecewise polynomial order of splines. 79 | device : str 80 | devicde 81 | 82 | Returns: 83 | -------- 84 | y_eval : 2D torch.tensor 85 | shape (number of splines, number of samples) 86 | 87 | Example 88 | ------- 89 | >>> num_spline = 5 90 | >>> num_sample = 100 91 | >>> num_grid_interval = 10 92 | >>> k = 3 93 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) 94 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 95 | >>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k)) 96 | >>> coef2curve(x_eval, grids, coef, k=k).shape 97 | torch.Size([5, 100]) 98 | ''' 99 | # x_eval: (size, batch), grid: (size, grid), coef: (size, coef) 100 | # coef: (size, coef), B_batch: (size, coef, batch), summer over coef 101 | y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device)) 102 | return y_eval 103 | 104 | 105 | def curve2coef(x_eval, y_eval, grid, k, device="cpu"): 106 | ''' 107 | converting B-spline curves to B-spline coefficients using least squares. 108 | 109 | Args: 110 | ----- 111 | x_eval : 2D torch.tensor 112 | shape (number of splines, number of samples) 113 | y_eval : 2D torch.tensor 114 | shape (number of splines, number of samples) 115 | grid : 2D torch.tensor 116 | shape (number of splines, number of grid points) 117 | k : int 118 | the piecewise polynomial order of splines. 119 | device : str 120 | devicde 121 | 122 | Example 123 | ------- 124 | >>> num_spline = 5 125 | >>> num_sample = 100 126 | >>> num_grid_interval = 10 127 | >>> k = 3 128 | >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) 129 | >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample)) 130 | >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) 131 | >>> curve2coef(x_eval, y_eval, grids, k=k).shape 132 | torch.Size([5, 13]) 133 | ''' 134 | # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar 135 | mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) 136 | coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 137 | 0] # sometimes 'cuda' version may diverge 138 | return coef.to(device) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.linear_model import LinearRegression 4 | import sympy 5 | 6 | # sigmoid = sympy.Function('sigmoid') 7 | # name: (torch implementation, sympy implementation) 8 | SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x), 9 | 'x^2': (lambda x: x ** 2, lambda x: x ** 2), 10 | 'x^3': (lambda x: x ** 3, lambda x: x ** 3), 11 | 'x^4': (lambda x: x ** 4, lambda x: x ** 4), 12 | '1/x': (lambda x: 1 / x, lambda x: 1 / x), 13 | '1/x^2': (lambda x: 1 / x ** 2, lambda x: 1 / x ** 2), 14 | '1/x^3': (lambda x: 1 / x ** 3, lambda x: 1 / x ** 3), 15 | '1/x^4': (lambda x: 1 / x ** 4, lambda x: 1 / x ** 4), 16 | 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x)), 17 | '1/sqrt(x)': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x)), 18 | 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x)), 19 | 'log': (lambda x: torch.log(x), lambda x: sympy.log(x)), 20 | 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x)), 21 | 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x)), 22 | 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x)), 23 | 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x)), 24 | 'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid')), 25 | # 'relu': (lambda x: torch.relu(x), relu), 26 | 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x)), 27 | 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.arcsin(x)), 28 | 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x)), 29 | 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x)), 30 | '0': (lambda x: x * 0, lambda x: x * 0), 31 | 'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2)), 32 | 'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x)), 33 | # 'logcosh': (lambda x: torch.log(torch.cosh(x)), lambda x: sympy.log(sympy.cosh(x))), 34 | # 'cosh^2': (lambda x: torch.cosh(x)**2, lambda x: sympy.cosh(x)**2), 35 | } 36 | 37 | 38 | def create_dataset(f, 39 | n_var=2, 40 | ranges=[-1, 1], 41 | train_num=1000, 42 | test_num=1000, 43 | normalize_input=False, 44 | normalize_label=False, 45 | device='cpu', 46 | seed=0): 47 | ''' 48 | create dataset 49 | 50 | Args: 51 | ----- 52 | f : function 53 | the symbolic formula used to create the synthetic dataset 54 | ranges : list or np.array; shape (2,) or (n_var, 2) 55 | the range of input variables. Default: [-1,1]. 56 | train_num : int 57 | the number of training samples. Default: 1000. 58 | test_num : int 59 | the number of test samples. Default: 1000. 60 | normalize_input : bool 61 | If True, apply normalization to inputs. Default: False. 62 | normalize_label : bool 63 | If True, apply normalization to labels. Default: False. 64 | device : str 65 | device. Default: 'cpu'. 66 | seed : int 67 | random seed. Default: 0. 68 | 69 | Returns: 70 | -------- 71 | dataset : dic 72 | Train/test inputs/labels are dataset['train_input'], dataset['train_label'], 73 | dataset['test_input'], dataset['test_label'] 74 | 75 | Example 76 | ------- 77 | >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 78 | >>> dataset = create_dataset(f, n_var=2, train_num=100) 79 | >>> dataset['train_input'].shape 80 | torch.Size([100, 2]) 81 | ''' 82 | 83 | np.random.seed(seed) 84 | torch.manual_seed(seed) 85 | 86 | if len(np.array(ranges).shape) == 1: 87 | ranges = np.array(ranges * n_var).reshape(n_var, 2) 88 | else: 89 | ranges = np.array(ranges) 90 | 91 | train_input = torch.zeros(train_num, n_var) 92 | test_input = torch.zeros(test_num, n_var) 93 | for i in range(n_var): 94 | train_input[:, i] = torch.rand(train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0] 95 | test_input[:, i] = torch.rand(test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0] 96 | 97 | train_label = f(train_input) 98 | test_label = f(test_input) 99 | 100 | def normalize(data, mean, std): 101 | return (data - mean) / std 102 | 103 | if normalize_input == True: 104 | mean_input = torch.mean(train_input, dim=0, keepdim=True) 105 | std_input = torch.std(train_input, dim=0, keepdim=True) 106 | train_input = normalize(train_input, mean_input, std_input) 107 | test_input = normalize(test_input, mean_input, std_input) 108 | 109 | if normalize_label == True: 110 | mean_label = torch.mean(train_label, dim=0, keepdim=True) 111 | std_label = torch.std(train_label, dim=0, keepdim=True) 112 | train_label = normalize(train_label, mean_label, std_label) 113 | test_label = normalize(test_label, mean_label, std_label) 114 | 115 | dataset = {} 116 | dataset['train_input'] = train_input.to(device) 117 | dataset['test_input'] = test_input.to(device) 118 | 119 | dataset['train_label'] = train_label.to(device) 120 | dataset['test_label'] = test_label.to(device) 121 | 122 | return dataset 123 | 124 | 125 | def fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True): 126 | ''' 127 | fit a, b, c, d such that 128 | 129 | .. math:: 130 | |y-(cf(ax+b)+d)|^2 131 | 132 | is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model. 133 | 134 | Args: 135 | ----- 136 | x : 1D array 137 | x values 138 | y : 1D array 139 | y values 140 | fun : function 141 | symbolic function 142 | a_range : tuple 143 | sweeping range of a 144 | b_range : tuple 145 | sweeping range of b 146 | grid_num : int 147 | number of steps along a and b 148 | iteration : int 149 | number of zooming in 150 | verbose : bool 151 | print extra information if True 152 | 153 | Returns: 154 | -------- 155 | a_best : float 156 | best fitted a 157 | b_best : float 158 | best fitted b 159 | c_best : float 160 | best fitted c 161 | d_best : float 162 | best fitted d 163 | r2_best : float 164 | best r2 (coefficient of determination) 165 | 166 | Example 167 | ------- 168 | >>> num = 100 169 | >>> x = torch.linspace(-1,1,steps=num) 170 | >>> noises = torch.normal(0,1,(num,)) * 0.02 171 | >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises 172 | >>> fit_params(x, y, torch.sin) 173 | r2 is 0.9999727010726929 174 | (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000)) 175 | ''' 176 | # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array. 177 | # sweep a and b, choose the best fitted model 178 | for _ in range(iteration): 179 | a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number) 180 | b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number) 181 | a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij') 182 | post_fun = fun(a_grid[None, :, :] * x[:, None, None] + b_grid[None, :, :]) 183 | x_mean = torch.mean(post_fun, dim=[0], keepdim=True) 184 | y_mean = torch.mean(y, dim=[0], keepdim=True) 185 | numerator = torch.sum((post_fun - x_mean) * (y - y_mean)[:, None, None], dim=0) ** 2 186 | denominator = torch.sum((post_fun - x_mean) ** 2, dim=0) * torch.sum((y - y_mean)[:, None, None] ** 2, dim=0) 187 | r2 = numerator / (denominator + 1e-4) 188 | r2 = torch.nan_to_num(r2) 189 | 190 | best_id = torch.argmax(r2) 191 | a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number 192 | 193 | if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1: 194 | if _ == 0 and verbose == True: 195 | print('Best value at boundary.') 196 | if a_id == 0: 197 | a_arange = [a_[0], a_[1]] 198 | if a_id == grid_number - 1: 199 | a_arange = [a_[-2], a_[-1]] 200 | if b_id == 0: 201 | b_arange = [b_[0], b_[1]] 202 | if b_id == grid_number - 1: 203 | b_arange = [b_[-2], b_[-1]] 204 | 205 | else: 206 | a_range = [a_[a_id - 1], a_[a_id + 1]] 207 | b_range = [b_[b_id - 1], b_[b_id + 1]] 208 | 209 | a_best = a_[a_id] 210 | b_best = b_[b_id] 211 | post_fun = fun(a_best * x + b_best) 212 | r2_best = r2[a_id, b_id] 213 | 214 | if verbose == True: 215 | print(f"r2 is {r2_best}") 216 | if r2_best < 0.9: 217 | print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.') 218 | 219 | post_fun = torch.nan_to_num(post_fun) 220 | reg = LinearRegression().fit(post_fun[:, None].detach().numpy(), y.detach().numpy()) 221 | c_best = torch.from_numpy(reg.coef_)[0] 222 | d_best = torch.from_numpy(np.array(reg.intercept_)) 223 | return torch.stack([a_best, b_best, c_best, d_best]), r2_best 224 | 225 | 226 | def add_symbolic(name, fun): 227 | ''' 228 | add a symbolic function to library 229 | 230 | Args: 231 | ----- 232 | name : str 233 | name of the function 234 | fun : fun 235 | torch function or lambda function 236 | 237 | Returns: 238 | -------- 239 | None 240 | 241 | Example 242 | ------- 243 | >>> print(SYMBOLIC_LIB['Bessel']) 244 | KeyError: 'Bessel' 245 | >>> add_symbolic('Bessel', torch.special.bessel_j0) 246 | >>> print(SYMBOLIC_LIB['Bessel']) 247 | (, Bessel) 248 | ''' 249 | exec(f"globals()['{name}'] = sympy.Function('{name}')") 250 | SYMBOLIC_LIB[name] = (fun, globals()[name]) 251 | -------------------------------------------------------------------------------- /pic/algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/algorithm.png -------------------------------------------------------------------------------- /pic/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/comparison.png -------------------------------------------------------------------------------- /pic/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/RoPINN/0996b0df232ef2270cc9a97e308a79e5d4cb9cb1/pic/results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | tqdm 3 | einops 4 | deepxde 5 | scipy 6 | scikit-learn 7 | pykan 8 | sympy 9 | torch==1.13.0 10 | -------------------------------------------------------------------------------- /scripts/1d_reaction_point.sh: -------------------------------------------------------------------------------- 1 | python 1d_reaction_point_optimization.py --model PINN --device 'cuda:0' 2 | python 1d_reaction_point_optimization.py --model QRes --device 'cuda:0' 3 | python 1d_reaction_point_optimization.py --model FLS --device 'cuda:0' 4 | python 1d_reaction_point_optimization.py --model KAN --device 'cuda:0' 5 | python 1d_reaction_point_optimization.py --model PINNsFormer --device 'cuda:0' -------------------------------------------------------------------------------- /scripts/1d_reaction_region.sh: -------------------------------------------------------------------------------- 1 | python 1d_reaction_region_optimization.py --model PINN --device 'cuda:0' 2 | python 1d_reaction_region_optimization.py --model QRes --device 'cuda:0' 3 | python 1d_reaction_region_optimization.py --model FLS --device 'cuda:0' 4 | python 1d_reaction_region_optimization.py --model KAN --device 'cuda:0' 5 | python 1d_reaction_region_optimization.py --model PINNsFormer --device 'cuda:0' -------------------------------------------------------------------------------- /scripts/1d_wave_point.sh: -------------------------------------------------------------------------------- 1 | python 1d_wave_point_optimization.py --model PINN --device 'cuda:0' 2 | python 1d_wave_point_optimization.py --model QRes --device 'cuda:0' 3 | python 1d_wave_point_optimization.py --model FLS --device 'cuda:0' 4 | python 1d_wave_point_optimization.py --model KAN --device 'cuda:0' 5 | python 1d_wave_point_optimization.py --model PINNsFormer_Enc_Only --device 'cuda:0' -------------------------------------------------------------------------------- /scripts/1d_wave_region.sh: -------------------------------------------------------------------------------- 1 | python 1d_wave_region_optimization.py --model PINN --device 'cuda:0' 2 | python 1d_wave_region_optimization.py --model QRes --device 'cuda:0' 3 | python 1d_wave_region_optimization.py --model FLS --device 'cuda:0' 4 | python 1d_wave_region_optimization.py --model KAN --device 'cuda:0' 5 | python 1d_wave_region_optimization.py --model PINNsFormer_Enc_Only --device 'cuda:0' -------------------------------------------------------------------------------- /scripts/convection_point.sh: -------------------------------------------------------------------------------- 1 | python convection_point_optimization.py --model PINN --device 'cuda:0' 2 | python convection_point_optimization.py --model QRes --device 'cuda:0' 3 | python convection_point_optimization.py --model FLS --device 'cuda:0' 4 | python convection_point_optimization.py --model KAN --device 'cuda:0' 5 | python convection_point_optimization.py --model PINNsFormer --device 'cuda:0' -------------------------------------------------------------------------------- /scripts/convection_region.sh: -------------------------------------------------------------------------------- 1 | python convection_region_optimization.py --model PINN --device 'cuda:0' 2 | python convection_region_optimization.py --model QRes --device 'cuda:0' 3 | python convection_region_optimization.py --model FLS --device 'cuda:0' 4 | python convection_region_optimization.py --model KAN --device 'cuda:0' # for KAN, the best past_iterations is 15 5 | python convection_region_optimization.py --model PINNsFormer --device 'cuda:0' -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import copy 4 | 5 | 6 | def get_data(x_range, y_range, x_num, y_num): 7 | x = np.linspace(x_range[0], x_range[1], x_num) 8 | t = np.linspace(y_range[0], y_range[1], y_num) 9 | 10 | x_mesh, t_mesh = np.meshgrid(x, t) 11 | data = np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1) 12 | 13 | b_left = data[0, :, :] 14 | b_right = data[-1, :, :] 15 | b_upper = data[:, -1, :] 16 | b_lower = data[:, 0, :] 17 | res = data.reshape(-1, 2) 18 | 19 | return res, b_left, b_right, b_upper, b_lower 20 | 21 | 22 | def get_n_params(model): 23 | pp = 0 24 | for p in list(model.parameters()): 25 | nn = 1 26 | for s in list(p.size()): 27 | nn = nn * s 28 | pp += nn 29 | return pp 30 | 31 | 32 | def make_time_sequence(src, num_step=5, step=1e-4): 33 | dim = num_step 34 | src = np.repeat(np.expand_dims(src, axis=1), dim, axis=1) # (N, L, 2) 35 | for i in range(num_step): 36 | src[:, i, -1] += step * i 37 | return src 38 | 39 | 40 | def make_space_time_sequence(src, space_num_step=5, space_step=1e-4, time_num_step=5, time_step=1e-4): 41 | dim = space_num_step * time_num_step 42 | src = np.repeat(np.expand_dims(src, axis=1), dim, axis=1) # (N, L, 2) 43 | for i in range(time_num_step): 44 | for j in range(space_num_step): 45 | src[:, i * space_num_step + j, -1] += time_step * i 46 | src[:, i * space_num_step + j, 0] += space_step * (j - space_num_step // 2) 47 | return src 48 | 49 | 50 | def get_clones(module, N): 51 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 52 | 53 | 54 | def get_data_3d(x_range, y_range, t_range, x_num, y_num, t_num): 55 | step_x = (x_range[1] - x_range[0]) / float(x_num - 1) 56 | step_y = (y_range[1] - y_range[0]) / float(y_num - 1) 57 | step_t = (t_range[1] - t_range[0]) / float(t_num - 1) 58 | 59 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y, 60 | t_range[0]:t_range[1] + step_t:step_t] 61 | 62 | data = np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1) 63 | res = data.reshape(-1, 3) 64 | 65 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[0] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y, 66 | t_range[0]:t_range[1] + step_t:step_t] 67 | b_left = np.squeeze( 68 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[ 69 | 1:-1].reshape(-1, 3) 70 | 71 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[1]:x_range[1] + step_x:step_x, y_range[0]:y_range[1] + step_y:step_y, 72 | t_range[0]:t_range[1] + step_t:step_t] 73 | b_right = np.squeeze( 74 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[ 75 | 1:-1].reshape(-1, 3) 76 | 77 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[0]:y_range[0] + step_y:step_y, 78 | t_range[0]:t_range[1] + step_t:step_t] 79 | b_lower = np.squeeze( 80 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[ 81 | 1:-1].reshape(-1, 3) 82 | 83 | x_mesh, y_mesh, t_mesh = np.mgrid[x_range[0]:x_range[1] + step_x:step_x, y_range[1]:y_range[1] + step_y:step_y, 84 | t_range[0]:t_range[1] + step_t:step_t] 85 | b_upper = np.squeeze( 86 | np.concatenate((np.expand_dims(x_mesh, -1), np.expand_dims(y_mesh, -1), np.expand_dims(t_mesh, -1)), axis=-1))[ 87 | 1:-1].reshape(-1, 3) 88 | 89 | return res, b_left, b_right, b_upper, b_lower 90 | --------------------------------------------------------------------------------