├── requirements.txt ├── source ├── dataset │ ├── KdV.mat │ ├── Burgers.npz │ ├── Allen_Cahn.mat │ └── CH_C1_02_2.mat ├── visualize.py ├── choose_optimizer.py ├── systems_pbc.py ├── scripts │ └── run.sh ├── utils.py ├── lbfgs.py ├── main.py └── net.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | h5py 3 | matplotlib 4 | tqdm -------------------------------------------------------------------------------- /source/dataset/KdV.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/AdaAFforPINNs/HEAD/source/dataset/KdV.mat -------------------------------------------------------------------------------- /source/dataset/Burgers.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/AdaAFforPINNs/HEAD/source/dataset/Burgers.npz -------------------------------------------------------------------------------- /source/dataset/Allen_Cahn.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/AdaAFforPINNs/HEAD/source/dataset/Allen_Cahn.mat -------------------------------------------------------------------------------- /source/dataset/CH_C1_02_2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/AdaAFforPINNs/HEAD/source/dataset/CH_C1_02_2.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Specialized Activation Functions for Physics-informed Neural Networks 2 | 3 | This repository contains the PyTorch source code for the experiments in the manuscript: 4 | 5 | [Learning Specialized Activation Functions for Physics-informed Neural Networks](https://arxiv.org/abs/2308.04073). 6 | 7 | ## Introduction 8 | 9 | In this work, we reveal the connection between the optimization difficulty of PINNs and activation functions. Specifically, we show that PINNs exhibit high sensitivity to activation functions when solving PDEs with distinct properties. Existing works usually choose activation functions by inefficient trial-and-error. To avoid the inefficient manual selection and to alleviate the optimization difficulty of PINNs, we introduce adaptive activation functions to search for the optimal function when solving different problems. We compare different adaptive activation functions and discuss their limitations in the context of PINNs. Furthermore, we propose to tailor the idea of learning combinations of candidate activation functions to the PINNs optimization, which has a higher requirement for the smoothness and diversity on learned functions. This is achieved by removing activation functions which cannot provide higher-order derivatives from the candidate set and incorporating elementary functions with different properties according to our prior knowledge about the PDE at hand. We further enhance the search space with adaptive slopes. The proposed adaptive activation function can be used to solve different PDE systems in an interpretable way. Its effectiveness is demonstrated on a series of benchmarks. 10 | 11 | ## Installation 12 | 13 | ``` 14 | git clone git@github.com:LeapLabTHU/AdaAFforPINNs.git 15 | cd AdaAFforPINNs 16 | conda create -n myenv python=3.9 17 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge 18 | conda activate myenv 19 | conda install --file=requirements.txt 20 | ``` 21 | 22 | ## Instructions 23 | 24 | Training on the convection, Allen-Cahn, KdV, or Cahn-Hilliard equations. 25 | ``` 26 | cd source 27 | bash scripts/run.sh 28 | ``` 29 | 30 | 43 | 44 | ## Contact 45 | If you have any question, please feel free to contact the authors. Honghui Wang: wanghh20@mails.tsinghua.edu.cn. 46 | 47 | ## Acknowledgments 48 | This codebase is built on the repository of [Characterizing possible failure modes in physics-informed neural networks](https://github.com/a1k12/characterizing-pinns-failure-modes). Please consider citing or starring the project. 49 | -------------------------------------------------------------------------------- /source/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize outputs. 3 | """ 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from mpl_toolkits.axes_grid1 import make_axes_locatable 7 | import matplotlib.gridspec as gridspec 8 | 9 | def exact_u(Exact, x, t, nu, beta, rho, layers, N_f, L, source, u0_str, system, path): 10 | """Visualize exact solution.""" 11 | fig = plt.figure(figsize=(9, 6)) 12 | ax = fig.add_subplot(111) 13 | 14 | h = ax.imshow(Exact.T, interpolation='nearest', cmap='rainbow', 15 | extent=[t.min(), t.max(), x.min(), x.max()], 16 | origin='lower', aspect='auto') 17 | divider = make_axes_locatable(ax) 18 | cax = divider.append_axes("right", size="5%", pad=0.10) 19 | cbar = fig.colorbar(h, cax=cax) 20 | cbar.ax.tick_params(labelsize=15) 21 | 22 | line = np.linspace(x.min(), x.max(), 2)[:,None] 23 | 24 | ax.set_xlabel('t', fontweight='bold', size=30) 25 | ax.set_ylabel('x', fontweight='bold', size=30) 26 | ax.legend( 27 | loc='upper center', 28 | bbox_to_anchor=(0.9, -0.05), 29 | ncol=5, 30 | frameon=False, 31 | prop={'size': 15} 32 | ) 33 | 34 | ax.tick_params(labelsize=15) 35 | 36 | plt.savefig(f"{path}/exactu_{system}_nu{nu}_beta{beta}_rho{rho}_Nf{N_f}_{layers}_L{L}_source{source}_{u0_str}.png") 37 | plt.close() 38 | 39 | return None 40 | 41 | def u_diff(X_f_train, Exact, U_pred, x, t, nu, beta, rho, seed, layers, N_f, L, source, lr, u0_str, system, path, relative_error = False): 42 | """Visualize abs(u_pred - u_exact).""" 43 | 44 | fig = plt.figure(figsize=(9, 6)) 45 | ax = fig.add_subplot(111) 46 | 47 | if relative_error: 48 | h = ax.imshow(np.abs(Exact.T - U_pred.T)/np.abs(Exact.T), interpolation='nearest', cmap='binary', 49 | extent=[t.min(), t.max(), x.min(), x.max()], 50 | origin='lower', aspect='auto') 51 | else: 52 | h = ax.imshow(np.abs(Exact.T - U_pred.T), interpolation='nearest', cmap='binary', 53 | extent=[t.min(), t.max(), x.min(), x.max()], 54 | origin='lower', aspect='auto') 55 | divider = make_axes_locatable(ax) 56 | cax = divider.append_axes("right", size="5%", pad=0.10) 57 | cbar = fig.colorbar(h, cax=cax) 58 | cbar.ax.tick_params(labelsize=15) 59 | 60 | ax.scatter(X_f_train[:, 1], X_f_train[:, 0]) 61 | # ax.scatter(X_f_train[:int(0.9*N_f), 1], X_f_train[:int(0.9*N_f), 0]) 62 | # ax.scatter(X_f_train[int(0.9*N_f):, 1], X_f_train[int(0.9*N_f):, 0], color='red') 63 | line = np.linspace(x.min(), x.max(), 2)[:,None] 64 | 65 | ax.set_xlabel('t', fontweight='bold', size=30) 66 | ax.set_ylabel('x', fontweight='bold', size=30) 67 | 68 | ax.legend( 69 | loc='upper center', 70 | bbox_to_anchor=(0.9, -0.05), 71 | ncol=5, 72 | frameon=False, 73 | prop={'size': 15} 74 | ) 75 | 76 | ax.tick_params(labelsize=15) 77 | 78 | plt.savefig(f"{path}/udiff_{system}_nu{nu}_beta{beta}_rho{rho}_Nf{N_f}_{layers}_L{L}_seed{seed}_source{source}_{u0_str}_lr{lr}.png") 79 | 80 | return None 81 | 82 | def u_predict(u_vals, U_pred, x, t, nu, beta, rho, seed, layers, N_f, L, source, lr, u0_str, system, path): 83 | """Visualize u_predicted.""" 84 | 85 | fig = plt.figure(figsize=(9, 6)) 86 | ax = fig.add_subplot(111) 87 | 88 | # colorbar for prediction: set min/max to ground truth solution. 89 | h = ax.imshow(U_pred.T, interpolation='nearest', cmap='rainbow', 90 | extent=[t.min(), t.max(), x.min(), x.max()], 91 | origin='lower', aspect='auto', vmin=u_vals.min(), vmax=u_vals.max()) 92 | divider = make_axes_locatable(ax) 93 | cax = divider.append_axes("right", size="5%", pad=0.10) 94 | cbar = fig.colorbar(h, cax=cax) 95 | cbar.ax.tick_params(labelsize=15) 96 | 97 | line = np.linspace(x.min(), x.max(), 2)[:,None] 98 | 99 | ax.set_xlabel('t', fontweight='bold', size=30) 100 | ax.set_ylabel('x', fontweight='bold', size=30) 101 | 102 | ax.legend( 103 | loc='upper center', 104 | bbox_to_anchor=(0.9, -0.05), 105 | ncol=5, 106 | frameon=False, 107 | prop={'size': 15} 108 | ) 109 | 110 | ax.tick_params(labelsize=15) 111 | 112 | plt.savefig(f"{path}/upredicted_{system}_nu{nu}_beta{beta}_rho{rho}_Nf{N_f}_{layers}_L{L}_seed{seed}_source{source}_{u0_str}_lr{lr}.png") 113 | 114 | plt.close() 115 | return None 116 | -------------------------------------------------------------------------------- /source/choose_optimizer.py: -------------------------------------------------------------------------------- 1 | """Optimizer choices.""" 2 | 3 | import torch 4 | import numpy as np 5 | from lbfgs import LBFGS 6 | 7 | def choose_optimizer(optimizer_name: str, *params, **kwargs): 8 | # print(params) 9 | if optimizer_name == 'LBFGS': 10 | return LBFGS_fn(*params, **kwargs) 11 | elif optimizer_name == 'AdaHessian': 12 | return AdaHessian(*params) 13 | elif optimizer_name == 'Shampoo': 14 | return Shampoo(*params) 15 | elif optimizer_name == 'Yogi': 16 | return Yogi(*params) 17 | elif optimizer_name == 'Apollo': 18 | return Apollo(*params) 19 | elif optimizer_name == 'Adam': 20 | return Adam(*params, **kwargs) 21 | elif optimizer_name == 'AdamW': 22 | return AdamW(*params, **kwargs) 23 | elif optimizer_name == 'SGD': 24 | return SGD(*params) 25 | 26 | def LBFGS_fn(model_param, 27 | lr=1.0, 28 | max_iter=15000, 29 | max_eval=None, 30 | history_size=100, 31 | tolerance_grad=1e-8, 32 | tolerance_change=0, 33 | line_search_fn='strong_wolfe'): 34 | print(lr) 35 | print(line_search_fn) 36 | optimizer = LBFGS( 37 | model_param, 38 | lr=lr, 39 | max_iter=max_iter, 40 | max_eval=max_eval, 41 | history_size=history_size, 42 | tolerance_grad=tolerance_grad, 43 | tolerance_change=tolerance_change, 44 | line_search_fn=line_search_fn 45 | ) 46 | 47 | return optimizer 48 | 49 | def Adam(model_param, lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False): 50 | print(lr) 51 | optimizer = torch.optim.Adam( 52 | model_param, 53 | lr=lr, 54 | betas=betas, 55 | eps=eps, 56 | weight_decay=weight_decay, 57 | amsgrad=amsgrad 58 | ) 59 | return optimizer 60 | 61 | def AdamW(model_param, lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False): 62 | print(lr) 63 | print(weight_decay) 64 | print(betas) 65 | print(eps) 66 | print(amsgrad) 67 | optimizer = torch.optim.AdamW( 68 | model_param, 69 | lr=lr, 70 | betas=betas, 71 | eps=eps, 72 | weight_decay=weight_decay, 73 | amsgrad=amsgrad 74 | ) 75 | return optimizer 76 | 77 | 78 | def SGD(model_param, lr=1e-4, momentum=0.9, dampening=0, weight_decay=0, nesterov=False): 79 | 80 | optimizer = torch.optim.SGD( 81 | model_param, 82 | lr=lr, 83 | momentum=momentum, 84 | dampening=dampening, 85 | weight_decay=weight_decay, 86 | nesterov=False 87 | ) 88 | 89 | return optimizer 90 | 91 | def AdaHessian(model_param, lr=1.0, betas=(0.9, 0.999), 92 | eps=1e-4, weight_decay=0.0, hessian_power=0.5): 93 | """ 94 | Arguments: 95 | params (iterable): iterable of parameters to optimize or dicts defining 96 | parameter groups 97 | lr (float, optional): learning rate (default: 0.15) 98 | betas (Tuple[float, float], optional): coefficients used for computing 99 | running averages of gradient and its square (default: (0.9, 0.999)) 100 | eps (float, optional): term added to the denominator to improve 101 | numerical stability (default: 1e-4) 102 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 103 | hessian_power (float, optional): Hessian power (default: 0.5) 104 | """ 105 | 106 | optimizer = Adahessian(model_param, 107 | lr=lr, 108 | betas=betas, 109 | eps=eps, 110 | weight_decay=weight_decay, 111 | hessian_power=hessian_power, 112 | single_gpu=False) 113 | 114 | return optimizer 115 | 116 | def Shampoo(model_param, lr=1e-1, momentum=0.0, weight_decay=0.0, 117 | epsilon=1e-4, update_freq=1): 118 | """ 119 | Args: 120 | params: params of model 121 | lr: learning rate 122 | momentum: momentum factor 123 | weight_decay: weight decay (L2 penalty) 124 | epsilon: epsilon added to each mat_gbar_j for numerical stability 125 | update_freq: update frequency to compute inverse 126 | """ 127 | optimizer = optim.Shampoo(model_param, 128 | lr=lr, 129 | momentum=momentum, 130 | weight_decay=weight_decay, 131 | epsilon=epsilon, 132 | update_freq=update_freq) 133 | 134 | return optimizer 135 | 136 | def Yogi(model_param, lr=1e-2, betas=(0.9, 0.999), eps=1e-3, initial_accumulator=1e-6, 137 | weight_decay=0): 138 | 139 | optimizer = optim.Yogi(model_param, 140 | lr=lr, 141 | betas=betas, 142 | eps=eps, 143 | initial_accumulator=initial_accumulator, 144 | weight_decay=weight_decay) 145 | 146 | return optimizer 147 | 148 | def Apollo(model_param, lr=1e-2, beta=0.9, eps=1e-4, warmup=5, init_lr=0.01, weight_decay=0): 149 | """Apollo already includes warmup! 150 | 151 | Arguments: 152 | params: iterable of parameters to optimize or dicts defining 153 | parameter groups 154 | lr: learning rate (default: 1e-2) 155 | beta: coefficient used for computing 156 | running averages of gradient (default: 0.9) 157 | eps: term added to the denominator to improve 158 | numerical stability (default: 1e-4) 159 | warmup: number of warmup steps (default: 5) 160 | init_lr: initial learning rate for warmup (default: 0.01) 161 | weight_decay: weight decay (L2 penalty) (default: 0) 162 | """ 163 | 164 | optimizer = optim.Apollo(model_param, 165 | lr=lr, 166 | beta=beta, 167 | eps=eps, 168 | warmup=warmup, 169 | init_lr=init_lr, 170 | weight_decay=weight_decay) 171 | return optimizer 172 | -------------------------------------------------------------------------------- /source/systems_pbc.py: -------------------------------------------------------------------------------- 1 | """Pick a system to study here for Poisson's/diffusion.""" 2 | import numpy as np 3 | import torch 4 | import torch.fft 5 | 6 | def function(u0: str): 7 | """Initial condition, string --> function.""" 8 | 9 | if u0 == 'sin(x)': 10 | u0 = lambda x: np.sin(x) 11 | elif u0 == 'sin(pix)': 12 | u0 = lambda x: np.sin(np.pi*x) 13 | elif u0 == 'sin^2(x)': 14 | u0 = lambda x: np.sin(x)**2 15 | elif u0 == 'sin(x)cos(x)': 16 | u0 = lambda x: np.sin(x)*np.cos(x) 17 | elif u0 == '0.1sin(x)': 18 | u0 = lambda x: 0.1*np.sin(x) 19 | elif u0 == '0.5sin(x)': 20 | u0 = lambda x: 0.5*np.sin(x) 21 | elif u0 == '10sin(x)': 22 | u0 = lambda x: 10*np.sin(x) 23 | elif u0 == '50sin(x)': 24 | u0 = lambda x: 50*np.sin(x) 25 | elif u0 == '1+sin(x)': 26 | u0 = lambda x: 1 + np.sin(x) 27 | elif u0 == '2+sin(x)': 28 | u0 = lambda x: 2 + np.sin(x) 29 | elif u0 == '6+sin(x)': 30 | u0 = lambda x: 6 + np.sin(x) 31 | elif u0 == '10+sin(x)': 32 | u0 = lambda x: 10 + np.sin(x) 33 | elif u0 == 'sin(2x)': 34 | u0 = lambda x: np.sin(2*x) 35 | elif u0 == 'tanh(x)': 36 | u0 = lambda x: np.tanh(x) 37 | elif u0 == '2x': 38 | u0 = lambda x: 2*x 39 | elif u0 == 'x^2': 40 | u0 = lambda x: x**2 41 | elif u0 == 'gauss': 42 | x0 = np.pi 43 | sigma = np.pi/4 44 | u0 = lambda x: np.exp(-np.power((x - x0)/sigma, 2.)/2.) 45 | return u0 46 | 47 | def reaction(u, rho, dt): 48 | """ du/dt = rho*u*(1-u) 49 | """ 50 | factor_1 = u * np.exp(rho * dt) 51 | factor_2 = (1 - u) 52 | u = factor_1 / (factor_2 + factor_1) 53 | return u 54 | 55 | def diffusion(u, nu, dt, IKX2): 56 | """ du/dt = nu*d2u/dx2 57 | """ 58 | factor = np.exp(nu * IKX2 * dt) 59 | u_hat = np.fft.fft(u) 60 | u_hat *= factor 61 | u = np.real(np.fft.ifft(u_hat)) 62 | return u 63 | 64 | def reaction_solution(u0: str, rho, nx=256, nt=100): 65 | L = 2*np.pi 66 | T = 1 67 | dx = L/nx 68 | dt = T/nt 69 | x = np.arange(0, 2*np.pi, dx) 70 | t = np.linspace(0, T, nt).reshape(-1, 1) 71 | X, T = np.meshgrid(x, t) 72 | 73 | # call u0 this way so array is (n, ), so each row of u should also be (n, ) 74 | u0 = function(u0) 75 | u0 = u0(x) 76 | 77 | u = reaction(u0, rho, T) 78 | 79 | u = u.flatten() 80 | return u 81 | 82 | def reaction_diffusion_discrete_solution(u0 : str, nu, rho, nx = 256, nt = 100): 83 | """ Computes the discrete solution of the reaction-diffusion PDE using 84 | pseudo-spectral operator splitting. 85 | Args: 86 | u0: initial condition 87 | nu: diffusion coefficient 88 | rho: reaction coefficient 89 | nx: size of x-tgrid 90 | nt: number of points in the t grid 91 | Returns: 92 | u: solution 93 | """ 94 | L = 2*np.pi 95 | T = 1 96 | dx = L/nx 97 | dt = T/nt 98 | x = np.arange(0, L, dx) # not inclusive of the last point 99 | t = np.linspace(0, T, nt).reshape(-1, 1) 100 | X, T = np.meshgrid(x, t) 101 | u = np.zeros((nx, nt)) 102 | 103 | IKX_pos = 1j * np.arange(0, nx/2+1, 1) 104 | IKX_neg = 1j * np.arange(-nx/2+1, 0, 1) 105 | IKX = np.concatenate((IKX_pos, IKX_neg)) 106 | IKX2 = IKX * IKX 107 | 108 | # call u0 this way so array is (n, ), so each row of u should also be (n, ) 109 | u0 = function(u0) 110 | u0 = u0(x) 111 | 112 | u[:,0] = u0 113 | u_ = u0 114 | for i in range(nt-1): 115 | u_ = reaction(u_, rho, dt) 116 | u_ = diffusion(u_, nu, dt, IKX2) 117 | u[:,i+1] = u_ 118 | 119 | u = u.T 120 | u = u.flatten() 121 | return u 122 | 123 | def convection_diffusion(u0: str, nu, beta, source=0, xgrid=256, nt=100): 124 | """Calculate the u solution for convection/diffusion, assuming PBCs. 125 | Args: 126 | u0: Initial condition 127 | nu: viscosity coefficient 128 | beta: wavespeed coefficient 129 | source: q (forcing term), option to have this be a constant 130 | xgrid: size of the x grid 131 | Returns: 132 | u_vals: solution 133 | """ 134 | 135 | N = xgrid 136 | h = 2 * np.pi / N 137 | x = np.arange(0, 2*np.pi, h) # not inclusive of the last point 138 | t = np.linspace(0, 1, nt).reshape(-1, 1) 139 | X, T = np.meshgrid(x, t) 140 | 141 | # call u0 this way so array is (n, ), so each row of u should also be (n, ) 142 | u0 = function(u0) 143 | u0 = u0(x) 144 | 145 | G = (np.copy(u0)*0)+source # G is the same size as u0 146 | 147 | IKX_pos =1j * np.arange(0, N/2+1, 1) 148 | IKX_neg = 1j * np.arange(-N/2+1, 0, 1) 149 | IKX = np.concatenate((IKX_pos, IKX_neg)) 150 | IKX2 = IKX * IKX 151 | 152 | uhat0 = np.fft.fft(u0) 153 | nu_factor = np.exp(nu * IKX2 * T - beta * IKX * T) 154 | A = uhat0 - np.fft.fft(G)*0 # at t=0, second term goes away 155 | uhat = A*nu_factor + np.fft.fft(G)*T # for constant, fft(p) dt = fft(p)*T 156 | u = np.real(np.fft.ifft(uhat)) 157 | 158 | u_vals = u.flatten() 159 | return u_vals 160 | 161 | def convection_diffusion_u0(u0, nu, beta, source=0, xgrid=256, nt=100): 162 | """Calculate the u solution for convection/diffusion, assuming PBCs. 163 | Args: 164 | u0: Initial condition 165 | nu: viscosity coefficient 166 | beta: wavespeed coefficient 167 | source: q (forcing term), option to have this be a constant 168 | xgrid: size of the x grid 169 | Returns: 170 | u_vals: solution 171 | """ 172 | 173 | N = xgrid 174 | h = 2 * np.pi / N 175 | x = np.arange(0, 2*np.pi, h) # not inclusive of the last point 176 | t = np.linspace(0, 1, nt).reshape(-1, 1) 177 | X, T = np.meshgrid(x, t) 178 | 179 | # call u0 this way so array is (n, ), so each row of u should also be (n, ) 180 | # u0 = function(u0) 181 | # u0 = u0(x) 182 | 183 | G = (np.copy(u0)*0)+source # G is the same size as u0 184 | 185 | IKX_pos =1j * np.arange(0, N/2+1, 1) 186 | IKX_neg = 1j * np.arange(-N/2+1, 0, 1) 187 | IKX = np.concatenate((IKX_pos, IKX_neg)) 188 | IKX2 = IKX * IKX 189 | 190 | uhat0 = np.fft.fft(u0) 191 | nu_factor = np.exp(nu * IKX2 * T - beta * IKX * T) 192 | A = uhat0 - np.fft.fft(G)*0 # at t=0, second term goes away 193 | uhat = A*nu_factor + np.fft.fft(G)*T # for constant, fft(p) dt = fft(p)*T 194 | u = np.real(np.fft.ifft(uhat)) 195 | 196 | u_vals = u.flatten() 197 | return u_vals 198 | -------------------------------------------------------------------------------- /source/scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # convection 4 | 5 | CUDA_VISIBLE_DEVICES=0, python -u main.py --visualize True --system convection --beta 64.0 --gpu --xgrid 512 --nt 200 --N_f 6400 --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --work_dir rep_convection_act_stable --sub_name 'wo_adaptive_slope' --layers 64,64,64,64,64,1 --plot_loss --save_model True --seed 111 --init --adam_lr 2e-3 --epoch 100000 --activation sin --repeat 5 --start_repeat 0 --sample_type grid --N_f_x 64 --N_f_t 100 --fix_sample --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1.0 --coeff_lr_first_layer 1.0 --lr_first_layer 2e-3 --lr_second_layer 2e-3 --tau 1.0 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer sgd --momentum 0.75 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --port 39080 --clip 100.0 --disable_lbfgs & 6 | 7 | sleep 10s 8 | 9 | CUDA_VISIBLE_DEVICES=1, python -u main.py --visualize True --system convection --beta 64.0 --gpu --xgrid 512 --nt 200 --N_f 6400 --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --work_dir rep_convection_act_stable --sub_name 'w_adaptive_slope' --layers 64,64,64,64,64,1 --plot_loss --save_model True --seed 111 --init --adam_lr 2e-3 --epoch 100000 --activation sin --repeat 5 --start_repeat 0 --sample_type grid --N_f_x 64 --N_f_t 100 --fix_sample --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1e-3 --coeff_lr_first_layer 1e-3 --lr_first_layer 2e-3 --lr_second_layer 2e-3 --tau 1.0 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.9 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --port 39080 --clip 100.0 --disable_lbfgs --enable_scaling & 10 | 11 | sleep 10s 12 | 13 | 14 | # AC 15 | 16 | CUDA_VISIBLE_DEVICES=2, python -u main.py --visualize True --system AC --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/Allen_Cahn.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 201 --nt 101 --N_f 8000 --work_dir rep_AC_act_stable_adam --sub_name 'wo_adaptive_slope' --layers 32,32,32,1 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 40000 --activation sin --sample_type interval --repeat 5 --L_u 0.0 --L_b 0.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1.6e-2 --coeff_lr_first_layer 1.6e-2 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.9 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --hard_ibc & 17 | 18 | 19 | sleep 10s 20 | 21 | CUDA_VISIBLE_DEVICES=3, python -u main.py --visualize True --system AC --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/Allen_Cahn.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 201 --nt 101 --N_f 8000 --work_dir rep_AC_act_stable_adam --sub_name 'w_adaptive_slope' --layers 32,32,32,1 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 40000 --activation sin --sample_type interval --repeat 5 --L_u 0.0 --L_b 0.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 5e-4 --coeff_lr_first_layer 5e-4 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.3 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --hard_ibc --enable_scaling --scaler 1.0 & 22 | 23 | sleep 10s 24 | 25 | wait 26 | 27 | # KdV 28 | 29 | 30 | CUDA_VISIBLE_DEVICES=0, python -u main.py --visualize True --system KdV --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/KdV.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 513 --nt 201 --N_f 8000 --work_dir rep_KdV_act_stable_adam_hard_2_ibc --sub_name 'wo_adaptive_slope' --layers 32,32,32,1 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 40000 --activation sin --sample_type interval --repeat 5 --L_u 0.0 --L_b 1.0 --L_f 1.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --extra_N_f 0 --range 0.1 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1e-3 --coeff_lr_first_layer 1e-3 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.9 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --max_iter 15000 --hard_ibc & 31 | 32 | sleep 10s 33 | 34 | CUDA_VISIBLE_DEVICES=1, python -u main.py --visualize True --system KdV --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/KdV.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 513 --nt 201 --N_f 8000 --work_dir rep_KdV_act_stable_adam_hard_2_ibc --sub_name 'w_adaptive_slope' --layers 32,32,32,1 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 40000 --activation sin --sample_type interval --repeat 5 --L_u 0.0 --L_b 1.0 --L_f 1.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --extra_N_f 0 --range 0.1 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1e-2 --coeff_lr_first_layer 1e-2 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.99 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --max_iter 15000 --hard_ibc --enable_scaling & 35 | 36 | sleep 10s 37 | 38 | # CH 39 | 40 | CUDA_VISIBLE_DEVICES=2, python -u main.py --visualize True --system CH --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/CH_C1_02_2.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 513 --nt 201 --N_f 12000 --work_dir rep_CH_act_decouple_1_fine_grid_adam_100k --sub_name 'wo_adaptive_slope' --layers 32,32,32,2 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 100000 --activation sin --sample_type grid --N_f_x 80 --N_f_t 100 --repeat 5 --L_u 100.0 --L_b 1.0 --L_f 1.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 5e-4 --coeff_lr_first_layer 5e-4 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.5 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --max_iter 15000 --decouple --four_order --fine_grid & 41 | 42 | sleep 10s 43 | 44 | CUDA_VISIBLE_DEVICES=3, python -u main.py --visualize True --system CH --beta 0.0 --rho 0.0 --nu 0.0 --data_path './dataset/CH_C1_02_2.mat' --exp_dir '/cluster/home2/whh/workspace/pinn/exp' --gpu --xgrid 513 --nt 201 --N_f 12000 --work_dir rep_CH_act_decouple_1_fine_grid_adam_100k --sub_name 'w_adaptive_slope' --layers 32,32,32,2 --plot_loss --save_model True --seed 111 --adam_lr 1e-3 --init --epoch 100000 --activation sin --sample_type grid --N_f_x 80 --N_f_t 100 --repeat 5 --L_u 100.0 --L_b 1.0 --L_f 1.0 --line_search_fn strong_wolfe --fix_sample --port 29064 --clip 100.0 --linearpool --poolsize '0,1,2,3,4' --aggregate softmax --weight_sharing --coeff_lr 1e-3 --coeff_lr_first_layer 1e-3 --lr_first_layer 1e-3 --lr_second_layer 1e-3 --cosine_decay --warm_up_iter 1000 --sep_cosine_decay --sep_warm_up_iter 1000 --sep_optim --sep_optimizer adam --coeff_beta1 0.9 --coeff_beta2 0.9 --l2_reg 0.0 --weight_decay 0.0 --coeff_weight_decay 0.0 --max_iter 15000 --decouple --four_order --fine_grid --enable_scaling & 45 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.distributed as dist 5 | 6 | import os 7 | import time 8 | from pathlib import Path 9 | import logging 10 | import subprocess 11 | 12 | def set_seed(seed): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(seed) 18 | # random.seed(seed) 19 | # np.random.seed(seed) 20 | # torch.manual_seed(seed) 21 | # torch.cuda.manual_seed(seed) 22 | # torch.cuda.manual_seed_all(seed) 23 | # os.environ['PYTHONHASHSEED'] = str(seed) 24 | 25 | def sample_random_type(X_all, N, extra_N=0, range=0): 26 | # def sample_random_type(X_all, N): 27 | """Given an array of (x,t) points, sample N points from this.""" 28 | # set_seed(seed) # this can be fixed for all N_f 29 | if isinstance(X_all, dict): 30 | point_dim = [] 31 | lb = X_all['lb'] 32 | ub = X_all['ub'] 33 | for lb_i, ub_i in zip(lb, ub): 34 | point_dim.append(np.random.uniform(low=lb_i, high=ub_i, size=(N,1))) 35 | X_sampled = np.hstack(point_dim) 36 | if extra_N > 0: 37 | extra_point = [] 38 | extra_point.append(np.random.uniform(low=-1.0*range, high=range, size=(extra_N,1))) 39 | extra_point.append(np.random.uniform(low=0, high=1, size=(extra_N,1))) 40 | extra_X_sampled = np.hstack(extra_point) 41 | # import pdb 42 | # pdb.set_trace() 43 | X_sampled = np.vstack([X_sampled, extra_X_sampled]) 44 | idx_sorted = np.argsort(X_sampled[:, -1]) 45 | 46 | return X_sampled[idx_sorted], None 47 | else: 48 | # print(X_all.shape[0], N) 49 | idx = np.random.choice(X_all.shape[0], N, replace=False) 50 | X_sampled = X_all[idx, :] 51 | idx_sorted = np.argsort(X_sampled[:, -1]) 52 | 53 | return X_sampled[idx_sorted], idx[idx_sorted] 54 | 55 | def sample_random(X_all, N): 56 | """Given an array of (x,t) points, sample N points from this.""" 57 | # set_seed(seed) # this can be fixed for all N_f 58 | 59 | idx = np.random.choice(X_all.shape[0], N, replace=False) 60 | X_sampled = X_all[idx, :] 61 | 62 | return X_sampled, idx 63 | 64 | def sample_random_interval(lb=[0], ub=[1], N=512): 65 | """Given an array of (x,t) points, sample N points from this.""" 66 | # set_seed(seed) # this can be fixed for all N_f 67 | point_dim = [] 68 | for lb_i, ub_i in zip(lb, ub): 69 | point_dim.append(np.random.uniform(low=lb_i, high=ub_i, size=(N,1))) 70 | 71 | return np.hstack(point_dim) 72 | 73 | def set_activation(activation): 74 | if activation == 'identity': 75 | return nn.Identity() 76 | elif activation == 'tanh': 77 | return nn.Tanh() 78 | elif activation == 'relu': 79 | return nn.ReLU() 80 | elif activation == 'gelu': 81 | return nn.GELU() 82 | else: 83 | print("WARNING: unknown activation function!") 84 | return -1 85 | 86 | logger_initialized = {} 87 | 88 | 89 | def init_environ(cfg): 90 | # build work dir 91 | exp_name = cfg.name # or config name 92 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 93 | # work_dir = os.path.join('./exp', exp_name, timestamp+'_'+cfg.sub_name) 94 | work_dir = os.path.join(cfg.exp_dir, exp_name, timestamp+'_'+cfg.sub_name) 95 | Path(work_dir).mkdir(parents=True, exist_ok=True) 96 | cfg.work_dir = work_dir 97 | 98 | # init distributed parallel 99 | if cfg.launcher == 'slurm': 100 | _init_dist_slurm('nccl', cfg, cfg.port) 101 | # else: 102 | # raise NotImplementedError(f'launcher {cfg.launcher} has not been implemented.') 103 | 104 | # create logger 105 | log_file = os.path.join(work_dir, 'log.txt') 106 | logger = get_logger('search', log_file) 107 | cfg.log_file = log_file 108 | # set random seed 109 | # if cfg.seed is not None: 110 | # set_random_seed(cfg.seed) 111 | # logger.info(f'set random seed to {cfg.seed}') 112 | 113 | return logger 114 | 115 | 116 | def _init_dist_slurm(backend, cfg, port=None): 117 | """Initialize slurm distributed training environment. 118 | If argument ``port`` is not specified, then the master port will be system 119 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 120 | environment variable, then a default port ``29500`` will be used. 121 | Args: 122 | backend (str): Backend of torch.distributed. 123 | port (int, optional): Master port. Defaults to None. 124 | """ 125 | proc_id = int(os.environ['SLURM_PROCID']) 126 | ntasks = int(os.environ['SLURM_NTASKS']) 127 | node_list = os.environ['SLURM_NODELIST'] 128 | num_gpus = torch.cuda.device_count() 129 | torch.cuda.set_device(proc_id % num_gpus) 130 | addr = subprocess.getoutput( 131 | f'scontrol show hostname {node_list} | head -n1') 132 | # specify master port 133 | if port is not None: 134 | os.environ['MASTER_PORT'] = str(port) 135 | elif 'MASTER_PORT' in os.environ: 136 | pass # use MASTER_PORT in the environment variable 137 | else: 138 | # 29500 is torch.distributed default port 139 | os.environ['MASTER_PORT'] = '29500' 140 | # use MASTER_ADDR in the environment variable if it already exists 141 | if 'MASTER_ADDR' not in os.environ: 142 | os.environ['MASTER_ADDR'] = addr 143 | os.environ['WORLD_SIZE'] = str(ntasks) 144 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 145 | os.environ['RANK'] = str(proc_id) 146 | cfg.world_size = ntasks 147 | cfg.gpu_id = proc_id % num_gpus 148 | cfg.rank = proc_id 149 | 150 | dist.init_process_group(backend=backend) 151 | print(f'Distributed training on {proc_id}/{ntasks}') 152 | 153 | 154 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 155 | """Initialize and get a logger by name. 156 | If the logger has not been initialized, this method will initialize the 157 | logger by adding one or two handlers, otherwise the initialized logger will 158 | be directly returned. During initialization, a StreamHandler will always be 159 | added. If `log_file` is specified and the process rank is 0, a FileHandler 160 | will also be added. 161 | Args: 162 | name (str): Logger name. 163 | log_file (str | None): The log filename. If specified, a FileHandler 164 | will be added to the logger. 165 | log_level (int): The logger level. Note that only the process of 166 | rank 0 is affected, and other processes will set the level to 167 | "Error" thus be silent most of the time. 168 | file_mode (str): The file mode used in opening log file. 169 | Defaults to 'w'. 170 | Returns: 171 | logging.Logger: The expected logger. 172 | """ 173 | logger = logging.getLogger(name) 174 | if name in logger_initialized: 175 | return logger 176 | # handle hierarchical names 177 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 178 | # initialization since it is a child of "a". 179 | for logger_name in logger_initialized: 180 | if name.startswith(logger_name): 181 | return logger 182 | 183 | stream_handler = logging.StreamHandler() 184 | handlers = [stream_handler] 185 | 186 | if dist.is_available() and dist.is_initialized(): 187 | rank = dist.get_rank() 188 | else: 189 | rank = 0 190 | 191 | # only rank 0 will add a FileHandler 192 | if rank == 0 and log_file is not None: 193 | # Here, the default behaviour of the official logger is 'a'. Thus, we 194 | # provide an interface to change the file mode to the default 195 | # behaviour. 196 | file_handler = logging.FileHandler(log_file, file_mode) 197 | handlers.append(file_handler) 198 | 199 | formatter = logging.Formatter( 200 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 201 | for handler in handlers: 202 | handler.setFormatter(formatter) 203 | handler.setLevel(log_level) 204 | logger.addHandler(handler) 205 | 206 | if rank == 0: 207 | logger.setLevel(log_level) 208 | else: 209 | logger.setLevel(logging.ERROR) 210 | 211 | logger_initialized[name] = True 212 | logger.propagate = False 213 | 214 | return logger 215 | 216 | 217 | def set_random_seed(seed, deterministic=False, use_rank_shift=False): 218 | """Set random seed. 219 | Args: 220 | seed (int): Seed to be used. 221 | deterministic (bool): Whether to set the deterministic option for 222 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 223 | to True and `torch.backends.cudnn.benchmark` to False. 224 | Default: False. 225 | rank_shift (bool): Whether to add rank number to the random seed to 226 | have different random seed in different threads. Default: False. 227 | """ 228 | if use_rank_shift: 229 | rank = dist.get_rank() 230 | seed += rank 231 | random.seed(seed) 232 | np.random.seed(seed) 233 | torch.manual_seed(seed) 234 | torch.cuda.manual_seed(seed) 235 | torch.cuda.manual_seed_all(seed) 236 | os.environ['PYTHONHASHSEED'] = str(seed) 237 | if deterministic: 238 | torch.backends.cudnn.deterministic = True 239 | torch.backends.cudnn.benchmark = False -------------------------------------------------------------------------------- /source/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 7 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 8 | # Compute bounds of interpolation area 9 | if bounds is not None: 10 | xmin_bound, xmax_bound = bounds 11 | else: 12 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 13 | 14 | # Code for most common case: cubic interpolation of 2 points 15 | # w/ function and derivative values for both 16 | # Solution in this case (where x2 is the farthest point): 17 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 18 | # d2 = sqrt(d1^2 - g1*g2); 19 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 20 | # t_new = min(max(min_pos,xmin_bound),xmax_bound); 21 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 22 | d2_square = d1**2 - g1 * g2 23 | if d2_square >= 0: 24 | d2 = d2_square.sqrt() 25 | if x1 <= x2: 26 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 27 | else: 28 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 29 | return min(max(min_pos, xmin_bound), xmax_bound) 30 | else: 31 | return (xmin_bound + xmax_bound) / 2. 32 | 33 | 34 | def _strong_wolfe(obj_func, 35 | x, 36 | t, 37 | d, 38 | f, 39 | g, 40 | gtd, 41 | c1=1e-4, 42 | c2=0.9, 43 | tolerance_change=1e-9, 44 | max_ls=25): 45 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 46 | d_norm = d.abs().max() 47 | g = g.clone(memory_format=torch.contiguous_format) 48 | # evaluate objective and gradient using initial step 49 | f_new, g_new = obj_func(x, t, d) 50 | ls_func_evals = 1 51 | gtd_new = g_new.dot(d) 52 | 53 | # bracket an interval containing a point satisfying the Wolfe criteria 54 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 55 | done = False 56 | ls_iter = 0 57 | while ls_iter < max_ls: 58 | # check conditions 59 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 60 | bracket = [t_prev, t] 61 | bracket_f = [f_prev, f_new] 62 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 63 | bracket_gtd = [gtd_prev, gtd_new] 64 | break 65 | 66 | if abs(gtd_new) <= -c2 * gtd: 67 | bracket = [t] 68 | bracket_f = [f_new] 69 | bracket_g = [g_new] 70 | done = True 71 | break 72 | 73 | if gtd_new >= 0: 74 | bracket = [t_prev, t] 75 | bracket_f = [f_prev, f_new] 76 | bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 77 | bracket_gtd = [gtd_prev, gtd_new] 78 | break 79 | 80 | # interpolate 81 | min_step = t + 0.01 * (t - t_prev) 82 | max_step = t * 10 83 | tmp = t 84 | t = _cubic_interpolate( 85 | t_prev, 86 | f_prev, 87 | gtd_prev, 88 | t, 89 | f_new, 90 | gtd_new, 91 | bounds=(min_step, max_step)) 92 | 93 | # next step 94 | t_prev = tmp 95 | f_prev = f_new 96 | g_prev = g_new.clone(memory_format=torch.contiguous_format) 97 | gtd_prev = gtd_new 98 | f_new, g_new = obj_func(x, t, d) 99 | ls_func_evals += 1 100 | gtd_new = g_new.dot(d) 101 | ls_iter += 1 102 | 103 | # reached max number of iterations? 104 | if ls_iter == max_ls: 105 | bracket = [0, t] 106 | bracket_f = [f, f_new] 107 | bracket_g = [g, g_new] 108 | 109 | # zoom phase: we now have a point satisfying the criteria, or 110 | # a bracket around it. We refine the bracket until we find the 111 | # exact point satisfying the criteria 112 | insuf_progress = False 113 | # find high and low points in bracket 114 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) 115 | while not done and ls_iter < max_ls: 116 | # line-search bracket is so small 117 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: 118 | break 119 | 120 | # compute new trial value 121 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], 122 | bracket[1], bracket_f[1], bracket_gtd[1]) 123 | 124 | # test that we are making sufficient progress: 125 | # in case `t` is so close to boundary, we mark that we are making 126 | # insufficient progress, and if 127 | # + we have made insufficient progress in the last step, or 128 | # + `t` is at one of the boundary, 129 | # we will move `t` to a position which is `0.1 * len(bracket)` 130 | # away from the nearest boundary point. 131 | eps = 0.1 * (max(bracket) - min(bracket)) 132 | if min(max(bracket) - t, t - min(bracket)) < eps: 133 | # interpolation close to boundary 134 | if insuf_progress or t >= max(bracket) or t <= min(bracket): 135 | # evaluate at 0.1 away from boundary 136 | if abs(t - max(bracket)) < abs(t - min(bracket)): 137 | t = max(bracket) - eps 138 | else: 139 | t = min(bracket) + eps 140 | insuf_progress = False 141 | else: 142 | insuf_progress = True 143 | else: 144 | insuf_progress = False 145 | 146 | # Evaluate new point 147 | f_new, g_new = obj_func(x, t, d) 148 | ls_func_evals += 1 149 | gtd_new = g_new.dot(d) 150 | ls_iter += 1 151 | 152 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 153 | # Armijo condition not satisfied or not lower than lowest point 154 | bracket[high_pos] = t 155 | bracket_f[high_pos] = f_new 156 | bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) 157 | bracket_gtd[high_pos] = gtd_new 158 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 159 | else: 160 | if abs(gtd_new) <= -c2 * gtd: 161 | # Wolfe conditions satisfied 162 | done = True 163 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 164 | # old high becomes new low 165 | bracket[high_pos] = bracket[low_pos] 166 | bracket_f[high_pos] = bracket_f[low_pos] 167 | bracket_g[high_pos] = bracket_g[low_pos] 168 | bracket_gtd[high_pos] = bracket_gtd[low_pos] 169 | 170 | # new point becomes new low 171 | bracket[low_pos] = t 172 | bracket_f[low_pos] = f_new 173 | bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) 174 | bracket_gtd[low_pos] = gtd_new 175 | 176 | # return stuff 177 | t = bracket[low_pos] 178 | f_new = bracket_f[low_pos] 179 | g_new = bracket_g[low_pos] 180 | return f_new, g_new, t, ls_func_evals 181 | 182 | 183 | class LBFGS(Optimizer): 184 | """Implements L-BFGS algorithm, heavily inspired by `minFunc 185 | `. 186 | 187 | .. warning:: 188 | This optimizer doesn't support per-parameter options and parameter 189 | groups (there can be only one). 190 | 191 | .. warning:: 192 | Right now all parameters have to be on a single device. This will be 193 | improved in the future. 194 | 195 | .. note:: 196 | This is a very memory intensive optimizer (it requires additional 197 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 198 | try reducing the history size, or use a different algorithm. 199 | 200 | Args: 201 | lr (float): learning rate (default: 1) 202 | max_iter (int): maximal number of iterations per optimization step 203 | (default: 20) 204 | max_eval (int): maximal number of function evaluations per optimization 205 | step (default: max_iter * 1.25). 206 | tolerance_grad (float): termination tolerance on first order optimality 207 | (default: 1e-5). 208 | tolerance_change (float): termination tolerance on function 209 | value/parameter changes (default: 1e-9). 210 | history_size (int): update history size (default: 100). 211 | line_search_fn (str): either 'strong_wolfe' or None (default: None). 212 | """ 213 | 214 | def __init__(self, 215 | params, 216 | lr=1, 217 | max_iter=20, 218 | max_eval=None, 219 | tolerance_grad=1e-7, 220 | tolerance_change=1e-9, 221 | history_size=100, 222 | line_search_fn=None): 223 | if max_eval is None: 224 | max_eval = max_iter * 5 // 4 225 | defaults = dict( 226 | lr=lr, 227 | max_iter=max_iter, 228 | max_eval=max_eval, 229 | tolerance_grad=tolerance_grad, 230 | tolerance_change=tolerance_change, 231 | history_size=history_size, 232 | line_search_fn=line_search_fn) 233 | super(LBFGS, self).__init__(params, defaults) 234 | 235 | if len(self.param_groups) != 1: 236 | raise ValueError("LBFGS doesn't support per-parameter options " 237 | "(parameter groups)") 238 | 239 | self._params = self.param_groups[0]['params'] 240 | self._numel_cache = None 241 | 242 | def _numel(self): 243 | if self._numel_cache is None: 244 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 245 | return self._numel_cache 246 | 247 | def _gather_flat_grad(self): 248 | views = [] 249 | for p in self._params: 250 | if p.grad is None: 251 | view = p.new(p.numel()).zero_() 252 | elif p.grad.is_sparse: 253 | view = p.grad.to_dense().view(-1) 254 | else: 255 | view = p.grad.view(-1) 256 | views.append(view) 257 | return torch.cat(views, 0) 258 | 259 | def _add_grad(self, step_size, update): 260 | offset = 0 261 | for p in self._params: 262 | numel = p.numel() 263 | # view as to avoid deprecated pointwise semantics 264 | p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) 265 | offset += numel 266 | assert offset == self._numel() 267 | 268 | def _clone_param(self): 269 | return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 270 | 271 | def _set_param(self, params_data): 272 | for p, pdata in zip(self._params, params_data): 273 | p.copy_(pdata) 274 | 275 | def _directional_evaluate(self, closure, x, t, d): 276 | self._add_grad(t, d) 277 | loss = float(closure()) 278 | flat_grad = self._gather_flat_grad() 279 | self._set_param(x) 280 | return loss, flat_grad 281 | 282 | @torch.no_grad() 283 | def step(self, closure): 284 | """Performs a single optimization step. 285 | 286 | Args: 287 | closure (callable): A closure that reevaluates the model 288 | and returns the loss. 289 | """ 290 | assert len(self.param_groups) == 1 291 | 292 | # Make sure the closure is always called with grad enabled 293 | closure = torch.enable_grad()(closure) 294 | 295 | group = self.param_groups[0] 296 | lr = group['lr'] 297 | max_iter = group['max_iter'] 298 | max_eval = group['max_eval'] 299 | tolerance_grad = group['tolerance_grad'] 300 | tolerance_change = group['tolerance_change'] 301 | line_search_fn = group['line_search_fn'] 302 | history_size = group['history_size'] 303 | 304 | # NOTE: LBFGS has only global state, but we register it as state for 305 | # the first param, because this helps with casting in load_state_dict 306 | state = self.state[self._params[0]] 307 | state.setdefault('func_evals', 0) 308 | state.setdefault('n_iter', 0) 309 | 310 | # evaluate initial f(x) and df/dx 311 | orig_loss = closure() 312 | loss = float(orig_loss) 313 | current_evals = 1 314 | state['func_evals'] += 1 315 | 316 | flat_grad = self._gather_flat_grad() 317 | opt_cond = flat_grad.abs().max() <= tolerance_grad 318 | 319 | # optimal condition 320 | if opt_cond: 321 | return orig_loss 322 | 323 | # tensors cached in state (for tracing) 324 | d = state.get('d') 325 | t = state.get('t') 326 | old_dirs = state.get('old_dirs') 327 | old_stps = state.get('old_stps') 328 | ro = state.get('ro') 329 | H_diag = state.get('H_diag') 330 | prev_flat_grad = state.get('prev_flat_grad') 331 | prev_loss = state.get('prev_loss') 332 | 333 | n_iter = 0 334 | # optimize for a max of max_iter iterations 335 | while n_iter < max_iter: 336 | # keep track of nb of iterations 337 | n_iter += 1 338 | state['n_iter'] += 1 339 | 340 | ############################################################ 341 | # compute gradient descent direction 342 | ############################################################ 343 | if state['n_iter'] == 1: 344 | d = flat_grad.neg() 345 | old_dirs = [] 346 | old_stps = [] 347 | ro = [] 348 | H_diag = 1 349 | else: 350 | # do lbfgs update (update memory) 351 | y = flat_grad.sub(prev_flat_grad) 352 | s = d.mul(t) 353 | ys = y.dot(s) # y*s 354 | if ys > 1e-10: 355 | # updating memory 356 | if len(old_dirs) == history_size: 357 | # shift history by one (limited-memory) 358 | old_dirs.pop(0) 359 | old_stps.pop(0) 360 | ro.pop(0) 361 | 362 | # store new direction/step 363 | old_dirs.append(y) 364 | old_stps.append(s) 365 | ro.append(1. / ys) 366 | 367 | # update scale of initial Hessian approximation 368 | H_diag = ys / y.dot(y) # (y*y) 369 | 370 | # compute the approximate (L-BFGS) inverse Hessian 371 | # multiplied by the gradient 372 | num_old = len(old_dirs) 373 | 374 | if 'al' not in state: 375 | state['al'] = [None] * history_size 376 | al = state['al'] 377 | 378 | # iteration in L-BFGS loop collapsed to use just one buffer 379 | q = flat_grad.neg() 380 | for i in range(num_old - 1, -1, -1): 381 | al[i] = old_stps[i].dot(q) * ro[i] 382 | q.add_(old_dirs[i], alpha=-al[i]) 383 | 384 | # multiply by initial Hessian 385 | # r/d is the final direction 386 | d = r = torch.mul(q, H_diag) 387 | for i in range(num_old): 388 | be_i = old_dirs[i].dot(r) * ro[i] 389 | r.add_(old_stps[i], alpha=al[i] - be_i) 390 | 391 | if prev_flat_grad is None: 392 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 393 | else: 394 | prev_flat_grad.copy_(flat_grad) 395 | prev_loss = loss 396 | 397 | ############################################################ 398 | # compute step length 399 | ############################################################ 400 | # reset initial guess for step size 401 | if state['n_iter'] == 1: 402 | t = min(1., 1. / flat_grad.abs().sum()) * lr 403 | else: 404 | t = lr 405 | 406 | # directional derivative 407 | gtd = flat_grad.dot(d) # g * d 408 | 409 | # directional derivative is below tolerance 410 | if gtd > -tolerance_change: 411 | break 412 | 413 | # optional line search: user function 414 | ls_func_evals = 0 415 | if line_search_fn is not None: 416 | # perform line search, using user function 417 | if line_search_fn != "strong_wolfe": 418 | raise RuntimeError("only 'strong_wolfe' is supported") 419 | else: 420 | x_init = self._clone_param() 421 | 422 | def obj_func(x, t, d): 423 | return self._directional_evaluate(closure, x, t, d) 424 | 425 | loss, flat_grad, t, ls_func_evals = _strong_wolfe( 426 | obj_func, x_init, t, d, loss, flat_grad, gtd) 427 | self._add_grad(t, d) 428 | opt_cond = flat_grad.abs().max() <= tolerance_grad 429 | else: 430 | # no line search, simply move with fixed-step 431 | self._add_grad(t, d) 432 | if n_iter != max_iter: 433 | # re-evaluate function only if not in last iteration 434 | # the reason we do this: in a stochastic setting, 435 | # no use to re-evaluate that function here 436 | with torch.enable_grad(): 437 | loss = float(closure()) 438 | flat_grad = self._gather_flat_grad() 439 | opt_cond = flat_grad.abs().max() <= tolerance_grad 440 | ls_func_evals = 1 441 | 442 | # addition evaluation to recorde the results 443 | loss_extra = float(closure(verbose=True)) 444 | del loss_extra 445 | 446 | # update func eval 447 | current_evals += ls_func_evals 448 | state['func_evals'] += ls_func_evals 449 | 450 | ############################################################ 451 | # check conditions 452 | ############################################################ 453 | if n_iter == max_iter: 454 | break 455 | 456 | if current_evals >= max_eval: 457 | break 458 | 459 | # optimal condition 460 | if opt_cond: 461 | break 462 | 463 | # lack of progress 464 | if d.mul(t).abs().max() <= tolerance_change: 465 | break 466 | 467 | if abs(loss - prev_loss) < tolerance_change: 468 | break 469 | 470 | state['d'] = d 471 | state['t'] = t 472 | state['old_dirs'] = old_dirs 473 | state['old_stps'] = old_stps 474 | state['ro'] = ro 475 | state['H_diag'] = H_diag 476 | state['prev_flat_grad'] = prev_flat_grad 477 | state['prev_loss'] = prev_loss 478 | 479 | return orig_loss 480 | -------------------------------------------------------------------------------- /source/main.py: -------------------------------------------------------------------------------- 1 | """Run PINNs for convection/reaction/reaction-diffusion with periodic boundary conditions.""" 2 | 3 | import argparse 4 | import numpy as np 5 | import os 6 | import random 7 | import torch 8 | from systems_pbc import * 9 | import torch.backends.cudnn as cudnn 10 | from utils import * 11 | from visualize import * 12 | import matplotlib.pyplot as plt 13 | import sys 14 | from scipy.io import loadmat 15 | import h5py 16 | from net import * 17 | torch.backends.cuda.matmul.allow_tf32 = False 18 | torch.set_default_dtype(torch.float64) 19 | # torch.use_deterministic_algorithms(True) 20 | # torch.set_deterministic(True) 21 | ################ 22 | # Arguments 23 | ################ 24 | parser = argparse.ArgumentParser(description='Adaptive AF for PINNs') 25 | 26 | parser.add_argument('--launcher', default='', type=str) 27 | parser.add_argument('--port', default=29051, type=int) 28 | parser.add_argument('--system', type=str, default='convection', help='System to study.') 29 | parser.add_argument('--data_path', type=str, default=None) 30 | parser.add_argument('--seed', type=int, default=0, help='Random initialization.') 31 | parser.add_argument('--N_f', type=int, default=100, help='Number of collocation points to sample.') 32 | parser.add_argument('--N_f_x', type=int, default=64, help='Number of collocation points to sample.') 33 | parser.add_argument('--N_f_t', type=int, default=100, help='Number of collocation points to sample.') 34 | parser.add_argument('--ratio', type=float, default=1.0, help='ratio for N_f') 35 | parser.add_argument('--optimizer_name', type=str, default='LBFGS', help='Optimizer of choice.') 36 | parser.add_argument('--lr', type=float, default=1.0, help='Learning rate.') 37 | parser.add_argument('--L_f', type=float, default=1.0, help='Multiplier on loss f.') 38 | parser.add_argument('--L_u', type=float, default=1.0, help='Multiplier on loss u.') 39 | parser.add_argument('--L_b', type=float, default=1.0, help='Multiplier on loss b.') 40 | parser.add_argument('--adam_lr', type=float, default=0.0, help='Learning rate for adam.') 41 | parser.add_argument('--coeff_lr', type=float, default=1e-3, help='Learning rate for adam.') 42 | parser.add_argument('--epoch', type=int, default=1000, help='Epoch for adam') 43 | parser.add_argument('--repeat', type=int, default=1) 44 | parser.add_argument('--start_repeat', type=int, default=0) 45 | parser.add_argument('--sample_type', type=str, default='grid') 46 | parser.add_argument('--xgrid', type=int, default=256, help='Number of points in the xgrid.') 47 | parser.add_argument('--nt', type=int, default=100, help='Number of points in the tgrid.') 48 | parser.add_argument('--nu', type=str, default='0', help='nu value that scales the d^2u/dx^2 term. 0 if only doing advection.') 49 | parser.add_argument('--rho', type=str, default='0', help='reaction coefficient for u*(1-u) term.') 50 | parser.add_argument('--beta', type=str, default='0', help='beta value that scales the du/dx term. 0 if only doing diffusion.') 51 | parser.add_argument('--u0_str', default='sin(x)', help='str argument for initial condition if no forcing term.') 52 | parser.add_argument('--source', default=0, type=float, help="If there's a source term, define it here. For now, just constant force terms.") 53 | parser.add_argument('--layers', type=str, default='50,50,50,50,1', help='Dimensions/layers of the NN, minus the first layer.') 54 | parser.add_argument('--net', type=str, default='DNN', help='The net architecture that is to be used.') 55 | parser.add_argument('--activation', default='tanh', help='Activation to use in the network.') 56 | 57 | # params for linearpool 58 | parser.add_argument('--linearpool', action='store_true') 59 | parser.add_argument('--llaf', action='store_true') 60 | parser.add_argument('--use_recovery', action='store_true') 61 | parser.add_argument('--scaler', type=float, default=1.0) 62 | parser.add_argument('--poolsize', type=str, default='0') 63 | parser.add_argument('--aggregate', type=str, default='sum') 64 | parser.add_argument('--weight_sharing', action='store_true') 65 | parser.add_argument('--use_auxiliary', action='store_true') 66 | parser.add_argument('--sample_iter', type=int, default=1) 67 | parser.add_argument('--plot_loss', action='store_true') 68 | parser.add_argument('--evaluate', action='store_true') 69 | 70 | parser.add_argument('--num_head', type=int, default=1) 71 | parser.add_argument('--not_adapt_adam_lr', action='store_true') 72 | parser.add_argument('--sample_stage', type=int, default=1) 73 | parser.add_argument('--init', action='store_true') 74 | parser.add_argument('--visualize', default=False, help='Visualize the solution.') 75 | parser.add_argument('--save_model', default=False, help='Save the model for analysis later.') 76 | parser.add_argument('--gpu', default=False, action='store_true') 77 | parser.add_argument('--work_dir', default='debug') 78 | parser.add_argument('--sub_name', default='') 79 | parser.add_argument('--hard_ibc', action='store_true') 80 | parser.add_argument('--channel_wise', action='store_true') 81 | parser.add_argument('--warm_up_iter', type=int, default=1000) 82 | parser.add_argument('--uniform_sample', action='store_true') 83 | parser.add_argument('--use_norm', action='store_true') 84 | parser.add_argument('--exp_alpha', type=float, default=2.0) 85 | parser.add_argument('--sin_alpha', type=float, default=1.0) 86 | parser.add_argument('--fix_sample', action='store_true') 87 | parser.add_argument('--cosine_decay', action='store_true') 88 | parser.add_argument('--disable_lbfgs', action='store_true') 89 | 90 | parser.add_argument('--tau', type=float, default=1.0) 91 | parser.add_argument('--checkpoint', type=str, default='') 92 | parser.add_argument('--clip', type=float, default=0.0) 93 | parser.add_argument('--coeff_clip', type=float, default=0.0) 94 | parser.add_argument('--coeff_clip_type', type=str, default='norm') 95 | parser.add_argument('--recovery_weight', type=float, default=1.0) 96 | parser.add_argument('--enable_scaling', action='store_true') 97 | parser.add_argument('--detach_u', action='store_true') 98 | parser.add_argument('--detach_b', action='store_true') 99 | parser.add_argument('--detach_f', action='store_true') 100 | parser.add_argument('--coeff_beta1', type=float, default=0.9) 101 | parser.add_argument('--coeff_beta2', type=float, default=0.999) 102 | parser.add_argument('--constant_warmup', type=float, default=0.0) 103 | parser.add_argument('--sep_optim', action='store_true') 104 | parser.add_argument('--sep_warm_up_iter', type=int, default=1000) 105 | parser.add_argument('--momentum', type=float, default=0.9) 106 | parser.add_argument('--sep_optimizer', type=str, default='sgd') 107 | parser.add_argument('--coeff_lr_first_layer', type=float, default=1e-3) 108 | parser.add_argument('--lr_first_layer', type=float, default=1e-3) 109 | parser.add_argument('--lr_second_layer', type=float, default=1e-3) 110 | parser.add_argument('--sep_cosine_decay', action='store_true') 111 | parser.add_argument('--target_seed', type=int, default=None) 112 | parser.add_argument('--print_freq', type=int, default=1000) 113 | parser.add_argument('--valid_freq', type=int, default=1000) 114 | parser.add_argument('--l2_reg', type=float, default=0.0) 115 | parser.add_argument('--weight_decay', type=float, default=0.0) 116 | parser.add_argument('--coeff_weight_decay', type=float, default=0.0) 117 | parser.add_argument('--extra_N_f', type=int, default=0) 118 | parser.add_argument('--range', type=float, default=0.01) 119 | parser.add_argument('--enable_coeff_l2_reg', action='store_true') 120 | 121 | 122 | parser.add_argument('--T_max', type=int, default=0) 123 | parser.add_argument('--line_search_fn', type=str, default=None) 124 | parser.add_argument('--gain_0', type=float, default=1.0) 125 | parser.add_argument('--gain', type=float, default=1.0) 126 | parser.add_argument('--enable_vx', action='store_true') 127 | parser.add_argument('--enable_vt', action='store_true') 128 | parser.add_argument('--include_t0', action='store_true') 129 | 130 | parser.add_argument('--max_iter', type=int, default=15000) 131 | 132 | # for CH 133 | parser.add_argument('--decouple', action='store_true') 134 | parser.add_argument('--high_bc', action='store_true') 135 | parser.add_argument('--four_order', action='store_true') 136 | parser.add_argument('--enable_adaptive_tol', action='store_true') 137 | parser.add_argument('--fine_grid', action='store_true') 138 | parser.add_argument('--segment', type=int, default=-1) 139 | 140 | 141 | parser.add_argument('--dtype_float32', action='store_true') 142 | parser.add_argument('--exp_dir', default='./') 143 | parser.add_argument('--random_init', action='store_true', help='random init for coeff') 144 | parser.add_argument('--resume', type=str, default=None) 145 | parser.add_argument('--taylor', action='store_true') 146 | parser.add_argument('--taylor_scale', action='store_true') 147 | parser.add_argument('--taylor_order', type=int, default=0) 148 | 149 | 150 | args = parser.parse_args() 151 | 152 | if not args.dtype_float32: 153 | torch.set_default_dtype(torch.float64) 154 | 155 | args.name = args.work_dir 156 | logger = init_environ(args) 157 | logger.info(args.name) 158 | logger.info(args) 159 | 160 | # CUDA support 161 | if args.gpu: 162 | device = torch.device('cuda') 163 | else: 164 | device = torch.device('cpu') 165 | 166 | nu = [float(item) for item in args.nu.split(',')] 167 | beta = [float(item) for item in args.beta.split(',')] 168 | rho = [float(item) for item in args.rho.split(',')] 169 | 170 | 171 | # args.name = os.path.join(args.work_dir, args.name) 172 | 173 | 174 | # parse the layers list here 175 | orig_layers = args.layers 176 | layers = [int(item) for item in args.layers.split(',')] 177 | 178 | ############################ 179 | # Process data 180 | ############################ 181 | 182 | if args.data_path is not None: 183 | # for burger, AC, KdV and CH 184 | if args.system == 'burger': 185 | data = np.load(args.data_path) 186 | t, x, Exact = data["t"], data["x"], data["usol"].T 187 | 188 | elif args.system == 'AC': 189 | # from scipy.io import loadmat 190 | data = loadmat(args.data_path) 191 | t,x,Exact = data["t"], data["x"], data["u"] # (1,101), (1,201), (101,201) 192 | t = t.reshape(-1) 193 | x = x.reshape(-1) 194 | elif args.system == 'KdV': 195 | # from scipy.io import loadmat 196 | data = loadmat(args.data_path) 197 | t,x,Exact = data["tt"], data["x"], data["uu"].T # (1,201), (1,512), (512,201) 198 | t = t.reshape(-1) 199 | x = x.reshape(-1) 200 | 201 | x = np.append(x,1.0) 202 | Exact_init = Exact[:, 0:1] 203 | Exact = np.hstack([Exact, Exact_init]) 204 | 205 | elif args.system == 'CH': 206 | data = loadmat(args.data_path) 207 | t,x,Exact = data["tt"], data["x"], data["uu"].T # (1,201), (1,512), (512,201) 208 | u22 = data['u22'] # (512,1) 209 | t = t.reshape(-1) 210 | # x = x.reshape(-1) 211 | x = np.linspace(-1,1,512,endpoint=False) 212 | x = np.append(x,1.0) 213 | Exact_init = Exact[:, 0:1] 214 | Exact = np.hstack([Exact, Exact_init]) 215 | if args.segment > 0: 216 | t = t[:args.segment] 217 | Exact = Exact[:args.segment, :] 218 | 219 | 220 | X, T = np.meshgrid(x, t) 221 | X_star = np.vstack((np.ravel(X), np.ravel(T))).T 222 | u_star = Exact.flatten()[:, None] 223 | Exact_list = [Exact] 224 | 225 | 226 | if args.sample_type == 'grid': 227 | if args.include_t0: 228 | x_noboundary = np.linspace(-1, 1, args.N_f_x, endpoint=False).reshape(-1, 1) # not inclusive 229 | t_noinitial = np.linspace(0, 1, args.N_f_t).reshape(-1, 1) 230 | else: 231 | x_noboundary = np.linspace(-1, 1, args.N_f_x+1, endpoint=False).reshape(-1, 1)[1:] # not inclusive 232 | if args.fine_grid: 233 | t_noinitial = np.linspace(0,0.05,50+1,endpoint=False)[1:] 234 | t_noinitial2 = np.linspace(0.05, 1, args.N_f_t) 235 | t_noinitial = np.hstack([t_noinitial, t_noinitial2]) 236 | args.N_f_t = args.N_f_t + 50 237 | else: 238 | t_noinitial = np.linspace(0, 1, args.N_f_t+1).reshape(-1, 1)[1:] 239 | 240 | X_noboundary, T_noinitial = np.meshgrid(x_noboundary, t_noinitial) 241 | X_star_noinitial_noboundary = np.hstack((X_noboundary.flatten()[:, None], T_noinitial.flatten()[:, None])) 242 | logger.info(f'sample from grid: {args.N_f_x} {args.N_f_t}') 243 | 244 | else: 245 | X_star_noinitial_noboundary = {'lb':[x.min(), t.min()], 'ub':[x.max(), t.max()]} 246 | logger.info(f'sample given interval: {X_star_noinitial_noboundary}') 247 | # X_f_train= sample_random_interval(X_star_noinitial_noboundary['lb'],X_star_noinitial_noboundary['ub'], args.N_f) 248 | 249 | # sample collocation points only from the interior (where the PDE is enforced) 250 | set_seed(args.seed) 251 | # X_f_train, idx = sample_random_type(X_star_noinitial_noboundary, args.N_f) 252 | X_f_train, idx = sample_random_type(X_star_noinitial_noboundary, args.N_f, args.extra_N_f, args.range) 253 | args.N_f = args.N_f + args.extra_N_f 254 | xx1 = np.hstack((X[0:1,:].T, T[0:1,:].T)) # initial condition, from x = [-end, +end] and t=0 255 | uu1 = Exact[0:1,:].T # u(x, t) at t=0 256 | bc_lb = np.hstack((X[:,0:1], T[:,0:1])) # boundary condition at x = -1, and 257 | 258 | # generate the other BC, now at x=1 259 | bc_ub = np.hstack((X[:,-1:], T[:,-1:])) 260 | u_train = uu1 # just the initial condition 261 | X_u_train = xx1 # (x,t) for initial condition 262 | 263 | G = np.full(1, float(args.source)) 264 | else: 265 | # for convection, reaction, reaction-diffusion 266 | x = np.linspace(0, 2*np.pi, args.xgrid, endpoint=False).reshape(-1, 1) # not inclusive 267 | t = np.linspace(0, 1, args.nt).reshape(-1, 1) 268 | X, T = np.meshgrid(x, t) # all the X grid points T times, all the T grid points X times 269 | X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None])) # all the x,t "test" data 270 | 271 | if args.sample_type == 'grid': 272 | # # remove initial and boundaty data from X_star 273 | # t_noinitial = t[1:] 274 | # # remove boundary at x=0 275 | # x_noboundary = x[1:-1] 276 | if args.include_t0: 277 | x_noboundary = np.linspace(0, 2*np.pi, args.N_f_x, endpoint=False).reshape(-1, 1) # not inclusive 278 | t_noinitial = np.linspace(0, 1, args.N_f_t).reshape(-1, 1) 279 | else: 280 | x_noboundary = np.linspace(0, 2*np.pi, args.N_f_x+1, endpoint=False).reshape(-1, 1)[1:] # not inclusive 281 | t_noinitial = np.linspace(0, 1, args.N_f_t+1).reshape(-1, 1)[1:] 282 | 283 | X_noboundary, T_noinitial = np.meshgrid(x_noboundary, t_noinitial) 284 | X_star_noinitial_noboundary = np.hstack((X_noboundary.flatten()[:, None], T_noinitial.flatten()[:, None])) 285 | logger.info(f'sample from grid: {args.N_f_x} {args.N_f_t}') 286 | else: 287 | X_star_noinitial_noboundary = {'lb':[x.min(), t.min()], 'ub':[x.max(), t.max()]} 288 | logger.info(f'sample given interval: {X_star_noinitial_noboundary}') 289 | 290 | # sample collocation points only from the interior (where the PDE is enforced) 291 | set_seed(args.seed) 292 | X_f_train, idx = sample_random_type(X_star_noinitial_noboundary, args.N_f) 293 | 294 | u_vals_list = list() 295 | for nu_i, beta_i, rho_i in zip(nu, beta, rho): 296 | if 'convection' in args.system or 'diffusion' in args.system: 297 | u_vals = convection_diffusion(args.u0_str, nu_i, beta_i, args.source, args.xgrid, args.nt) 298 | 299 | elif 'rd' in args.system: 300 | u_vals = reaction_diffusion_discrete_solution(args.u0_str, nu_i, rho_i, args.xgrid, args.nt) 301 | 302 | elif 'reaction' in args.system: 303 | u_vals = reaction_solution(args.u0_str, rho_i, args.xgrid, args.nt) 304 | 305 | else: 306 | print("WARNING: System is not specified.") 307 | u_vals_list.append(u_vals) 308 | G = np.full(1, float(args.source)) 309 | 310 | u_star_list = [u_vals.reshape(-1, 1) for u_vals in u_vals_list] # Exact solution reshaped into (n, 1) 311 | Exact_list = [u_star.reshape(len(t), len(x)) for u_star in u_star_list] # Exact on the (x,t) grid 312 | u_star = u_star_list[-1] 313 | Exact = Exact_list[-1] 314 | xx1 = np.hstack((X[0:1,:].T, T[0:1,:].T)) # initial condition, from x = [-end, +end] and t=0 315 | uu1 = Exact[0:1,:].T # u(x, t) at t=0 316 | bc_lb = np.hstack((X[:,0:1], T[:,0:1])) # boundary condition at x = 0, and t = [0, 1] 317 | uu2 = Exact[:,0:1] # u(-end, t) 318 | 319 | # generate the other BC, now at x=2pi 320 | t = np.linspace(0, 1, args.nt).reshape(-1, 1) 321 | x_bc_ub = np.array([2*np.pi]*t.shape[0]).reshape(-1, 1) 322 | bc_ub = np.hstack((x_bc_ub, t)) 323 | 324 | u_train = uu1 # just the initial condition 325 | X_u_train = xx1 # (x,t) for initial condition 326 | 327 | if args.system != 'CH': 328 | u22 = None 329 | 330 | layers.insert(0, X_u_train.shape[-1]) 331 | 332 | ############################ 333 | # Train the model 334 | ############################ 335 | repeat = args.repeat 336 | error_u_relative_all = [] 337 | error_u_abs_all = [] 338 | error_u_linf_all = [] 339 | losses_all = [] 340 | losses_u = [] 341 | losses_f = [] 342 | losses_b = [] 343 | losses_f_test = [] 344 | 345 | linear_pool_coeff = [] 346 | 347 | # numpy array float64 to float32 348 | # X_u_train = X_u_train.astype(np.float32) 349 | # u_train = u_train.astype(np.float32) 350 | # X_f_train = X_f_train.astype(np.float32) 351 | # bc_lb = bc_lb.astype(np.float32) 352 | # bc_ub = bc_ub.astype(np.float32) 353 | # X_star = X_star.astype(np.float32) 354 | 355 | for i in range(args.start_repeat, repeat): 356 | 357 | if args.target_seed is not None and i != args.target_seed: 358 | continue 359 | set_seed(args.seed+i) # for weight initialization 360 | if not args.fix_sample: 361 | X_f_train, idx = sample_random_type(X_star_noinitial_noboundary, args.N_f) 362 | X_f_train = X_f_train[:int(args.ratio*args.N_f)] 363 | # X_f_train = X_star_noinitial_noboundary 364 | # X_f_train = X_star 365 | logger.info(X_f_train) 366 | logger.info(X_f_train.shape) 367 | logger.info(f'seed: {args.seed+i}') 368 | model = PhysicsInformedNN_pbc(args, i, logger, device, args.system, X_star_noinitial_noboundary, X_u_train, u_train, X_f_train, u22, bc_lb, bc_ub, X_star, Exact_list, layers, G, nu, beta, rho, 369 | args.optimizer_name, args.lr, args.net, args.activation) 370 | 371 | if args.evaluate: 372 | 373 | u_pred = model.predict(X_star) 374 | loss_f = model.evaluate_loss_f(X_star) 375 | error_u_relative = np.linalg.norm(u_star-u_pred, 2)/np.linalg.norm(u_star, 2) 376 | error_u_abs = np.mean(np.abs(u_star - u_pred)) 377 | error_u_linf = np.linalg.norm(u_star - u_pred, np.inf)/np.linalg.norm(u_star, np.inf) 378 | 379 | logger.info('Error u rel: %e' % (error_u_relative)) 380 | logger.info('Error u abs: %e' % (error_u_abs)) 381 | logger.info('Error u linf: %e' % (error_u_linf)) 382 | data_dict = {'exact':u_star.reshape(len(t), len(x)), 'pred':u_pred.reshape(len(t), len(x))} 383 | np.save(os.path.join(args.work_dir, 'results.npy'), data_dict) 384 | break 385 | 386 | 387 | loss_all, loss_u, loss_b, loss_f = model.train_adam(args.adam_lr, args.epoch, X_star, u_star) 388 | losses_all.append(loss_all) 389 | losses_u.append(loss_u) 390 | losses_b.append(loss_b) 391 | losses_f.append(loss_f) 392 | 393 | u_pred = model.predict(X_star) 394 | f_pred = model.evaluate_loss_f(X_star) 395 | 396 | if args.linearpool: 397 | coeff_all = [] 398 | for name, para in model.dnn.named_parameters(): 399 | if 'coeff' in name: 400 | if args.aggregate == 'sigmoid': 401 | coeff = torch.sigmoid(para).mean(-1) 402 | elif args.aggregate == 'softmax': 403 | coeff = torch.softmax(para/args.tau,dim=0).mean(-1) 404 | elif args.aggregate == 'unlimited': 405 | coeff = para.mean(-1) 406 | else: 407 | coeff = (para / para.abs().sum(dim=0)).mean(-1) 408 | coeff_all.append(coeff) 409 | coeff_all = torch.cat(coeff_all) 410 | linear_pool_coeff.append(coeff_all) 411 | error_u_relative = np.linalg.norm(u_star-u_pred, 2)/np.linalg.norm(u_star, 2) 412 | error_u_abs = np.mean(np.abs(u_star - u_pred)) 413 | error_u_linf = np.linalg.norm(u_star - u_pred, np.inf)/np.linalg.norm(u_star, np.inf) 414 | error_f_test = np.mean(f_pred ** 2) 415 | 416 | logger.info('Error u rel: %e' % (error_u_relative)) 417 | logger.info('Error u abs: %e' % (error_u_abs)) 418 | logger.info('Error u linf: %e' % (error_u_linf)) 419 | logger.info('Loss all: %e' % (loss_all)) 420 | logger.info('Loss u: %e' % (loss_u)) 421 | logger.info('Loss b: %e' % (loss_b)) 422 | logger.info('Loss f: %e' % (loss_f)) 423 | logger.info('Loss f test: %e' % (error_f_test)) 424 | 425 | 426 | error_u_relative_all.append(error_u_relative) 427 | error_u_abs_all.append(error_u_abs) 428 | error_u_linf_all.append(error_u_linf) 429 | losses_f_test.append(error_f_test) 430 | 431 | if i != repeat-1: 432 | del model 433 | 434 | if args.linearpool: 435 | linear_pool_coeff = torch.vstack(linear_pool_coeff) 436 | mean_coeff = linear_pool_coeff.mean(dim=0) 437 | std_coeff = linear_pool_coeff.std(dim=0) 438 | coeff_str = [f'{mean_i:.2f}/{std_i:.2f}' for mean_i, std_i in zip(mean_coeff, std_coeff)] 439 | coeff_str = ', '.join(coeff_str) 440 | logger.info(coeff_str) 441 | logger.info('Error u rel: mean %e, std %e' % (np.mean(error_u_relative_all), np.std(error_u_relative_all))) 442 | logger.info('Error u abs: mean %e, std %e' % (np.mean(error_u_abs_all), np.std(error_u_abs_all))) 443 | logger.info('Error u linf: mean %e, std %e' % (np.mean(error_u_linf_all), np.std(error_u_linf_all))) 444 | logger.info('loss_all: mean %e, std %e' % (np.mean(losses_all), np.std(losses_all))) 445 | logger.info('loss_u: mean %e, std %e' % (np.mean(losses_u), np.std(losses_u))) 446 | logger.info('loss_b: mean %e, std %e' % (np.mean(losses_b), np.std(losses_b))) 447 | logger.info('loss_f: mean %e, std %e' % (np.mean(losses_f), np.std(losses_f))) 448 | logger.info('loss_f_test: mean %e, std %e' % (np.mean(losses_f_test), np.std(losses_f_test))) 449 | 450 | if args.visualize: 451 | path = os.path.join(args.work_dir, f"heatmap_results/{args.system}") 452 | if not os.path.exists(path): 453 | os.makedirs(path) 454 | u_pred = u_pred.reshape(len(t), len(x)) 455 | if args.evaluate: 456 | loss_f = loss_f.reshape(len(t), len(x)) 457 | exact_u(loss_f, x, t, nu, beta, rho, orig_layers, args.N_f, args.L_f, args.source, args.u0_str, args.system, path=path) 458 | else: 459 | exact_u(Exact, x, t, nu, beta, rho, orig_layers, args.N_f, args.L_f, args.source, args.u0_str, args.system, path=path) 460 | 461 | u_diff(X_f_train,Exact, u_pred, x, t, nu, beta, rho, args.seed, orig_layers, args.N_f, args.L_f, args.source, args.lr, args.u0_str, args.system, path=path) 462 | u_predict(Exact, u_pred, x, t, nu, beta, rho, args.seed, orig_layers, args.N_f, args.L_f, args.source, args.lr, args.u0_str, args.system, path=path) 463 | 464 | -------------------------------------------------------------------------------- /source/net.py: -------------------------------------------------------------------------------- 1 | from curses.panel import new_panel 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from collections import OrderedDict 6 | import numpy as np 7 | from choose_optimizer import * 8 | from matplotlib import pyplot as plt 9 | import os 10 | from utils import * 11 | import copy 12 | from mpl_toolkits.axes_grid1 import make_axes_locatable 13 | 14 | import math 15 | import random 16 | from torch.utils.tensorboard import SummaryWriter 17 | from tqdm import tqdm 18 | 19 | all_features = [] 20 | enable_retain_grad = False 21 | all_grad = [] 22 | 23 | 24 | class Rational(nn.Module): 25 | 26 | def __init__(self, width=1): 27 | super().__init__() 28 | w_numerator = [ 29 | 2.1172949817857366e-09, 30 | 0.9999942495075363, 31 | 6.276332768876106e-07, 32 | 0.10770864506559906, 33 | 2.946556898117109e-08, 34 | 0.000871124373591946 35 | ] 36 | w_denominator = [ 37 | 6.376908337817277e-07, 38 | 0.44101418051922986, 39 | 2.2747661404467182e-07, 40 | 0.014581039909092108 41 | ] 42 | 43 | # self.numerator = nn.Parameter(torch.tensor(w_numerator).double(),requires_grad=True) 44 | # self.denominator = nn.Parameter(torch.tensor(w_denominator).double(),requires_grad=True) 45 | self.numerator = nn.Parameter(torch.randn(6).double(),requires_grad=True) 46 | self.denominator = nn.Parameter(torch.randn(4).double(),requires_grad=True) 47 | 48 | def _get_xps(self, z, len_numerator, len_denominator): 49 | xps = list() 50 | xps.append(z) 51 | for _ in range(max(len_numerator, len_denominator) - 2): 52 | xps.append(xps[-1].mul(z)) 53 | xps.insert(0, torch.ones_like(z)) 54 | return torch.stack(xps, 1) 55 | 56 | def forward(self, x): 57 | 58 | numerator = sum([self.numerator[order] * torch.pow(x, order) for order in range(self.numerator.shape[0])]) 59 | denominator = sum([self.denominator[order] * torch.pow(x, order+1) for order in range(self.denominator.shape[0])]) 60 | return numerator.div(1 + denominator.abs()) 61 | 62 | class AconC(nn.Module): 63 | r""" ACON activation (activate or not). 64 | # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter 65 | # according to "Activate or Not: Learning Customized Activation" . 66 | """ 67 | 68 | def __init__(self, width=1): 69 | super().__init__() 70 | self.p1 = nn.Parameter(torch.randn(1, width).double()) 71 | self.p2 = nn.Parameter(torch.randn(1, width).double()) 72 | self.beta = nn.Parameter(torch.ones(1, width).double()) 73 | 74 | def forward(self, x): 75 | return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x 76 | 77 | 78 | class Sine(torch.nn.Module): 79 | def __init__(self, alpha=1.0): 80 | super(Sine, self).__init__() 81 | self.alpha = alpha 82 | print(self.alpha) 83 | def forward(self, x): 84 | return torch.sin(self.alpha * x) 85 | 86 | class Cosine(torch.nn.Module): 87 | def __init__(self,alpha=1.0): 88 | super(Cosine, self).__init__() 89 | self.alpha = alpha 90 | print(self.alpha) 91 | def forward(self, x): 92 | return torch.cos(self.alpha * x) 93 | 94 | class Exp(torch.nn.Module): 95 | def __init__(self, alpha=2.0): 96 | super(Exp, self).__init__() 97 | self.alpha = alpha 98 | 99 | def forward(self, x): 100 | return torch.exp((x-0.0)/self.alpha)-1.0 101 | 102 | class Pow(torch.nn.Module): 103 | def __init__(self, order, scale=False): 104 | super(Pow, self).__init__() 105 | self.order = order 106 | self.scale = scale 107 | 108 | def forward(self, x): 109 | # print(self.alpha) 110 | # return torch.sign(x) * (torch.exp(-torch.abs(x)) - 1.0) 111 | if self.scale and self.order > 0: 112 | return torch.pow(x, self.order) / (1.0 * self.order) 113 | else: 114 | return torch.pow(x, self.order) 115 | 116 | class Log(torch.nn.Module): 117 | def __init__(self): 118 | super(Log, self).__init__() 119 | 120 | def forward(self, x): 121 | return torch.log(torch.abs(x)+1e-8) 122 | 123 | class MyLinearPool(torch.nn.Module): 124 | def __init__(self, input_channel, output_channel, poolsize='0', aggregate='sum', weight_sharing=True, out_list=True, channel_wise=True, use_norm=False, exp_alpha=2.0, sin_alpha=1.0, tau=1.0, scaler=1.0, enable_scaling=False, random_init=False, taylor=False, taylor_scale=False, taylor_order=5): 125 | # print(exp_alpha) 126 | super(MyLinearPool, self).__init__() 127 | 128 | act_pool = [Sine(alpha=sin_alpha), torch.nn.Tanh(), nn.GELU(), nn.SiLU(), nn.Softplus()] 129 | self.poolsize = [int(i) for i in poolsize.split(',')] 130 | if taylor: 131 | self.activations = torch.nn.ModuleList([Pow(order=order_i, scale=taylor_scale) for order_i in range(taylor_order+1)]) 132 | else: 133 | self.activations = torch.nn.ModuleList([act_pool[i] for i in self.poolsize]) 134 | self.aggregate = aggregate 135 | self.tau = tau 136 | print('tau', self.tau) 137 | 138 | self.use_norm = use_norm 139 | if use_norm: 140 | self.norm = torch.nn.LayerNorm(output_channel,elementwise_affine=False).double() 141 | self.enable_scaling = enable_scaling 142 | self.scaler = scaler 143 | self.weight_sharing = weight_sharing 144 | self.enable_scaling = enable_scaling 145 | self.scaler = scaler 146 | if self.weight_sharing: 147 | self.fc = torch.nn.Linear(input_channel, output_channel).double() 148 | else: 149 | self.fc = torch.nn.ModuleList([torch.nn.Linear(input_channel, output_channel).double() for _ in range(len(self.activations))]) 150 | if channel_wise: 151 | if self.aggregate == 'sigmoid' or self.aggregate == 'softmax': 152 | if random_init: 153 | coeff_weight = torch.randn(len(self.activations), output_channel).double() 154 | self.coeff = torch.nn.parameter.Parameter(coeff_weight) 155 | else: 156 | self.coeff = torch.nn.parameter.Parameter(torch.zeros(len(self.activations), output_channel).double()) 157 | else: 158 | 159 | self.coeff = torch.nn.parameter.Parameter(torch.ones(len(self.activations), output_channel).double() / len(self.poolsize)) 160 | 161 | self.coeff2 = torch.nn.parameter.Parameter(1.0 / scaler * torch.ones(len(self.activations), output_channel).double()) 162 | else: 163 | if self.aggregate == 'sigmoid' or self.aggregate == 'softmax': 164 | if random_init: 165 | coeff_weight = torch.randn(len(self.activations), 1).double() 166 | self.coeff = torch.nn.parameter.Parameter(coeff_weight) 167 | else: 168 | self.coeff = torch.nn.parameter.Parameter(torch.zeros(len(self.activations), 1).double()) 169 | else: 170 | 171 | self.coeff = torch.nn.parameter.Parameter(torch.randn(len(self.activations), 1).double() / len(self.poolsize)) 172 | self.coeff2 = torch.nn.parameter.Parameter(1.0 / scaler * torch.ones(len(self.activations), 1).double()) 173 | self.out_list = out_list 174 | 175 | def forward(self, input): 176 | if self.aggregate == 'sigmoid': 177 | coeff = torch.sigmoid(self.coeff) 178 | elif self.aggregate == 'softmax': 179 | coeff = torch.softmax(self.coeff/self.tau, dim=0) 180 | elif self.aggregate == 'unlimited': 181 | coeff = self.coeff 182 | elif self.aggregate == 'l1_norm': 183 | coeff = self.coeff / self.coeff.abs().sum(dim=0) 184 | else: 185 | raise NotImplementedError 186 | 187 | if self.enable_scaling: 188 | coeff2 = self.coeff2 189 | else: 190 | coeff2 = self.coeff2.detach() 191 | 192 | x = input[0] 193 | detach = input[1] 194 | # print(coeff.shape) 195 | if detach: 196 | if self.weight_sharing: 197 | y = self.fc(x) #* self.coeff2 198 | out = [c.detach()*act(self.scaler * c2 * y) for act, c, c2 in zip(self.activations, coeff, coeff2)] 199 | else: 200 | out = [c.detach()*act(self.scaler * c2 * layer(x)) for act, layer, c, c2 in zip(self.activations, self.fc, coeff, coeff2)] 201 | else: 202 | if self.weight_sharing: 203 | y = self.fc(x) 204 | if self.use_norm: 205 | y = self.norm(y) 206 | # print(y.mean()) 207 | out = [c*act(self.scaler * c2 * y) for act, c, c2 in zip(self.activations, coeff, coeff2)] 208 | # out_mean = [o.mean() for o in out] 209 | # print(out_mean) 210 | else: 211 | out = [c*act(self.scaler * c2 * layer(x)) for act, layer, c, c2 in zip(self.activations, self.fc, coeff, coeff2)] 212 | 213 | if self.out_list: 214 | return [sum(out), detach] 215 | else: 216 | return sum(out) 217 | 218 | class LLAF(torch.nn.Module): 219 | def __init__(self, input_channel, output_channel, scaler=10, channel_wise=False): 220 | super(LLAF, self).__init__() 221 | 222 | self.fc = torch.nn.Linear(input_channel, output_channel).double() 223 | self.channel_wise = channel_wise 224 | if self.channel_wise: 225 | self.coeff2 = torch.nn.parameter.Parameter(1.0 / scaler * torch.ones(output_channel).double()) 226 | else: 227 | self.coeff2 = torch.nn.parameter.Parameter(1.0 / scaler * torch.ones(1).double()) 228 | self.scaler = scaler 229 | self.disable = False 230 | print(self.scaler) 231 | def forward(self, x): 232 | if self.disable: 233 | y = self.fc(x) * self.coeff2.detach() * self.scaler 234 | else: 235 | y = self.fc(x) * self.coeff2 * self.scaler 236 | return y 237 | 238 | class DNN(torch.nn.Module): 239 | def __init__(self, args, layers, activation, linearpool=False, init=False, poolsize=1, aggregate='sum', llaf=False): 240 | super(DNN, self).__init__() 241 | self.args = args 242 | 243 | # parameters 244 | self.depth = len(layers) - 1 245 | self.aggregate = aggregate 246 | if self.aggregate == 'cat': 247 | for i in range(1, len(layers)-1): 248 | layers[i] = layers[i] * poolsize 249 | # activations = activations.split(',') 250 | if activation == 'identity': 251 | self.activation = torch.nn.Identity 252 | elif activation == 'tanh': 253 | self.activation = torch.nn.Tanh 254 | elif activation == 'relu': 255 | self.activation = torch.nn.ReLU 256 | elif activation == 'softplus': 257 | self.activation = torch.nn.Softplus 258 | elif activation == 'gelu': 259 | self.activation = torch.nn.GELU 260 | elif activation == 'sin': 261 | self.activation = Sine 262 | elif activation == 'cos': 263 | self.activation = Cosine 264 | elif activation == 'exp': 265 | self.activation = Exp 266 | elif activation == 'logsigmoid': 267 | self.activation = torch.nn.LogSigmoid 268 | elif activation == 'silu': 269 | self.activation = torch.nn.SiLU 270 | elif activation == 'sigmoid': 271 | self.activation = torch.nn.Sigmoid 272 | elif activation == 'elu': 273 | self.activation = torch.nn.ELU 274 | elif activation == 'aconc': 275 | self.activation = AconC 276 | elif activation == 'rational': 277 | self.activation = Rational 278 | else: 279 | raise NotImplementedError 280 | 281 | self.detach = False 282 | 283 | layer_list = list() 284 | 285 | if linearpool: 286 | for i in range(self.depth-1): 287 | out_list = False if i == self.depth-2 else True 288 | layer_list.append( 289 | ('layer_%d' % i, MyLinearPool(layers[i], layers[i+1], poolsize=poolsize, aggregate=aggregate, weight_sharing=args.weight_sharing, out_list=out_list, channel_wise=self.args.channel_wise, use_norm=self.args.use_norm, exp_alpha=self.args.exp_alpha, sin_alpha=self.args.sin_alpha, tau=self.args.tau, scaler=self.args.scaler, enable_scaling=self.args.enable_scaling, random_init=self.args.random_init, taylor=args.taylor, taylor_order=args.taylor_order, taylor_scale=args.taylor_scale)) 290 | ) 291 | layer_list.append( 292 | ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]).double()) 293 | ) 294 | else: 295 | for i in range(self.depth - 1): 296 | if llaf: 297 | layer_list.append( 298 | ('layer_%d' % i, LLAF(layers[i], layers[i+1], self.args.scaler, self.args.channel_wise)) 299 | ) 300 | else: 301 | layer_list.append( 302 | ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]).double()) 303 | ) 304 | 305 | if activation == 'exp': 306 | layer_list.append(('activation_%d' % i, self.activation(alpha=self.args.exp_alpha))) 307 | elif activation == 'sin': 308 | layer_list.append(('activation_%d' % i, self.activation(alpha=self.args.sin_alpha))) 309 | 310 | # elif activation == 'rational': 311 | # layer_list.append(('activation_%d' % i, self.activation().double())) 312 | else: 313 | layer_list.append(('activation_%d' % i, self.activation())) 314 | 315 | 316 | layer_list.append( 317 | ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]).double()) 318 | ) 319 | 320 | layerDict = OrderedDict(layer_list) 321 | 322 | # deploy layers 323 | self.layers = torch.nn.Sequential(layerDict) 324 | if init: 325 | self.init_weights() 326 | 327 | def init_weights(self): 328 | for name,m in self.layers.named_modules(): 329 | print(name) 330 | if isinstance(m, (nn.Conv2d, nn.Linear)): 331 | print('init_xavier in backbone') 332 | if 'layer_0' in name: 333 | gain = self.args.gain_0 334 | else: 335 | gain = self.args.gain 336 | nn.init.xavier_uniform_(m.weight, gain) 337 | if 'layer_0' in name: 338 | if self.args.enable_vx: 339 | with torch.no_grad(): 340 | # import pdb 341 | # pdb.set_trace() 342 | m.weight[:,0] = m.weight[:,0] / 8.0 343 | # m.weight[:,1] = m.weight[:,1] * 8.0 344 | if self.args.enable_vt: 345 | with torch.no_grad(): 346 | # m.weight[:,0] = m.weight[:,0] / 8.0 347 | m.weight[:,1] = m.weight[:,1] * 8.0 348 | 349 | if m.bias is not None: 350 | nn.init.zeros_(m.bias) 351 | 352 | def forward(self, x): 353 | if self.args.linearpool: 354 | out = self.layers([x, self.detach]) 355 | else: 356 | out = self.layers(x) 357 | return out 358 | 359 | 360 | class PhysicsInformedNN_pbc(): 361 | """PINNs (convection/diffusion/reaction) for periodic boundary conditions.""" 362 | def __init__(self, args, repeat, logger, device, system, X_star_noinitial_noboundary, X_u_train, u_train, X_f_train, u22, bc_lb, bc_ub, X_star, Exact, layers, G, nu, beta, rho, optimizer_name, lr, 363 | net, activation='tanh'): 364 | self.args = args 365 | self.repeat = repeat 366 | self.logger = logger 367 | self.system = system 368 | self.device = device 369 | self.X_star_noinitial_noboundary = X_star_noinitial_noboundary 370 | self.X_star = X_star 371 | self.Exact = Exact if isinstance(Exact, list) else [Exact] 372 | 373 | # print(device) 374 | # a = torch.randn(2).cuda() 375 | # print(a) 376 | self.x_u = torch.tensor(X_u_train[:, 0:1], requires_grad=True).double().to(device) 377 | # print(self.x_u) 378 | self.t_u = torch.tensor(X_u_train[:, 1:2], requires_grad=True).double().to(device) 379 | self.x_f = torch.tensor(X_f_train[:, 0:1], requires_grad=True).double().to(device) 380 | self.t_f = torch.tensor(X_f_train[:, 1:2], requires_grad=True).double().to(device) 381 | # import pdb 382 | # pdb.set_trace() 383 | self.x_bc_lb = torch.tensor(bc_lb[:, 0:1], requires_grad=True).double().to(device) 384 | self.t_bc_lb = torch.tensor(bc_lb[:, 1:2], requires_grad=True).double().to(device) 385 | self.x_bc_ub = torch.tensor(bc_ub[:, 0:1], requires_grad=True).double().to(device) 386 | self.t_bc_ub = torch.tensor(bc_ub[:, 1:2], requires_grad=True).double().to(device) 387 | 388 | self.net = net 389 | 390 | self.use_auxiliary = args.use_auxiliary 391 | 392 | self.depth = len(layers) - 1 393 | 394 | self.num_head = args.num_head 395 | self.writer = SummaryWriter(log_dir=os.path.join(args.work_dir, f'run_{repeat}')) 396 | 397 | 398 | self.u = torch.tensor(u_train, requires_grad=False).double().to(device) 399 | # if self.system == 'CH' and self.args.decouple: 400 | # self.u22 = torch.tensor(u22, requires_grad=False).double().to(device) 401 | # self.u = torch.cat([self.u, self.u22], dim=-1) 402 | 403 | self.layers = layers 404 | self.nu = nu if isinstance(nu, list) else [nu] 405 | self.beta = beta if isinstance(beta, list) else [beta] 406 | self.rho = rho if isinstance(rho, list) else [rho] 407 | assert len(self.nu) == args.num_head 408 | assert len(self.beta) == args.num_head 409 | assert len(self.rho) == args.num_head 410 | assert len(self.Exact) == args.num_head 411 | print(self.nu) 412 | print(self.beta) 413 | print(self.rho) 414 | self.G = torch.tensor(G, requires_grad=True).double().to(device) 415 | self.G = self.G.reshape(-1, 1) 416 | 417 | self.L_f = args.L_f 418 | self.L_u = args.L_u 419 | self.L_b = args.L_b 420 | 421 | self.lr = lr 422 | self.optimizer_name = optimizer_name 423 | 424 | self.dnn = DNN(args, layers, activation, linearpool=args.linearpool, init=args.init, poolsize=args.poolsize, aggregate=args.aggregate, llaf=args.llaf).to(device) 425 | 426 | if args.resume is not None: 427 | checkpoint = torch.load(args.resume, map_location='cpu') 428 | self.logger.info(f'load checkpoint {args.resume}') 429 | # new_state_dict = {} 430 | # for k,v in checkpoint['state_dict'].items(): 431 | # if 'coeff2' in k: 432 | # new_state_dict[k] = v.new_ones((5,1)) 433 | # else: 434 | # new_state_dict[k] = v 435 | # self.dnn.load_state_dict(new_state_dict, strict=False) 436 | 437 | 438 | self.dnn.load_state_dict(checkpoint['state_dict'], strict=False) 439 | 440 | self.dnn_auxiliary = copy.deepcopy(self.dnn) 441 | 442 | self.logger.info(self.dnn) 443 | 444 | self.iter = 0 445 | 446 | 447 | def net_u(self, x, t, use_auxiliary=False, detach=False): 448 | self.dnn.detach = detach 449 | """The standard DNN that takes (x,t) --> u.""" 450 | 451 | input_xt = torch.cat([x, t], dim=1) 452 | 453 | if use_auxiliary: 454 | u = self.dnn_auxiliary(input_xt) 455 | else: 456 | u = self.dnn(input_xt) 457 | if 'AC' in self.system and self.args.hard_ibc: 458 | u = x**2 * torch.cos(np.pi * x) + t * (1 - x**2) * u 459 | if 'convection' in self.system and self.args.hard_ibc: 460 | u = t * u + torch.sin(x) 461 | if 'KdV' in self.system and self.args.hard_ibc: 462 | u = torch.cos(np.pi * x) + t * u 463 | if 'CH' in self.system and self.args.hard_ibc: 464 | if self.args.decouple: 465 | u_init = torch.cos(np.pi * x) - torch.exp(-4* (np.pi * x)**2) 466 | u2_init = - (np.pi ** 2) * torch.cos(np.pi * x) + 8* np.pi**2 * (1-8*(np.pi*x)**2) * torch.exp(-4* (np.pi * x)**2) 467 | y_init = torch.cat([u_init,u2_init],dim=-1) 468 | u = y_init + t * u 469 | else: 470 | u = torch.cos(np.pi * x) - torch.exp(-4* (np.pi * x)**2) + t * u 471 | 472 | if not isinstance(u, list): 473 | u = [u] 474 | return u 475 | 476 | def net_f(self, x, t, use_auxiliary=False, return_gradient=False, detach=False): 477 | """ Autograd for calculating the residual for different systems.""" 478 | u = self.net_u(x, t, use_auxiliary=use_auxiliary, detach=detach) 479 | f_all = [] 480 | u_tx = [] 481 | for output_i, nu, beta, rho in zip(u, self.nu, self.beta, self.rho): 482 | if self.system == 'CH' and self.args.decouple: 483 | u_i = output_i[:, 0:1] 484 | u2_i = output_i[:, 1:2] 485 | else: 486 | u_i = output_i 487 | u_t = torch.autograd.grad( 488 | u_i, t, 489 | grad_outputs=torch.ones_like(u_i), 490 | retain_graph=True, 491 | create_graph=True 492 | )[0] 493 | u_x = torch.autograd.grad( 494 | u_i, x, 495 | grad_outputs=torch.ones_like(u_i), 496 | retain_graph=True, 497 | create_graph=True 498 | )[0] 499 | if 'inviscid' not in self.system: 500 | u_xx = torch.autograd.grad( 501 | u_x, x, 502 | grad_outputs=torch.ones_like(u_x), 503 | retain_graph=True, 504 | create_graph=True 505 | )[0] 506 | u_tx.append([u_t.detach(), u_x.detach()]) 507 | if 'convection' in self.system or 'diffusion' in self.system: 508 | f = u_t - nu*u_xx + beta*u_x - self.G 509 | elif 'rd' in self.system: 510 | f = u_t - nu*u_xx - rho*u_i + rho*u_i**2 511 | elif 'reaction' in self.system: 512 | f = u_t - rho*u_i + rho*u_i**2 513 | elif 'burger' in self.system: 514 | f = u_t + u_i * u_x - 0.01 / math.pi * u_xx 515 | elif 'inviscid' in self.system: 516 | f = u_t + 2 * u_i * u_x 517 | elif 'AC' in self.system: 518 | f = u_t - 0.001 * u_xx - 5 * (u_i - u_i**3) 519 | # f = u_t - u_xx 520 | elif 'KdV' in self.system: 521 | u_xxx = torch.autograd.grad( 522 | u_xx, x, 523 | grad_outputs=torch.ones_like(u_xx), 524 | retain_graph=True, 525 | create_graph=True 526 | )[0] 527 | f = u_t + u_i * u_x + 0.0025 * u_xxx 528 | elif 'CH' in self.system: 529 | if self.args.decouple: 530 | 531 | y_ch = -0.02 * u2_i + u_i**3 - u_i 532 | y_ch_x = torch.autograd.grad( 533 | y_ch, x, 534 | grad_outputs=torch.ones_like(y_ch), 535 | retain_graph=True, 536 | create_graph=True 537 | )[0] 538 | y_ch_xx = torch.autograd.grad( 539 | y_ch_x, x, 540 | grad_outputs=torch.ones_like(y_ch_x), 541 | retain_graph=True, 542 | create_graph=True 543 | )[0] 544 | f1 = u_t - y_ch_xx 545 | f2 = u2_i - u_xx 546 | f = torch.cat([f1,f2], dim=-1) 547 | else: 548 | y_ch = u_i**3 - u_i - 0.02 * u_xx 549 | y_ch_x = torch.autograd.grad( 550 | y_ch, x, 551 | grad_outputs=torch.ones_like(y_ch), 552 | retain_graph=True, 553 | create_graph=True 554 | )[0] 555 | y_ch_xx = torch.autograd.grad( 556 | y_ch_x, x, 557 | grad_outputs=torch.ones_like(y_ch_x), 558 | retain_graph=True, 559 | create_graph=True 560 | )[0] 561 | f = u_t - y_ch_xx 562 | else: 563 | raise NotImplementedError 564 | 565 | f_all.append(f) 566 | if return_gradient: 567 | return f_all, u_tx 568 | else: 569 | return f_all 570 | 571 | def net_b_derivatives(self, u_lb, u_ub, x_bc_lb, x_bc_ub): 572 | """For taking BC derivatives.""" 573 | 574 | u_lb_x = torch.autograd.grad( 575 | u_lb, x_bc_lb, 576 | grad_outputs=torch.ones_like(u_lb), 577 | retain_graph=True, 578 | create_graph=True 579 | )[0] 580 | 581 | u_ub_x = torch.autograd.grad( 582 | u_ub, x_bc_ub, 583 | grad_outputs=torch.ones_like(u_ub), 584 | retain_graph=True, 585 | create_graph=True 586 | )[0] 587 | 588 | return u_lb_x, u_ub_x 589 | 590 | def net_b_derivatives_high_order(self, u_lb, u_ub, x_bc_lb, x_bc_ub): 591 | """For taking BC derivatives.""" 592 | 593 | u_lb_x = torch.autograd.grad( 594 | u_lb, x_bc_lb, 595 | grad_outputs=torch.ones_like(u_lb), 596 | retain_graph=True, 597 | create_graph=True 598 | )[0] 599 | 600 | u_lb_xx = torch.autograd.grad( 601 | u_lb_x, x_bc_lb, 602 | grad_outputs=torch.ones_like(u_lb_x), 603 | retain_graph=True, 604 | create_graph=True 605 | )[0] 606 | 607 | u_lb_xxx = torch.autograd.grad( 608 | u_lb_xx, x_bc_lb, 609 | grad_outputs=torch.ones_like(u_lb_xx), 610 | retain_graph=True, 611 | create_graph=True 612 | )[0] 613 | 614 | u_ub_x = torch.autograd.grad( 615 | u_ub, x_bc_ub, 616 | grad_outputs=torch.ones_like(u_ub), 617 | retain_graph=True, 618 | create_graph=True 619 | )[0] 620 | 621 | u_ub_xx = torch.autograd.grad( 622 | u_ub_x, x_bc_ub, 623 | grad_outputs=torch.ones_like(u_ub_x), 624 | retain_graph=True, 625 | create_graph=True 626 | )[0] 627 | 628 | u_ub_xxx = torch.autograd.grad( 629 | u_ub_xx, x_bc_ub, 630 | grad_outputs=torch.ones_like(u_ub_xx), 631 | retain_graph=True, 632 | create_graph=True 633 | )[0] 634 | 635 | return u_lb_x, u_ub_x, u_lb_xx, u_ub_xx, u_lb_xxx, u_ub_xxx 636 | 637 | 638 | def adapt_sample_range(self): 639 | t_range = np.linspace(0, 1, self.args.sample_stage+1) 640 | stage_iter = self.args.epoch // self.args.sample_stage 641 | if self.iter >= self.args.epoch: 642 | X_sample = self.X_star_noinitial_noboundary 643 | else: 644 | t_range_iter = t_range[self.iter // stage_iter+1] 645 | X_sample = self.X_star_noinitial_noboundary[self.X_star_noinitial_noboundary[:, 1] <= t_range_iter] 646 | # import pdb 647 | # pdb.set_trace() 648 | return X_sample 649 | 650 | def loss_pinn(self, verbose=False, step=False, evaluate=False): 651 | """ Loss function. """ 652 | 653 | if torch.is_grad_enabled(): 654 | self.optimizer.zero_grad() 655 | if self.args.sep_optim: 656 | self.optimizer_coeff.zero_grad() 657 | 658 | u_pred = self.net_u(self.x_u, self.t_u, detach=self.args.detach_u) 659 | 660 | u_pred_lb = self.net_u(self.x_bc_lb, self.t_bc_lb, detach=self.args.detach_b) 661 | u_pred_ub = self.net_u(self.x_bc_ub, self.t_bc_ub, detach=self.args.detach_b) 662 | 663 | if 'CH' in self.system and self.args.decouple: 664 | if not self.args.four_order: 665 | u_pred_lb_x, u_pred_ub_x = list(), list() 666 | # u2_pred_lb_x, u2_pred_ub_x = list(), list() 667 | for u_pred_lb_i, u_pred_ub_i in zip(u_pred_lb, u_pred_ub): 668 | 669 | u_pred_lb_x_i, u_pred_ub_x_i = self.net_b_derivatives(u_pred_lb_i[:, 0:1], u_pred_ub_i[:, 0:1], self.x_bc_lb, self.x_bc_ub) 670 | u_pred_lb_x.append(u_pred_lb_x_i) 671 | u_pred_ub_x.append(u_pred_ub_x_i) 672 | else: 673 | # u_pred_lb_x, u_pred_ub_x, u_pred_lb_xx, u_pred_ub_xx, u_pred_lb_xxx, u_pred_ub_xxx = list(), list(), list(), list(), list(), list() 674 | u_pred_lb_x, u_pred_ub_x = list(), list() 675 | u2_pred_lb_x, u2_pred_ub_x = list(), list() 676 | for u_pred_lb_i, u_pred_ub_i in zip(u_pred_lb, u_pred_ub): 677 | 678 | # u_pred_lb_x_i, u_pred_ub_x_i, u_pred_lb_xx_i, u_pred_ub_xx_i, u_pred_lb_xxx_i, u_pred_ub_xxx_i = self.net_b_derivatives_high_order(u_pred_lb_i[:, 0:1], u_pred_ub_i[:, 0:1], self.x_bc_lb, self.x_bc_ub) 679 | # u_pred_lb_x.append(u_pred_lb_x_i) 680 | # u_pred_ub_x.append(u_pred_ub_x_i) 681 | # u_pred_lb_xx.append(u_pred_lb_xx_i) 682 | # u_pred_ub_xx.append(u_pred_ub_xx_i) 683 | # u_pred_lb_xxx.append(u_pred_lb_xxx_i) 684 | # u_pred_ub_xxx.append(u_pred_ub_xxx_i) 685 | u_pred_lb_x_i, u_pred_ub_x_i = self.net_b_derivatives(u_pred_lb_i[:, 0:1], u_pred_ub_i[:, 0:1], self.x_bc_lb, self.x_bc_ub) 686 | u2_pred_lb_x_i, u2_pred_ub_x_i = self.net_b_derivatives(u_pred_lb_i[:, 1:2], u_pred_ub_i[:, 1:2], self.x_bc_lb, self.x_bc_ub) 687 | u2_pred_lb_x.append(u2_pred_lb_x_i) 688 | u2_pred_ub_x.append(u2_pred_ub_x_i) 689 | u_pred_lb_x.append(u_pred_lb_x_i) 690 | u_pred_ub_x.append(u_pred_ub_x_i) 691 | else: 692 | u_pred_lb_x, u_pred_ub_x = list(), list() 693 | for nu, u_pred_lb_i, u_pred_ub_i in zip(self.nu, u_pred_lb, u_pred_ub): 694 | if nu != 0 or 'KdV' in self.system or self.args.high_bc: 695 | u_pred_lb_x_i, u_pred_ub_x_i = self.net_b_derivatives(u_pred_lb_i, u_pred_ub_i, self.x_bc_lb, self.x_bc_ub) 696 | u_pred_lb_x.append(u_pred_lb_x_i) 697 | u_pred_ub_x.append(u_pred_ub_x_i) 698 | 699 | f_pred = self.net_f(self.x_f, self.t_f, detach=self.args.detach_f) 700 | loss_u_list, loss_b_list, loss_f_list = list(), list(), list() 701 | # loss_u_list, loss_b_list = list(), list() 702 | # import pdb; pdb.set_trace() 703 | for idx in range(self.num_head): 704 | 705 | if 'CH' in self.system and self.args.decouple: 706 | loss_u_i = torch.mean((self.u - u_pred[idx][:,0:1]) ** 2) 707 | else: 708 | loss_u_i = torch.mean((self.u - u_pred[idx]) ** 2) 709 | 710 | loss_f_i = torch.mean(f_pred[idx] ** 2) 711 | 712 | if 'burger' in self.system or 'inviscid' in self.system: 713 | loss_b_i = torch.mean(u_pred_lb[idx] ** 2) + torch.mean(u_pred_ub[idx] ** 2) 714 | 715 | elif 'AC' in self.system: 716 | loss_b_i = torch.mean((u_pred_lb[idx]+1) ** 2) + torch.mean((u_pred_ub[idx]+1) ** 2) 717 | elif 'KdV' in self.system: 718 | loss_b_i = torch.mean((u_pred_lb[idx] - u_pred_ub[idx]) ** 2) + torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) 719 | # loss_b_i = 0 720 | elif 'CH' in self.system and self.args.decouple: 721 | # loss_b_i = 2 * torch.mean((u_pred_lb[idx] - u_pred_ub[idx]) ** 2) + torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) + torch.mean((u2_pred_lb_x[idx] - u2_pred_ub_x[idx]) ** 2) 722 | if self.args.four_order: 723 | # loss_b_i = torch.mean((u_pred_lb[idx][:,0:1] - u_pred_ub[idx][:,0:1]) ** 2) + torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) \ 724 | # + torch.mean((u_pred_lb_xx[idx] - u_pred_ub_xx[idx]) ** 2) + torch.mean((u_pred_lb_xxx[idx] - u_pred_ub_xxx[idx]) ** 2) 725 | loss_b_i = 2 * torch.mean((u_pred_lb[idx] - u_pred_ub[idx]) ** 2) + torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) + torch.mean((u2_pred_lb_x[idx] - u2_pred_ub_x[idx]) ** 2) 726 | else: 727 | loss_b_i = torch.mean((u_pred_lb[idx][:,0:1] - u_pred_ub[idx][:,0:1]) ** 2) #+ torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) 728 | else: 729 | loss_b_i = torch.mean((u_pred_lb[idx] - u_pred_ub[idx]) ** 2) 730 | 731 | if self.nu[idx] != 0 or self.args.high_bc: 732 | loss_b_i += torch.mean((u_pred_lb_x[idx] - u_pred_ub_x[idx]) ** 2) 733 | 734 | loss_u_list.append(loss_u_i) 735 | loss_f_list.append(loss_f_i) 736 | loss_b_list.append(loss_b_i) 737 | loss_u = sum(loss_u_list) / self.num_head 738 | loss_f = sum(loss_f_list) / self.num_head 739 | loss_b = sum(loss_b_list) / self.num_head 740 | 741 | for name, p in self.dnn.named_parameters(): 742 | if p.requires_grad: 743 | if p.grad is not None: 744 | if verbose: 745 | if (self.iter < self.args.epoch and self.iter % 100== 0) or (self.iter >= self.args.epoch and self.iter % self.args.print_freq == 0): 746 | if 'coeff2' not in name and 'bias' not in name: 747 | grad_u = torch.autograd.grad( 748 | loss_u, p, 749 | grad_outputs=torch.ones_like(loss_u), 750 | retain_graph=True, 751 | create_graph=False, 752 | allow_unused=True, 753 | )[0] 754 | self.writer.add_scalars(f'{name}_grad', {'grad_u_mean':grad_u.detach().abs().mean().item()}, self.iter) 755 | self.writer.add_scalars(f'{name}_grad', {'grad_u_max': grad_u.detach().abs().max().item()}, self.iter) 756 | grad_b = torch.autograd.grad( 757 | loss_b, p, 758 | grad_outputs=torch.ones_like(loss_b), 759 | retain_graph=True, 760 | create_graph=False, 761 | allow_unused=True, 762 | )[0] 763 | self.writer.add_scalars(f'{name}_grad', {'grad_b_mean':grad_b.detach().abs().mean().item()}, self.iter) 764 | self.writer.add_scalars(f'{name}_grad', {'grad_b_max': grad_b.detach().abs().max().item()}, self.iter) 765 | grad_f = torch.autograd.grad( 766 | loss_f, p, 767 | grad_outputs=torch.ones_like(loss_f), 768 | retain_graph=True, 769 | create_graph=False, 770 | allow_unused=True, 771 | )[0] 772 | self.writer.add_scalars(f'{name}_grad', {'grad_f_mean':grad_f.detach().abs().mean().item()}, self.iter) 773 | self.writer.add_scalars(f'{name}_grad', {'grad_f_max': grad_f.detach().abs().max().item()}, self.iter) 774 | 775 | 776 | loss = self.L_u*loss_u + self.L_b*loss_b + self.L_f*loss_f 777 | # loss = self.L_u*loss_u + self.L_b*loss_b 778 | if evaluate: 779 | return loss.detach().cpu().numpy(), loss_u.detach().cpu().numpy(), loss_b.detach().cpu().numpy(), loss_f.detach().cpu().numpy() 780 | 781 | # for l-bfgs 782 | if self.iter >= self.args.epoch and self.args.l2_reg > 0: 783 | l2_reg = 0.0 784 | for name, p in self.dnn.named_parameters(): 785 | if 'coeff' in name and not self.args.enable_coeff_l2_reg: 786 | continue 787 | if p.requires_grad: 788 | l2_reg += 0.5 * p.square().sum() 789 | loss += self.args.l2_reg * l2_reg 790 | 791 | if (self.args.llaf or self.args.linearpool) and self.args.use_recovery: 792 | coeff_term = [torch.exp(torch.mean(self.dnn.layers[i].coeff2)) for i in range(len(self.dnn.layers)) if isinstance(self.dnn.layers[i], (LLAF, MyLinearPool))] 793 | 794 | recovery_term = len(coeff_term) / sum(coeff_term) 795 | 796 | loss += self.args.recovery_weight * recovery_term 797 | else: 798 | recovery_term = 0 799 | 800 | if loss.requires_grad: 801 | loss.backward() 802 | if self.args.clip > 0.0: 803 | # torch.nn.utils.clip_grad_norm_(self.dnn.parameters(), self.args.clip) 804 | if self.args.coeff_clip_type == 'norm': 805 | torch.nn.utils.clip_grad_norm_(self.dnn.parameters(), self.args.clip) 806 | elif self.args.coeff_clip_type == 'value': 807 | torch.nn.utils.clip_grad_value_(self.dnn.parameters(), self.args.clip) 808 | else: 809 | raise NotImplementedError 810 | if self.args.coeff_clip > 0.0: 811 | if verbose and self.iter % self.args.print_freq == 0: 812 | pre_grad_norm = 0 813 | 814 | for p in self.optimizer.param_groups[1]['params']: 815 | # print(p.shape) 816 | if p.grad is not None: 817 | param_norm = p.grad.detach().data.norm(2) 818 | pre_grad_norm += param_norm.item() ** 2 819 | pre_grad_norm = pre_grad_norm ** 0.5 820 | if self.args.coeff_clip_type == 'norm': 821 | torch.nn.utils.clip_grad_norm_(self.optimizer.param_groups[1]['params'], self.args.coeff_clip) 822 | elif self.args.coeff_clip_type == 'value': 823 | torch.nn.utils.clip_grad_value_(self.optimizer.param_groups[1]['params'], self.args.coeff_clip) 824 | else: 825 | raise NotImplementedError 826 | if verbose and self.iter % self.args.print_freq == 0: 827 | post_grad_norm = 0 828 | for p in self.optimizer.param_groups[1]['params']: 829 | if p.grad is not None: 830 | param_norm = p.grad.detach().data.norm(2) 831 | post_grad_norm += param_norm.item() ** 2 832 | post_grad_norm = post_grad_norm ** 0.5 833 | self.logger.info(f'pre {pre_grad_norm}, post {post_grad_norm}') 834 | 835 | grad_norm = 0 836 | for name, p in self.dnn.named_parameters(): 837 | if p.requires_grad: 838 | if p.grad is not None: 839 | param_norm = p.grad.detach().data.norm(2) 840 | grad_norm += param_norm.item() ** 2 841 | grad_norm = grad_norm ** 0.5 842 | 843 | if step: 844 | self.optimizer.step() 845 | if self.args.sep_optim: 846 | self.optimizer_coeff.step() 847 | if verbose: 848 | if (self.iter < self.args.epoch and self.iter % 100== 0) or (self.iter >= self.args.epoch and self.iter % self.args.print_freq == 0): 849 | for loss_u_i, loss_b_i, loss_f_i, nu, rho, beta in zip(loss_u_list, loss_b_list, loss_f_list, self.nu, self.rho, self.beta): 850 | loss_i = self.L_u*loss_u_i + self.L_b*loss_b_i + self.L_f*loss_f_i 851 | self.logger.info( 852 | 'epoch %d, nu: %.5e, rho: %.5e, beta: %.5e, gradient: %.5e, loss: %.5e, recovery: %.5e, loss_u: %.5e, L_u: %.5e, loss_b: %.5e, L_b: %.5e, loss_f: %.5e, L_f: %.5e' % (self.iter, nu, rho, beta, grad_norm, loss_i.item(), recovery_term, loss_u_i.item(), self.L_u, loss_b_i.item(), self.L_b, loss_f_i.item(), self.L_f) 853 | ) 854 | 855 | 856 | self.writer.add_scalars(f'nu{nu}_rho{rho}_beta{beta}_loss_all', {'loss_all':loss_i.item()}, self.iter) 857 | self.writer.add_scalars(f'nu{nu}_rho{rho}_beta{beta}_loss_u', {'loss_u':loss_u_i.item()}, self.iter) 858 | self.writer.add_scalars(f'nu{nu}_rho{rho}_beta{beta}_loss_b', {'loss_b':loss_b_i.item()}, self.iter) 859 | self.writer.add_scalars(f'nu{nu}_rho{rho}_beta{beta}_loss_f', {'loss_f':loss_f_i.item()}, self.iter) 860 | if self.args.activation == 'rational': 861 | for name, para in self.dnn.named_parameters(): 862 | if 'denominator' in name or 'numerator' in name: 863 | self.logger.info(f'{name} {para}') 864 | if self.args.linearpool or self.args.llaf: 865 | for name, para in self.dnn.named_parameters(): 866 | if 'coeff' in name: 867 | if 'coeff2' in name: 868 | para_log = para.mean(dim=-1) 869 | # self.logger.info(f'{name} {para.mean(dim=-1)}') 870 | elif self.args.aggregate == 'sigmoid': 871 | para_log = torch.sigmoid(para).mean(-1) 872 | # self.logger.info(f'{name} {torch.sigmoid(para).mean(-1)}') 873 | elif self.args.aggregate == 'softmax': 874 | para_log = torch.softmax(para/self.args.tau,dim=0).mean(-1) 875 | # self.logger.info(f'{name} {torch.softmax(para/self.args.tau,dim=0).mean(-1)}') 876 | elif self.args.aggregate == 'unlimited': 877 | para_log = para.mean(-1) 878 | # self.logger.info(f'{name} {para.mean(-1)}') 879 | else: 880 | para_log = (para / para.abs().sum(dim=0)).mean(-1) 881 | # self.logger.info(f'{name} {(para / para.abs().sum(dim=0)).mean(-1)}') 882 | self.logger.info(f'{name} {para_log}') 883 | 884 | if 'coeff2' not in name and not self.args.channel_wise: 885 | self.logger.info(f'{name} {para.flatten().data}') 886 | for ele, para_log_ele in enumerate(para_log): 887 | self.writer.add_scalars(f'{name}', {f'{ele}': para_log_ele.item()}, self.iter) 888 | self.writer.add_scalars(f'{name}_value', {f'{ele}': para[ele].item()}, self.iter) 889 | 890 | 891 | if (self.iter > 0 and self.iter < self.args.epoch and self.iter % 1000 == 0) or (self.iter >= self.args.epoch and self.iter % self.args.valid_freq == 0): 892 | self.validate() 893 | if self.iter >= self.args.epoch: 894 | if self.args.plot_loss: 895 | self.draw_loss(name=f'{self.repeat}_{self.iter}_') 896 | 897 | self.iter += 1 898 | return loss 899 | 900 | def train_adam(self, adam_lr, epoch, X_star, u_star): 901 | # import pdb 902 | # pdb.set_trace() 903 | if self.args.linearpool or self.args.llaf: 904 | params_net = [] 905 | params_net_first_layer = [] 906 | params_net_second_layer = [] 907 | params_coeff = [] 908 | params_coeff_first_layer = [] 909 | for name, param in self.dnn.named_parameters(): 910 | if 'coeff' in name: 911 | if 'layer_0' in name: 912 | params_coeff_first_layer.append(param) 913 | else: 914 | params_coeff.append(param) 915 | else: 916 | # if 'layer_0' in name or 'layer_1' in name: 917 | if 'layer_0' in name: 918 | params_net_first_layer.append(param) 919 | # params_coeff_first_layer.append(param) 920 | elif 'layer_1' in name: 921 | params_net_second_layer.append(param) 922 | else: 923 | params_net.append(param) 924 | if not self.args.sep_optim: 925 | self.optimizer = choose_optimizer('AdamW', [{'params': params_net}, {'params': params_net_first_layer, 'lr': self.args.lr_first_layer}, {'params': params_net_second_layer, 'lr': self.args.lr_second_layer}, {'params': params_coeff, 'lr': self.args.coeff_lr, 'betas':(self.args.coeff_beta1, self.args.coeff_beta2)}, {'params': params_coeff_first_layer, 'lr': self.args.coeff_lr_first_layer, 'betas':(self.args.coeff_beta1, self.args.coeff_beta2)}], adam_lr, weight_decay=self.args.weight_decay) 926 | else: 927 | self.optimizer = choose_optimizer('AdamW', [{'params': params_net}, {'params': params_net_first_layer, 'lr': self.args.lr_first_layer}, {'params': params_net_second_layer, 'lr': self.args.lr_second_layer}], adam_lr, weight_decay=self.args.weight_decay) 928 | # self.optimizer = choose_optimizer('Adam', [{'params': params_net}], adam_lr) 929 | if self.args.sep_optimizer == 'sgd': 930 | self.optimizer_coeff = torch.optim.SGD([{'params': params_coeff}, {'params': params_coeff_first_layer, 'lr': self.args.coeff_lr_first_layer}], self.args.coeff_lr, momentum=self.args.momentum) 931 | elif self.args.sep_optimizer == 'adam': 932 | self.optimizer_coeff = torch.optim.Adam([{'params': params_coeff}, {'params': params_coeff_first_layer, 'lr': self.args.coeff_lr_first_layer}], self.args.coeff_lr, betas=(self.args.coeff_beta1, self.args.coeff_beta2), weight_decay=self.args.coeff_weight_decay) 933 | else: 934 | raise NotImplementedError 935 | 936 | else: 937 | if self.args.lr_first_layer>0: 938 | params_net = [] 939 | params_net_first_layer = [] 940 | 941 | for name, param in self.dnn.named_parameters(): 942 | if 'layer_0' in name: 943 | params_net_first_layer.append(param) 944 | # params_coeff_first_layer.append(param) 945 | else: 946 | params_net.append(param) 947 | self.optimizer = choose_optimizer('AdamW', [{'params': params_net}, {'params': params_net_first_layer, 'lr': self.args.lr_first_layer}], adam_lr, weight_decay=self.args.weight_decay) 948 | else: 949 | self.logger.info('use only one optimizer') 950 | self.optimizer = choose_optimizer('AdamW', self.dnn.parameters(), adam_lr, weight_decay=self.args.weight_decay) 951 | 952 | warm_up_iter = self.args.warm_up_iter 953 | lr_min = 1e-3 954 | lr_max = 1 955 | T_max = epoch if self.args.T_max == 0 else self.args.T_max 956 | if self.args.cosine_decay: 957 | if self.args.constant_warmup > 0: 958 | lambda0 = lambda cur_iter: self.args.constant_warmup if cur_iter < warm_up_iter else \ 959 | (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((cur_iter-warm_up_iter)/(T_max-warm_up_iter)*math.pi))) 960 | else: 961 | lambda0 = lambda cur_iter: 1e-3 + (lr_max/2**(cur_iter//T_max) - 1e-3) * (cur_iter%T_max) / warm_up_iter if (cur_iter%T_max) < warm_up_iter else \ 962 | (lr_min + 0.5*(lr_max/2**(cur_iter//T_max)-lr_min)*(1.0+math.cos(((cur_iter%T_max)-warm_up_iter)/(T_max-warm_up_iter)*math.pi))) 963 | else: 964 | lambda0 = lambda cur_iter: 1e-3 + (1 - 1e-3) * cur_iter / warm_up_iter if cur_iter < warm_up_iter else 1.0 965 | schduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda0) 966 | if self.args.sep_optim: 967 | sep_warm_up_iter = self.args.sep_warm_up_iter 968 | lr_min = 1e-3 969 | lr_max = 1.0 970 | T_max = epoch 971 | if self.args.sep_cosine_decay: 972 | if self.args.constant_warmup > 0: 973 | sep_lambda0 = lambda cur_iter: self.args.constant_warmup if cur_iter < sep_warm_up_iter else \ 974 | (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((cur_iter-sep_warm_up_iter)/(T_max-sep_warm_up_iter)*math.pi))) 975 | else: 976 | sep_lambda0 = lambda cur_iter: 1e-3 + (1 - 1e-3) * cur_iter / sep_warm_up_iter if cur_iter < sep_warm_up_iter else \ 977 | (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((cur_iter-sep_warm_up_iter)/(T_max-sep_warm_up_iter)*math.pi))) 978 | else: 979 | sep_lambda0 = lambda cur_iter: 1e-3 + (1 - 1e-3) * cur_iter / sep_warm_up_iter if cur_iter < sep_warm_up_iter else 1.0 980 | scheduler_coeff = torch.optim.lr_scheduler.LambdaLR(self.optimizer_coeff, lr_lambda=sep_lambda0) 981 | # schduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, epoch, eta_min=1e-3*adam_lr) 982 | self.dnn.train() 983 | self.logger.info('>>>>>> Adam optimizer') 984 | 985 | self.args.pre_use_recovery = self.args.use_recovery 986 | 987 | for epoch_n in range(epoch): 988 | loss = self.loss_pinn(verbose=True,step=True) 989 | # self.logger.info(f"lr {self.optimizer.param_groups[0]['lr']}") 990 | # self.writer.add_scalars('lr', {'pg0':self.optimizer.param_groups[0]['lr']}, self.iter) 991 | 992 | if (epoch_n+1) % 10000 == 0: 993 | if self.args.plot_loss: 994 | self.draw_loss(name=f'{self.repeat}_{self.iter}_') 995 | if not self.args.not_adapt_adam_lr: 996 | schduler.step() 997 | if self.args.sep_optim: 998 | scheduler_coeff.step() 999 | 1000 | torch.save({'state_dict':self.dnn.state_dict()}, os.path.join(self.args.work_dir,f'model_adam_{self.repeat}.pth.tar') ) 1001 | if self.args.checkpoint != '': 1002 | self.logger.info(f'load model from {self.args.checkpoint}') 1003 | checkpoint = torch.load(self.args.checkpoint) 1004 | self.dnn.load_state_dict(checkpoint['state_dict']) 1005 | 1006 | 1007 | 1008 | self.logger.info(f'>>>>>> {self.optimizer_name} optimizer') 1009 | 1010 | 1011 | self.optimizer = choose_optimizer(self.optimizer_name, self.dnn.parameters(), self.lr, max_iter=self.args.max_iter, line_search_fn=self.args.line_search_fn) 1012 | self.dnn.train() 1013 | if not self.args.disable_lbfgs: 1014 | self.optimizer.step(self.loss_pinn) 1015 | torch.save({'state_dict':self.dnn.state_dict()}, os.path.join(self.args.work_dir,f'model_lbfgs_{self.repeat}.pth.tar') ) 1016 | 1017 | 1018 | if self.args.plot_loss: 1019 | self.draw_loss(name=str(self.repeat)) 1020 | 1021 | return self.loss_pinn(evaluate=True) 1022 | 1023 | def test_loss_f(self, X_f_train): 1024 | self.x_f = torch.tensor(X_f_train[:, 0:1], requires_grad=True).double().to(self.device) 1025 | self.t_f = torch.tensor(X_f_train[:, 1:2], requires_grad=True).double().to(self.device) 1026 | self.dnn.train() 1027 | f_pred = self.net_f(self.x_f, self.t_f) 1028 | loss_f = f_pred ** 2 1029 | self.logger.info(f'{loss_f.mean():.5e}') 1030 | return loss_f.detach().cpu().numpy() 1031 | 1032 | def draw_loss(self, name=''): 1033 | self.dnn.train() 1034 | u_pred = self.net_u(self.x_u, self.t_u) 1035 | u_pred_lb = self.net_u(self.x_bc_lb, self.t_bc_lb) 1036 | u_pred_ub = self.net_u(self.x_bc_ub, self.t_bc_ub) 1037 | if 'CH' in self.system and self.args.decouple: 1038 | u_pred = [u_i[:, 0:1] for u_i in u_pred] 1039 | u_pred_lb = [u_i[:, 0:1] for u_i in u_pred_lb] 1040 | u_pred_ub = [u_i[:, 0:1] for u_i in u_pred_ub] 1041 | x_f = torch.tensor(self.X_star[:, 0:1], requires_grad=True).double().to(self.device) 1042 | t_f = torch.tensor(self.X_star[:, 1:2], requires_grad=True).double().to(self.device) 1043 | 1044 | # iter_num = math.ceil(x_f.shape[0] / self.args.N_f) 1045 | iter_num = 0 1046 | f_pred = [list() for _ in range(self.num_head)] 1047 | u_t_plot = [list() for _ in range(self.num_head)] 1048 | u_x_plot = [list() for _ in range(self.num_head)] 1049 | while(iter_num < x_f.shape[0]): 1050 | f_pred_i, u_tx_i = self.net_f(x_f[int(iter_num):int(iter_num+self.args.N_f)], t_f[int(iter_num):int(iter_num+self.args.N_f)], return_gradient=True) 1051 | iter_num += self.args.N_f 1052 | for f_list, f_pred_i_head in zip(f_pred, f_pred_i): 1053 | f_list.append(f_pred_i_head.detach()) 1054 | del f_pred_i_head 1055 | for u_t_list, u_x_list, (u_t_i, u_x_i) in zip(u_t_plot, u_x_plot, u_tx_i): 1056 | u_t_list.append(u_t_i) 1057 | u_x_list.append(u_x_i) 1058 | del u_t_i, u_x_i 1059 | # f_pred.append(f_pred_i.detach()) 1060 | 1061 | f_pred = [torch.cat(f_list) for f_list in f_pred] 1062 | if 'CH' in self.system and self.args.decouple: 1063 | # f_pred = [f_pred_i.abs().sum(dim=-1) for f_pred_i in f_pred] 1064 | f_pred1 = [f_pred_i[:,0:1] for f_pred_i in f_pred] 1065 | f_pred2 = [f_pred_i[:,1:2] for f_pred_i in f_pred] 1066 | u_t_plot = [torch.cat(u_t_list) for u_t_list in u_t_plot] 1067 | u_x_plot = [torch.cat(u_x_list) for u_x_list in u_x_plot] 1068 | u_pred_f = self.net_u(x_f, t_f) 1069 | if 'CH' in self.system and self.args.decouple: 1070 | u_pred_f = [u_i[:, 0:1] for u_i in u_pred_f] 1071 | 1072 | for idx in range(self.num_head): 1073 | if 'CH' in self.system and self.args.decouple: 1074 | loss_u = (self.u[:,0:1] - u_pred[idx]).abs() 1075 | else: 1076 | loss_u = (self.u - u_pred[idx]).abs() 1077 | loss_u = loss_u.detach().cpu().numpy() 1078 | # import pdb 1079 | # pdb.set_trace() 1080 | loss_b = (u_pred_lb[idx] - u_pred_ub[idx]).abs() 1081 | if self.nu[idx] != 0: 1082 | u_pred_lb_x, u_pred_ub_x = self.net_b_derivatives(u_pred_lb[idx], u_pred_ub[idx], self.x_bc_lb, self.x_bc_ub) 1083 | loss_b += (u_pred_lb_x - u_pred_ub_x).abs() 1084 | u_pred_lb_i = u_pred_lb[idx].detach().cpu().numpy() 1085 | u_pred_ub_i = u_pred_ub[idx].detach().cpu().numpy() 1086 | loss_b = loss_b.detach().cpu().numpy() 1087 | u_b = self.Exact[idx][:, 0:1] 1088 | diff_lb = np.abs(u_b - u_pred_lb_i) 1089 | diff_ub = np.abs(u_b - u_pred_ub_i) 1090 | fig, axes = plt.subplots(3,1,figsize=(9, 12)) 1091 | 1092 | axes[0].plot(self.x_u.reshape(-1).detach().cpu().numpy(), loss_u.reshape(-1)) 1093 | axes[0].set_title('loss_initial') 1094 | axes[1].plot(self.t_bc_lb.reshape(-1).detach().cpu().numpy(), loss_b.reshape(-1)) 1095 | axes[1].set_title('loss_boundary') 1096 | axes[2].plot(self.t_bc_lb.reshape(-1).detach().cpu().numpy(), diff_lb.reshape(-1), c='r') 1097 | axes[2].plot(self.t_bc_lb.reshape(-1).detach().cpu().numpy(), diff_ub.reshape(-1), c='b') 1098 | axes[2].legend(['lower bound', 'upper bound']) 1099 | axes[2].set_title('error_boundary') 1100 | plt.savefig(os.path.join(self.args.work_dir, name+f'loss_initial_boundary_nu{self.nu[idx]}_rho{self.rho[idx]}_beta{self.beta[idx]}.png')) 1101 | plt.close(fig) 1102 | 1103 | 1104 | 1105 | if 'CH' in self.system and self.args.decouple: 1106 | loss_f = (f_pred1[idx].abs()).detach().cpu().numpy().reshape(self.args.nt, self.args.xgrid) 1107 | loss_f2 = (f_pred2[idx].abs()).detach().cpu().numpy().reshape(self.args.nt, self.args.xgrid) 1108 | loss_f[0,:] = 0 1109 | loss_f2[0,:] = 0 1110 | fig = plt.figure(figsize=(9, 6)) 1111 | ax0 = fig.add_subplot(111) 1112 | h0 = ax0.imshow(loss_f2.T, interpolation='nearest', cmap='rainbow', 1113 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1114 | origin='lower', aspect='auto') 1115 | divider0 = make_axes_locatable(ax0) 1116 | cax0 = divider0.append_axes("right", size="5%", pad=0.10) 1117 | cbar0 = fig.colorbar(h0, cax=cax0) 1118 | cbar0.ax.tick_params(labelsize=15) 1119 | 1120 | ax0.set_xlabel('t', fontweight='bold', size=15) 1121 | ax0.set_ylabel('x', fontweight='bold', size=15) 1122 | ax0.set_title('loss_pde2') 1123 | plt.savefig(os.path.join(self.args.work_dir, name+f'loss_pde2_nu{self.nu[idx]}_rho{self.rho[idx]}_beta{self.beta[idx]}.png')) 1124 | plt.close(fig) 1125 | else: 1126 | loss_f = (f_pred[idx].abs()).detach().cpu().numpy().reshape(self.args.nt, self.args.xgrid) 1127 | loss_f[0,:] = 0 1128 | u_pred_f_i = u_pred_f[idx].reshape(self.args.nt, self.args.xgrid).detach().cpu().numpy() 1129 | diff_u = np.abs(self.Exact[idx]-u_pred_f_i) 1130 | fig = plt.figure(figsize=(9, 18)) 1131 | ax0 = fig.add_subplot(311) 1132 | 1133 | h0 = ax0.imshow(u_pred_f_i.T, interpolation='nearest', cmap='rainbow', 1134 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1135 | origin='lower', aspect='auto') 1136 | divider0 = make_axes_locatable(ax0) 1137 | cax0 = divider0.append_axes("right", size="5%", pad=0.10) 1138 | cbar0 = fig.colorbar(h0, cax=cax0) 1139 | cbar0.ax.tick_params(labelsize=15) 1140 | ax0.set_title('predition') 1141 | 1142 | ax0.set_xlabel('t', fontweight='bold', size=15) 1143 | ax0.set_ylabel('x', fontweight='bold', size=15) 1144 | 1145 | ax1 = fig.add_subplot(312) 1146 | 1147 | h1 = ax1.imshow(diff_u.T, interpolation='nearest', cmap='rainbow', 1148 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1149 | origin='lower', aspect='auto') 1150 | divider1 = make_axes_locatable(ax1) 1151 | cax1 = divider1.append_axes("right", size="5%", pad=0.10) 1152 | cbar1 = fig.colorbar(h1, cax=cax1) 1153 | cbar1.ax.tick_params(labelsize=15) 1154 | ax1.set_title('error') 1155 | 1156 | ax1.set_xlabel('t', fontweight='bold', size=15) 1157 | ax1.set_ylabel('x', fontweight='bold', size=15) 1158 | 1159 | ax2 = fig.add_subplot(313) 1160 | 1161 | h2 = ax2.imshow(loss_f.T, interpolation='nearest', cmap='rainbow', 1162 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1163 | origin='lower', aspect='auto') 1164 | divider2 = make_axes_locatable(ax2) 1165 | cax2 = divider2.append_axes("right", size="5%", pad=0.10) 1166 | cbar2 = fig.colorbar(h2, cax=cax2) 1167 | cbar2.ax.tick_params(labelsize=15) 1168 | 1169 | ax2.set_xlabel('t', fontweight='bold', size=15) 1170 | ax2.set_ylabel('x', fontweight='bold', size=15) 1171 | ax2.set_title('loss_pde') 1172 | 1173 | plt.savefig(os.path.join(self.args.work_dir, name+f'loss_pde_nu{self.nu[idx]}_rho{self.rho[idx]}_beta{self.beta[idx]}.png')) 1174 | plt.close(fig) 1175 | 1176 | fig = plt.figure(figsize=(9, 18)) 1177 | ax0 = fig.add_subplot(211) 1178 | 1179 | u_t_pred = u_t_plot[idx].reshape(self.args.nt, self.args.xgrid).detach().cpu().numpy() 1180 | u_x_pred = u_x_plot[idx].reshape(self.args.nt, self.args.xgrid).detach().cpu().numpy() 1181 | h0 = ax0.imshow(u_t_pred.T, interpolation='nearest', cmap='binary', 1182 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1183 | origin='lower', aspect='auto') 1184 | divider0 = make_axes_locatable(ax0) 1185 | cax0 = divider0.append_axes("right", size="5%", pad=0.10) 1186 | cbar0 = fig.colorbar(h0, cax=cax0) 1187 | cbar0.ax.tick_params(labelsize=15) 1188 | ax0.set_title('u_t') 1189 | 1190 | ax0.set_xlabel('t', fontweight='bold', size=15) 1191 | ax0.set_ylabel('x', fontweight='bold', size=15) 1192 | 1193 | ax1 = fig.add_subplot(212) 1194 | 1195 | h1 = ax1.imshow(u_x_pred.T, interpolation='nearest', cmap='binary', 1196 | extent=[t_f.detach().cpu().numpy().min(), t_f.detach().cpu().numpy().max(), x_f.detach().cpu().numpy().min(), x_f.detach().cpu().numpy().max()], 1197 | origin='lower', aspect='auto') 1198 | divider1 = make_axes_locatable(ax1) 1199 | cax1 = divider1.append_axes("right", size="5%", pad=0.10) 1200 | cbar1 = fig.colorbar(h1, cax=cax1) 1201 | cbar1.ax.tick_params(labelsize=15) 1202 | ax1.set_title('u_x') 1203 | 1204 | ax1.set_xlabel('t', fontweight='bold', size=15) 1205 | ax1.set_ylabel('x', fontweight='bold', size=15) 1206 | 1207 | plt.savefig(os.path.join(self.args.work_dir, name+f'gradient_nu{self.nu[idx]}_rho{self.rho[idx]}_beta{self.beta[idx]}.png')) 1208 | plt.close(fig) 1209 | 1210 | def validate(self): 1211 | self.dnn_auxiliary.load_state_dict(self.dnn.state_dict()) 1212 | u_pred = self.predict(self.X_star, return_all=True, use_auxiliary=True) 1213 | f_pred = self.evaluate_loss_f(self.X_star, return_all=True, use_auxiliary=True) 1214 | 1215 | 1216 | for u_pred_i, f_pred_i, Exact, nu, beta, rho in zip(u_pred, f_pred, self.Exact, self.nu, self.beta, self.rho): 1217 | u_star = Exact.reshape(-1, 1) 1218 | 1219 | error_u_relative = np.linalg.norm(u_star-u_pred_i, 2)/np.linalg.norm(u_star, 2) 1220 | 1221 | error_u_abs = np.mean(np.abs(u_star - u_pred_i)) 1222 | error_u_linf = np.linalg.norm(u_star - u_pred_i, np.inf)/np.linalg.norm(u_star, np.inf) 1223 | error_f_test = np.mean(f_pred_i ** 2) 1224 | self.logger.info(f"lr {self.optimizer.param_groups[0]['lr']}") 1225 | if self.args.sep_optim: 1226 | self.logger.info(f"lr {self.optimizer_coeff.param_groups[0]['lr']}") 1227 | self.logger.info(f'Head for nu {nu}, rho {rho}, beta {beta}') 1228 | self.logger.info('Error u rel: %e' % (error_u_relative)) 1229 | self.logger.info('Error u abs: %e' % (error_u_abs)) 1230 | self.logger.info('Error u linf: %e' % (error_u_linf)) 1231 | self.logger.info('Loss f test: %e' % (error_f_test)) 1232 | # import pdb 1233 | # pdb.set_trace() 1234 | self.writer.add_scalars('error', {'loss_F':error_f_test}, self.iter) 1235 | self.writer.add_scalars('error', {'relative':error_u_relative}, self.iter) 1236 | 1237 | 1238 | def predict(self, X, return_all=False, use_auxiliary=False): 1239 | x = torch.tensor(X[:, 0:1], requires_grad=True).double().to(self.device) 1240 | t = torch.tensor(X[:, 1:2], requires_grad=True).double().to(self.device) 1241 | if use_auxiliary: 1242 | self.dnn_auxiliary.eval() 1243 | 1244 | u = self.net_u(x, t, use_auxiliary=use_auxiliary) 1245 | if 'CH' in self.system and self.args.decouple: 1246 | u = [u_i[:, 0:1].detach().cpu().numpy() for u_i in u] 1247 | else: 1248 | u = [u_i.detach().cpu().numpy() for u_i in u] 1249 | if return_all: 1250 | return u 1251 | else: 1252 | return u[-1] 1253 | 1254 | def evaluate_loss_f(self, X, return_all=False, use_auxiliary=False): 1255 | x = torch.tensor(X[:, 0:1], requires_grad=True).double().to(self.device) 1256 | t = torch.tensor(X[:, 1:2], requires_grad=True).double().to(self.device) 1257 | iter_num = 0 1258 | f_pred = [list() for _ in range(self.num_head)] 1259 | while(iter_num < x.shape[0]): 1260 | f_pred_i = self.net_f(x[int(iter_num):int(iter_num+self.args.N_f)], t[int(iter_num):int(iter_num+self.args.N_f)], use_auxiliary=use_auxiliary) 1261 | iter_num += self.args.N_f 1262 | for f_list, f_pred_i_head in zip(f_pred, f_pred_i): 1263 | f_list.append(f_pred_i_head.detach().cpu().numpy()) 1264 | del f_pred_i_head 1265 | f_pred = [np.vstack(f_list) for f_list in f_pred] 1266 | 1267 | # f_pred = self.net_f(x, t, use_auxiliary=use_auxiliary) 1268 | # f_pred = [f_pred_i.detach().cpu().numpy() for f_pred_i in f_pred] 1269 | if return_all: 1270 | return f_pred 1271 | else: 1272 | return f_pred[-1] 1273 | 1274 | --------------------------------------------------------------------------------