├── .gitignore ├── LICENSE ├── README.md ├── data ├── add_noise.py ├── generate_CD.py ├── generate_KS.py ├── generate_NLSE.py └── generate_fiber.py ├── input_files ├── CD_train.json ├── KS_train.json ├── NLSE_train.json ├── fiber_train.json └── template.json ├── models ├── pde1d.py ├── pde1d_decoder_only.py ├── pde2d.py └── pde2d_decoder_only.py ├── npz_dataset.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Peter Y. Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDE-VAE: Variational Autoencoder for *Extracting Interpretable Physical Parameters from Spatiotemporal Systems using Unsupervised Learning* 2 | 3 | Implementation of a variational autoencoder (VAE)-based method for extracting interpretable physical parameters (from spatiotemporal data) that parameterize the dynamics of a spatiotemporal system, e.g. a system governed by a partial differential equation (PDE). 4 | 5 | Please cite "**Extracting Interpretable Physical Parameters from Spatiotemporal Systems using Unsupervised Learning**" (https://journals.aps.org/prx/abstract/10.1103/PhysRevX.10.031056) and see the paper for more details. This is the official repository for the paper. 6 | 7 | ## Requirements 8 | PyTorch version >= 1.1.0, NumPy 9 | 10 | (Note: Dataset generation scripts have additional requirements is some cases.) 11 | 12 | ## Usage 13 | ### Dataset 14 | The dataset generation scripts for the datasets in the paper are located in the "data/" folder. Data is loaded using the PyTorch dataloader framework. To use the existing dataset loader, format the data as a NumPy array with shape: 15 | 16 | ``` 17 | (dataset size, data channels, propagation dimension, spatial dimension 1, spatial dimension 2, ...) 18 | ``` 19 | Currently, only datasets with 1 or 2 spatial dimensions are supported. The propagation dimension is usually the time direction. 20 | 21 | ### Training 22 | Hyperparameter and architecture adjustments can be made using the input file. Examples are located in the "input\_files/" folder (see "input\_files/template.json" for a description of each setting). For training, make sure the "train" parameter is set to *true* in the input file, then run: 23 | 24 | ```bash 25 | python run.py input_file.json > out 26 | ``` 27 | 28 | ### Evaluation 29 | To run the provided evaluation script, change the "train" parameter to *false* in the input file, and make sure to set "MODELLOAD" in the input file to the path of the trained model save. Then, rerun the same input file. Note that even if crop boundaries are used, the evaluation method will no longer crop to smaller sizes and instead evaluates on the full dataset, so adjustments may need to be made to the boundary conditions and batch size. 30 | 31 | Custom evaluation routines are recommended for detailed data analysis. For example, using the "pde1d_decoder_only" model (or "pde2d_decoder_only" model) with weights loaded from the trained model allows you to manually tune the latent parameters and observe the predicted propagation (given an input initial condition). This may aid in interpreting the extracted relevant parameters. 32 | -------------------------------------------------------------------------------- /data/add_noise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | add_noise.py 5 | 6 | Script adding white noise to the input dataset (*.npz file) with stdv adjusted below. 7 | 8 | Usage: 9 | python add_noise.py dataset.npz 10 | """ 11 | 12 | import os 13 | import sys 14 | import numpy as np 15 | 16 | FILENAME = sys.argv[1] 17 | 18 | data = np.load(FILENAME) 19 | 20 | x = data['x'] 21 | 22 | noise = 0.1 23 | x += np.random.normal(loc=0, scale=noise, size=x.shape) 24 | 25 | print("Saving to: " + os.path.splitext(FILENAME)[0] + "_noise" + str(noise) + ".npz") 26 | np.savez(os.path.splitext(FILENAME)[0] + "_noise" + str(noise) + ".npz", x=x, params=data['params']) 27 | -------------------------------------------------------------------------------- /data/generate_CD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | generate_CD.py 5 | 6 | Script for generating the 2D convection-diffusion dataset. 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | from scipy.fftpack import fft2, ifft2 12 | from scipy.stats import truncnorm 13 | 14 | import argparse 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generate convection-diffusion dataset.") 18 | parser.add_argument('-d', '--datasize', type=int, default=1000, 19 | help="Size of generated dataset--should be a multiple of batch_size. Default: 1000") 20 | parser.add_argument('-b', '--batchsize', type=int, default=50, 21 | help="Batch size for generating dataset in parallel--should divide data_size. Default: 50") 22 | parser.add_argument('-f', '--filename', type=str, default='CD_dataset.npz', 23 | help="Path with file name ending in .npz where dataset is saved. Default: CD_dataset.npz") 24 | 25 | args = parser.parse_args() 26 | data_size = args.datasize 27 | batch_size = args.batchsize 28 | FILENAME = args.filename 29 | 30 | l = 16*np.pi # system size 31 | mesh = 256 # mesh 32 | tmax = 4*np.pi # max time 33 | tmesh = 64 34 | dt = tmax/tmesh # time step 35 | shift = 0 # shift time to exclude initial conditions, set to 0 to keep t = 0 36 | 37 | dmean = 0.1 # diffusion constant 38 | dstd = dmean/4 39 | velstd = 0.5 # standard deviation of velocity 40 | 41 | kx = np.expand_dims(2*np.pi * np.fft.fftfreq(mesh, d=l/mesh), axis=-1) 42 | ky = np.expand_dims(2*np.pi * np.fft.fftfreq(mesh, d=l/mesh), axis=0) 43 | 44 | # for use in 1st derivative 45 | kx_1 = kx.copy() 46 | kx_1[int(mesh/2)] = 0 47 | ky_1 = ky.copy() 48 | ky_1[:, int(mesh/2)] = 0 49 | 50 | 51 | ### Generate data 52 | u_list = [] 53 | d_list = [] 54 | velx_list = [] 55 | vely_list = [] 56 | for i in range(int(data_size/batch_size)): 57 | print('Batch ' + str(i+1) + ' of ' + str(int(data_size/batch_size))) 58 | 59 | d = truncnorm.rvs((0 - dmean) / dstd, (2 * dmean - dmean) / dstd, 60 | loc=dmean, scale=dstd, size=(batch_size, 1, 1, 1)) 61 | d_list.append(d.astype(np.float32)) 62 | 63 | velx = np.random.normal(loc=0, scale=velstd, size=(batch_size, 1, 1, 1)) 64 | vely = np.random.normal(loc=0, scale=velstd, size=(batch_size, 1, 1, 1)) 65 | velx_list.append(velx.astype(np.float32)) 66 | vely_list.append(vely.astype(np.float32)) 67 | 68 | ## initial condition 69 | krange = (0.25 * mesh*np.pi/l - 8 * np.pi/l) * np.random.rand(batch_size, 1, 1, 1) + 8 * np.pi/l 70 | envelope = np.exp(-1/(2*krange**2) * (kx**2 + ky**2) ) 71 | v0 = envelope * (np.random.normal(loc=0, scale=1.0, size=(batch_size, 1, mesh, mesh)) 72 | + 1j*np.random.normal(loc=0, scale=1.0, size=(batch_size, 1, mesh, mesh))) 73 | u0 = np.real(ifft2(v0)) 74 | u0 = mesh * u0/np.linalg.norm(u0, axis=(-2,-1), keepdims=True) # normalize 75 | v0 = fft2(u0) 76 | 77 | ## Differential equation 78 | L = -d * (kx**2 + ky**2) - 1j * (kx_1 * velx + ky_1 * vely) 79 | 80 | t = np.linspace(shift, tmax + shift, tmesh, endpoint=False) 81 | v = np.exp(np.expand_dims(np.expand_dims(t, -1), -1) * L) * v0 82 | u = np.real(ifft2(v)) 83 | u_list.append(np.expand_dims(u, axis=1).astype(np.float32)) 84 | 85 | 86 | u_list = np.concatenate(u_list) 87 | d_list = np.concatenate(d_list) 88 | velx_list = np.concatenate(velx_list) 89 | vely_list = np.concatenate(vely_list) 90 | 91 | ## shape of u_list = (data_size, 1, tmesh, mesh, mesh) 92 | print(u_list.shape) 93 | print(u_list.dtype) 94 | 95 | ## shape of d_list = (data_size, 1, 1, 1) 96 | print(d_list.shape) 97 | print(d_list.dtype) 98 | 99 | print('Exporting to: ' + FILENAME) 100 | np.savez(FILENAME, x=u_list, 101 | params=np.stack([d_list.flatten(), velx_list.flatten(), vely_list.flatten()], axis=1)) 102 | -------------------------------------------------------------------------------- /data/generate_KS.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | generate_KS.py 5 | 6 | Script for generating the 1D Kuramoto-Sivashinsky dataset. 7 | """ 8 | 9 | import os 10 | 11 | import numpy as np 12 | from scipy.fftpack import fft, ifft 13 | from scipy.stats import truncnorm 14 | from scipy.signal import resample 15 | 16 | from joblib import Parallel, delayed 17 | import multiprocessing 18 | 19 | import warnings 20 | import argparse 21 | 22 | 23 | def generateData(i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, numean, nustd): 24 | ### Generate data 25 | 26 | ### Uses ETDRK4 method for integrating stiff PDEs 27 | ### https://epubs.siam.org/doi/pdf/10.1137/S1064827502410633 28 | 29 | # np.random.seed() 30 | # print(np.random.rand()) 31 | 32 | print('Batch ' + str(i+1) + ' of ' + str(int(data_size/batch_size))) 33 | 34 | nu = truncnorm.rvs((0.5 - numean) / nustd, (1.5 - numean) / nustd, 35 | loc=numean, scale=nustd, size=(batch_size, 1), 36 | random_state=np.random.RandomState()) 37 | print(nu) 38 | 39 | lamb = 1.0 40 | 41 | pool = max(int(10 * l/(out_mesh * np.amin(nu))), 1) 42 | print('Pooling: ' + str(pool)) 43 | tpool = 2 * pool 44 | mesh = out_mesh * pool 45 | tmesh = out_tmesh * tpool 46 | dt = tmax/tmesh # time step 47 | 48 | k = 2*np.pi * np.fft.fftfreq(mesh, d=l/mesh) 49 | 50 | ## initial condition 51 | krange = ((out_mesh/8)*np.pi/l - 4 * np.pi/l) * np.random.rand(batch_size, 1) + 4 * np.pi/l 52 | envelope = np.exp(-1/(2*krange**2) * k**2) 53 | v0 = envelope * (np.random.normal(loc=0, scale=1.0, size=(batch_size, mesh)) 54 | + 1j*np.random.normal(loc=0, scale=1.0, size=(batch_size, mesh))) 55 | u0 = np.real(ifft(v0)) 56 | u0 = np.sqrt(mesh) * u0/np.expand_dims(np.linalg.norm(u0, axis=-1), axis=-1) # normalize 57 | v0 = fft(u0) 58 | 59 | ## differential equation 60 | L = lamb * k**2 - nu * k**4 61 | N = lambda v: -0.5j * k * fft(np.real(ifft(v))**2) 62 | 63 | ## ETDRK4 method 64 | E = np.exp(dt * L) 65 | E2 = np.exp(dt * L / 2.0) 66 | 67 | contour_radius = 1 68 | M = 16 69 | r = contour_radius*np.exp(1j * np.pi * (np.arange(1, M + 1) - 0.5) / M) 70 | r = r.reshape(1, -1) 71 | r_contour = np.repeat(r, mesh, axis=0) 72 | 73 | LR = dt * L 74 | LR = np.expand_dims(LR, axis=-1) 75 | LR = np.repeat(LR, M, axis=-1) 76 | LR = LR + r_contour 77 | 78 | Q = dt*np.real(np.mean( (np.exp(LR/2.0)-1)/LR, axis=-1 )) 79 | f1 = dt*np.real(np.mean( (-4.0-LR + np.exp(LR)*(4.0-3.0*LR+LR**2))/LR**3, axis=-1 )) 80 | f2 = dt*np.real(np.mean( (2.0+LR + np.exp(LR)*(-2.0 + LR))/LR**3, axis=-1 )) 81 | f3 = dt*np.real(np.mean( (-4.0-3.0*LR - LR**2 + np.exp(LR)*(4.0 - LR))/LR**3, axis=-1 )) 82 | 83 | t = 0.0 84 | u = [] 85 | v = v0 86 | tpool_num = 0 87 | 88 | # catch overflow warnings and rerun the data generation 89 | with warnings.catch_warnings(record=True) as w: 90 | warnings.simplefilter(action='ignore', category=FutureWarning) 91 | 92 | while t < tmax + dt + shift: 93 | if t >= shift and len(u) < out_tmesh and tpool_num % tpool == 0: # exclude first 'shift' time 94 | u.append(resample(np.real(ifft(v)), out_mesh, axis=-1)) 95 | 96 | Nv = N(v) 97 | a = E2 * v + Q * Nv 98 | Na = N(a) 99 | b = E2 * v + Q * Na 100 | Nb = N(b) 101 | c = E2 * a + Q * (2.0*Nb - Nv) 102 | Nc = N(c) 103 | v = E*v + Nv*f1 + 2.0*(Na + Nb)*f2 + Nc*f3 104 | 105 | t = t + dt 106 | tpool_num += 1 107 | 108 | if w: 109 | print('Rerunning...') 110 | return generateData(i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, numean, nustd) 111 | 112 | assert len(u) == out_tmesh 113 | 114 | return np.expand_dims(np.stack(u, axis=-2), axis=1).astype(np.float32), nu.astype(np.float32) 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser(description="Generate Kuramoto-Sivashinsky dataset.") 119 | parser.add_argument('-d', '--datasize', type=int, default=5000, 120 | help="Size of generated dataset--should be a multiple of batch_size. Default: 5000") 121 | parser.add_argument('-b', '--batchsize', type=int, default=1, 122 | help="Batch size for generating dataset in parallel--should divide data_size. Default: 1") 123 | parser.add_argument('-f', '--filename', type=str, default='KS_dataset.npz', 124 | help="Path with file name ending in .npz where dataset is saved. Default: KS_dataset.npz") 125 | 126 | args = parser.parse_args() 127 | data_size = args.datasize 128 | batch_size = args.batchsize 129 | FILENAME = args.filename 130 | 131 | l = 64*np.pi # system size 132 | out_mesh = 256 # mesh 133 | tmax = 32*np.pi # max time 134 | out_tmesh = 256 # time mesh 135 | shift = 0 * tmax/out_mesh # shift time to exclude initial conditions, set to 0 to keep t = 0 136 | 137 | numean = 1.0 138 | nustd = 0.125 139 | 140 | num_cores = multiprocessing.cpu_count() 141 | print('Using ' + str(num_cores) + ' cores...') 142 | out_list = Parallel(n_jobs=num_cores)(delayed(generateData) 143 | (i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, numean, nustd) 144 | for i in range(int(data_size/batch_size))) 145 | 146 | u_list, nu_list = [[data[i] for data in out_list] for i in range(2)] 147 | 148 | u_list = np.concatenate(u_list) 149 | nu_list = np.concatenate(nu_list) 150 | 151 | ## shape of u_list = (data_size, 1, out_tmesh, out_mesh) 152 | print(u_list.shape) 153 | print(u_list.dtype) 154 | 155 | ## shape of nu_list = (data_size, 1) 156 | print(nu_list.shape) 157 | print(nu_list.dtype) 158 | 159 | print('Exporting to: ' + FILENAME) 160 | np.savez(FILENAME, x=u_list, params=np.stack([nu_list.flatten()], axis=1)) 161 | 162 | -------------------------------------------------------------------------------- /data/generate_NLSE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | generate_NLSE.py 5 | 6 | Script for generating the 1D nonlinear Schrödinger dataset. 7 | """ 8 | 9 | import os 10 | 11 | import numpy as np 12 | from scipy.fftpack import fft, ifft 13 | from scipy.stats import truncnorm 14 | from scipy.signal import resample 15 | 16 | from joblib import Parallel, delayed 17 | import multiprocessing 18 | 19 | import warnings 20 | import argparse 21 | 22 | 23 | def generateData(i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, kappa_mean, kappa_std): 24 | ### Generate data 25 | 26 | ### Uses ETDRK4 method for integrating stiff PDEs 27 | ### https://epubs.siam.org/doi/pdf/10.1137/S1064827502410633 28 | 29 | # np.random.seed() 30 | # print(np.random.rand()) 31 | 32 | print('Batch ' + str(i+1) + ' of ' + str(int(data_size/batch_size))) 33 | 34 | kappa = truncnorm.rvs((-2 - kappa_mean) / kappa_std, (0 - kappa_mean) / kappa_std, 35 | loc=kappa_mean, scale=kappa_std, size=(batch_size, 1), 36 | random_state=np.random.RandomState()) 37 | 38 | pool = 10 39 | print('Pooling: ' + str(pool)) 40 | tpool = 20 41 | mesh = out_mesh * pool 42 | tmesh = out_tmesh * tpool 43 | dt = tmax/tmesh # time step 44 | 45 | k = 2*np.pi * np.fft.fftfreq(mesh, d=l/mesh) 46 | 47 | ## initial condition 48 | krange = ((out_mesh/8)*np.pi/l - 4 * np.pi/l) * np.random.rand(batch_size, 1) + 4 * np.pi/l 49 | envelope = np.exp(-1/(2*krange**2) * k**2) 50 | v0 = envelope * (np.random.normal(loc=0, scale=1.0, size=(batch_size, mesh)) 51 | + 1j*np.random.normal(loc=0, scale=1.0, size=(batch_size, mesh))) 52 | u0 = ifft(v0) 53 | u0 = np.sqrt(2 * mesh) * u0/np.expand_dims(np.linalg.norm(u0, axis=-1), axis=-1) # normalize 54 | v0 = fft(u0) 55 | 56 | ## differential equation 57 | L = -0.5j * k**2 58 | def N(v): 59 | u = ifft(v) 60 | return -1j * kappa * fft(np.abs(u)**2 * u) 61 | 62 | ## ETDRK4 method 63 | E = np.exp(dt * L) 64 | E2 = np.exp(dt * L / 2.0) 65 | 66 | contour_radius = 1 67 | M = 32 68 | r = contour_radius*np.exp(2j * np.pi * (np.arange(1, M + 1) - 0.5) / M) 69 | r = r.reshape(1, -1) 70 | r_contour = np.repeat(r, mesh, axis=0) 71 | 72 | LR = dt * L 73 | LR = np.expand_dims(LR, axis=-1) 74 | LR = np.repeat(LR, M, axis=-1) 75 | LR = LR + r_contour 76 | 77 | Q = dt*np.mean( (np.exp(LR/2.0)-1)/LR, axis=-1 ) 78 | f1 = dt*np.mean( (-4.0-LR + np.exp(LR)*(4.0-3.0*LR+LR**2))/LR**3, axis=-1 ) 79 | f2 = dt*np.mean( (2.0+LR + np.exp(LR)*(-2.0 + LR))/LR**3, axis=-1 ) 80 | f3 = dt*np.mean( (-4.0-3.0*LR - LR**2 + np.exp(LR)*(4.0 - LR))/LR**3, axis=-1 ) 81 | 82 | t = 0.0 83 | u = [] 84 | v = v0 85 | tpool_num = 0 86 | 87 | # catch overflow warnings and rerun the data generation 88 | with warnings.catch_warnings(record=True) as w: 89 | warnings.simplefilter(action='ignore', category=FutureWarning) 90 | 91 | while t < tmax + shift - 1e-8: 92 | if t >= shift and tpool_num % tpool == 0: # exclude first 'shift' time 93 | up = resample(ifft(v), out_mesh, axis=-1) 94 | u.append(np.stack([np.real(up), np.imag(up)], axis=-2)) 95 | 96 | Nv = N(v) 97 | a = E2 * v + Q * Nv 98 | Na = N(a) 99 | b = E2 * v + Q * Na 100 | Nb = N(b) 101 | c = E2 * a + Q * (2.0*Nb - Nv) 102 | Nc = N(c) 103 | v = E*v + Nv*f1 + 2.0*(Na + Nb)*f2 + Nc*f3 104 | 105 | t = t + dt 106 | tpool_num += 1 107 | 108 | if w: 109 | print('Rerunning...') 110 | print(w[-1]) 111 | return generateData(i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, kappa_mean, kappa_std) 112 | 113 | return np.stack(u, axis=-2).astype(np.float32), kappa.astype(np.float32) 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="Generate nonlinear Schrödinger dataset.") 118 | parser.add_argument('-d', '--datasize', type=int, default=5000, 119 | help="Size of generated dataset--should be a multiple of batch_size. Default: 5000") 120 | parser.add_argument('-b', '--batchsize', type=int, default=50, 121 | help="Batch size for generating dataset in parallel--should divide data_size. Default: 50") 122 | parser.add_argument('-f', '--filename', type=str, default='NLSE_dataset.npz', 123 | help="Path with file name ending in .npz where dataset is saved. Default: NLSE_dataset.npz") 124 | 125 | args = parser.parse_args() 126 | data_size = args.datasize 127 | batch_size = args.batchsize 128 | FILENAME = args.filename 129 | 130 | l = 8*np.pi # system size 131 | out_mesh = 256 # mesh 132 | tmax = 1*np.pi # max time 133 | out_tmesh = 256 # time mesh 134 | shift = 0 # shift time to exclude initial conditions, set to 0 to keep t = 0 135 | 136 | kappa_mean = -1 137 | kappa_std = 0.25 138 | 139 | num_cores = multiprocessing.cpu_count() 140 | print('Using ' + str(num_cores) + ' cores...') 141 | out_list = Parallel(n_jobs=num_cores)(delayed(generateData) 142 | (i, data_size, batch_size, l, out_mesh, tmax, out_tmesh, shift, kappa_mean, kappa_std) 143 | for i in range(int(data_size/batch_size))) 144 | 145 | u_list, kappa_list = [[data[i] for data in out_list] for i in range(2)] 146 | 147 | u_list = np.concatenate(u_list) 148 | kappa_list = np.concatenate(kappa_list) 149 | 150 | ## shape of u_list = (data_size, data_channels, out_tmesh, out_mesh) 151 | print(u_list.shape) 152 | print(u_list.dtype) 153 | 154 | ## shape of kappa_list = (data_size, 1) 155 | print(kappa_list.shape) 156 | print(kappa_list.dtype) 157 | 158 | print('Exporting to: ' + FILENAME) 159 | np.savez(FILENAME, x=u_list, params=np.stack([kappa_list.flatten()], axis=1)) 160 | 161 | -------------------------------------------------------------------------------- /data/generate_fiber.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | generate_fiber.py 5 | 6 | Script for generating the nonlinear fiber propagation dataset. 7 | Requires MEEP (https://meep.readthedocs.io/en/latest/). 8 | """ 9 | 10 | import shutil 11 | 12 | import meep as mp 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import h5py 16 | import glob 17 | 18 | from joblib import Parallel, delayed 19 | import multiprocessing 20 | 21 | import warnings 22 | import argparse 23 | 24 | def get_f(total_time, dT, fsrc, batch_size): 25 | mesh = int(total_time/dT) 26 | assert mesh == total_time/dT 27 | T = 1/fsrc 28 | n_periods = 20 29 | 30 | k = 2*np.pi * np.fft.fftfreq(mesh, d=dT) 31 | w = np.pi/(n_periods*T) 32 | envelope = np.exp(-k**2/(2*w**2)) 33 | v0 = envelope * (np.random.normal(size=(batch_size, mesh)) + 1j*np.random.normal(size=(batch_size, mesh))) 34 | u0 = np.fft.ifft(v0) 35 | u0 = np.sqrt(2 * mesh) * u0/np.expand_dims(np.linalg.norm(u0, axis=-1), axis=-1) # normalize 36 | v0 = np.fft.fft(u0) 37 | 38 | sigmoid = lambda x: 1 / (1 + np.exp(-x)) 39 | return u0, lambda t: 1/mesh * np.matmul(v0, np.exp(1j * k * t)) * sigmoid(t/(20*T)-5) * np.exp(-1j*2*np.pi*fsrc*t) # slowly turn on 40 | 41 | def get_f_half(total_time, dT, fsrc, batch_size): 42 | T = 1/fsrc 43 | u0, f = get_f(total_time, dT, fsrc, batch_size) 44 | sigmoid = lambda x: 1 / (1 + np.exp(-x)) 45 | return u0, lambda t: f(t) * sigmoid((total_time/2-t)/(20*T)-5) # turn off before half total_time 46 | 47 | def generateData(i, resolution=5, total_time=10000, dT=10, length=500., pml_thickness=10., outer_radius=10., chi3=0.0, half=False, shift=200): 48 | ### Generate data using MEEP 49 | 50 | print('Run ' + str(i+1) + ' of ' + str(data_size)) 51 | 52 | while True: 53 | dr1 = np.random.normal(scale=0.05) 54 | dr2 = np.random.normal(scale=0.02) 55 | de1 = np.random.normal(scale=1) 56 | de2 = np.random.normal(scale=2) 57 | 58 | inner_core_radius = 0.5 + dr2 59 | core_radius = 1.0 + dr1 60 | 61 | if core_radius > inner_core_radius: 62 | break 63 | 64 | 65 | output_dir = f"out-{i:05d}" 66 | 67 | cell_size = mp.Vector3(outer_radius + pml_thickness, 0, length + 2*pml_thickness) 68 | pml_layers = [mp.PML(thickness=pml_thickness)] 69 | default_material = mp.Medium(index=1, chi3=chi3) 70 | geometry = [mp.Block(center=mp.Vector3(), size=mp.Vector3(2*core_radius, mp.inf, mp.inf), material=mp.Medium(epsilon=8 + de1, chi3=chi3)), 71 | mp.Block(center=mp.Vector3(), size=mp.Vector3(2*inner_core_radius, mp.inf, mp.inf), material=mp.Medium(epsilon=30 + de2, chi3=chi3)) 72 | ] 73 | 74 | fsrc = 0.1 75 | if not half: 76 | u0, f = get_f(total_time, dT, fsrc=fsrc, batch_size=1) 77 | else: 78 | u0, f = get_f_half(total_time, dT, fsrc=fsrc, batch_size=1) 79 | sources = [mp.Source(src=mp.CustomSource(src_func=lambda t: f(t)[0]), 80 | center=mp.Vector3(0,0,-length/2.), 81 | size=mp.Vector3(2*(3)), 82 | component=mp.Er)] 83 | 84 | sim = mp.Simulation(cell_size=cell_size, 85 | resolution=resolution, 86 | boundary_layers=pml_layers, 87 | sources=sources, 88 | geometry=geometry, 89 | dimensions=mp.CYLINDRICAL, 90 | m=1 91 | ) 92 | flux_total = sim.add_flux(fsrc, 1.*fsrc, int(fsrc*total_time)+1, mp.FluxRegion(center=mp.Vector3(0, 0, -length/2. + pml_thickness), size=mp.Vector3(2*outer_radius))) 93 | 94 | sim.use_output_directory(output_dir) 95 | sim.run(mp.at_every(dT, mp.in_volume(mp.Volume(center=mp.Vector3(), size=mp.Vector3(0,0,length)), mp.output_efield_r)), until=total_time) 96 | 97 | files = sorted(glob.glob(output_dir + "/*er-*.h5")) 98 | data = [] 99 | for file in files: 100 | f = h5py.File(file, "r") 101 | data.append(np.array(f['er.r']) + 1j * np.array(f['er.i'])) 102 | 103 | data = np.stack(data) 104 | 105 | # Normalize by flux 106 | freqs = np.array(mp.get_flux_freqs(flux_total)) 107 | flux = np.array(mp.get_fluxes(flux_total)) 108 | integrated_flux = np.sum(flux)*(freqs[1]-freqs[0]) 109 | integrated_efield = np.sum(np.abs(data[:, int(resolution*pml_thickness)+1])**2)*dT 110 | norm_factor = np.sqrt(integrated_flux/integrated_efield) 111 | data *= norm_factor 112 | 113 | mean_norm2 = np.mean(np.abs(data[:, int(resolution*pml_thickness)+1])**2) 114 | 115 | # Remove carrier frequency/wavelength 116 | k = np.fft.fftfreq(data.shape[-1], d=1./resolution) 117 | k0 = k[np.argmax(np.abs(np.mean(np.fft.fft(data), 0)))] 118 | psi = data * np.exp(1j * 2*np.pi * (fsrc * dT * np.expand_dims(np.arange(data.shape[0]),1) - k0 * np.arange(data.shape[1])/resolution)) 119 | 120 | psi = psi[shift:, int(resolution*pml_thickness)+1:-1:int(resolution)] # drop region near PML and initial time 200*dT 121 | psi = psi.transpose() # swap t and z axis 122 | 123 | shutil.rmtree(output_dir) 124 | 125 | return np.expand_dims(np.stack([np.real(psi), np.imag(psi)], axis=-3), 0).astype(np.float32), np.array([[dr1, dr2, de1, de2, k0, mean_norm2]]).astype(np.float32) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser(description="Generate fiber dataset using MEEP.") 130 | parser.add_argument('-d', '--datasize', type=int, default=200, 131 | help="Size of generated dataset--should be a multiple of batch_size. Default: 200") 132 | parser.add_argument('-f', '--filename', type=str, default='fiber_dataset.npz', 133 | help="Path with file name ending in .npz where dataset is saved. Default: fiber_dataset.npz") 134 | parser.add_argument('--half', action='store_true', 135 | help="Stop input pulse before half of the total simulation time to ensure entire pulse passes during simulation.") 136 | 137 | args = parser.parse_args() 138 | data_size = args.datasize 139 | FILENAME = args.filename 140 | 141 | length = 510 142 | total_time = 10000 143 | chi3 = 0.02 144 | shift = 0 if args.half else 200 145 | 146 | num_cores = multiprocessing.cpu_count() # may want to use less than maximum number of CPUs 147 | print('Using ' + str(num_cores) + ' cores...') 148 | out_list = Parallel(n_jobs=num_cores)( 149 | delayed(generateData)(i, total_time=total_time, length=length, chi3=chi3, half=args.half, shift=shift) for i in range(data_size)) 150 | 151 | u_list, params_list = [[data[i] for data in out_list] for i in range(2)] 152 | 153 | u_list = np.concatenate(u_list) 154 | params_list = np.concatenate(params_list) 155 | 156 | # Re-normalize to achieve component variance ~ 1 on average over the dataset 157 | # norm_factor = np.sqrt(2/np.mean(params_list[:,-1])) 158 | norm_factor = 0.8600365727096997 159 | u_list *= norm_factor 160 | 161 | ## shape of u_list = (data_size, data_channels, length-10, total_time/dT - shift) 162 | print(u_list.shape) 163 | print(u_list.dtype) 164 | 165 | ## shape of params_list = (data_size, 6) 166 | print(params_list.shape) 167 | print(params_list.dtype) 168 | 169 | print('Exporting to: ' + FILENAME) 170 | np.savez(FILENAME, x=u_list, params=params_list) 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /input_files/CD_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "DATAFILE": "data/CD_dataset_size1000.npz", 3 | "OUTFOLDER": "CD_run0", 4 | 5 | "cuda_device": 0, 6 | 7 | "train": true, 8 | 9 | "model": "pde2d", 10 | "data_dimension": 2, 11 | "data_channels": 1, 12 | 13 | "dataset_type": "npz_dataset", 14 | "boundary_cond": "crop", 15 | "input_size": 62, 16 | "training_size": 44, 17 | 18 | "input_depth": 45, 19 | "training_depth": 15, 20 | "evaluation_depth": 127, 21 | 22 | "linear_kernel_size": 0, 23 | 24 | "nonlin_kernel_size": 5, 25 | "hidden_channels": 16, 26 | "prop_layers": 1, 27 | 28 | "param_size": 5, 29 | "beta": 1e-4, 30 | 31 | "learning_rate": 1e-3, 32 | "eps": 1e-8, 33 | "num_workers": 4, 34 | "batch_size": 50, 35 | 36 | "max_epochs": 8000, 37 | "save_epochs": 800 38 | } 39 | -------------------------------------------------------------------------------- /input_files/KS_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "DATAFILE": "data/KS_dataset_size5000.npz", 3 | "OUTFOLDER": "KS_run0", 4 | 5 | "cuda_device": 0, 6 | 7 | "train": true, 8 | 9 | "model": "pde1d", 10 | "data_dimension": 1, 11 | "data_channels": 1, 12 | 13 | "dataset_type": "npz_dataset", 14 | "boundary_cond": "crop", 15 | "input_size": 94, 16 | "training_size": 76, 17 | 18 | "input_depth": 64, 19 | "training_depth": 31, 20 | "evaluation_depth": 255, 21 | 22 | "linear_kernel_size": 0, 23 | 24 | "nonlin_kernel_size": 5, 25 | "hidden_channels": 16, 26 | "prop_layers": 1, 27 | 28 | "param_size": 5, 29 | "beta": 2e-2, 30 | 31 | "learning_rate": 1e-3, 32 | "eps": 1e-8, 33 | "num_workers": 2, 34 | "batch_size": 50, 35 | 36 | "max_epochs": 2000, 37 | "save_epochs": 200 38 | } 39 | -------------------------------------------------------------------------------- /input_files/NLSE_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "DATAFILE": "data/NLSE_dataset_size5000.npz", 3 | "OUTFOLDER": "NLSE_run0", 4 | 5 | "cuda_device": 0, 6 | 7 | "train": true, 8 | 9 | "model": "pde1d", 10 | "data_dimension": 1, 11 | "data_channels": 2, 12 | 13 | "dataset_type": "npz_dataset", 14 | "boundary_cond": "crop", 15 | "input_size": 94, 16 | "training_size": 76, 17 | 18 | "input_depth": 64, 19 | "training_depth": 63, 20 | "evaluation_depth": 255, 21 | 22 | "linear_kernel_size": 0, 23 | 24 | "nonlin_kernel_size": 5, 25 | "hidden_channels": 16, 26 | "prop_layers": 1, 27 | 28 | "param_size": 5, 29 | "beta": 2e-2, 30 | 31 | "learning_rate": 1e-3, 32 | "eps": 1e-8, 33 | "num_workers": 2, 34 | "batch_size": 50, 35 | 36 | "max_epochs": 2000, 37 | "save_epochs": 200 38 | } 39 | -------------------------------------------------------------------------------- /input_files/fiber_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "DATAFILE": "data/fiber_dataset_size200.npz", 3 | "OUTFOLDER": "fiber_run0", 4 | 5 | "cuda_device": 0, 6 | 7 | "train": true, 8 | 9 | "model": "pde1d", 10 | "data_dimension": 1, 11 | "data_channels": 2, 12 | 13 | "dataset_type": "npz_dataset", 14 | "boundary_cond": "crop", 15 | "input_size": 158, 16 | "training_size": 76, 17 | 18 | "input_depth": 128, 19 | "training_depth": 31, 20 | "evaluation_depth": 499, 21 | 22 | "linear_kernel_size": 0, 23 | 24 | "nonlin_kernel_size": 5, 25 | "hidden_channels": 16, 26 | "prop_layers": 1, 27 | 28 | "prop_noise": 1e-2, 29 | 30 | "param_size": 5, 31 | "beta": 7e-4, 32 | 33 | "learning_rate": 1e-3, 34 | "eps": 1e-8, 35 | "num_workers": 2, 36 | "batch_size": 50, 37 | 38 | "max_epochs": 40000, 39 | "save_epochs": 2000 40 | } 41 | -------------------------------------------------------------------------------- /input_files/template.json: -------------------------------------------------------------------------------- 1 | // REMOVE ALL COMMENTS BEFORE USING THIS FILE 2 | { 3 | "DATAFILE": "path/to/dataset.npz", // Path to dataset 4 | "OUTFOLDER": "path/to/output_folder", // Output folder path (folder must not exist yet and will be created) 5 | 6 | "cuda_device": 0, // CUDA device number to use (0 if only one GPU in system), or list of devices for data parallel 7 | "data_parallel": false, // OPTIONAL: Set to true for running in parallel on multiple GPUs, then set cuda_device to GPU list 8 | 9 | "train": true, // Set to true for training and false for evaluating (evaluating requires restart is true) 10 | 11 | "restart": false, // OPTIONAL: Set to true to restart from existing model parameters in file MODELLOAD 12 | "freeze_encoder": false, // OPTIONAL: Set to true to freeze encoder weights to refine decoder 13 | 14 | "MODELLOAD": "path/to/saved_model_file", // OPTIONAL: Saved model file used to load weights if evaluating or if restart is true 15 | 16 | "model": "pde1d", // Model file from 'models/' folder, use pde1d for 1D PDEs and pde2d for 2D PDEs 17 | "data_dimension": 1, // Number of spatial dimensions of dataset (should match 'model') 18 | "data_channels": 2, // Number of channels in dataset (e.g. 1 for scalar fields, n for n-d vector fields) 19 | 20 | "dataset_type": "npz_dataset", // Dataloader dataset class 21 | "boundary_cond": "crop", // Boundary conditions: "crop" (default data augmentation), "periodic", "dirichlet0" 22 | "input_size": 158, // For crop boundaries, cropped spatial size of encoder inputs 23 | "training_size": 76, // For crop boundaries, cropped spatial size of decoder time-series 24 | 25 | "input_depth": 128, // Temporal size of encoder inputs (can be smaller than dataset size for crop boundaries) 26 | "training_depth": 31, // Time steps (temporal depth) to predict using decoder during training (excluding the initial condition) 27 | "evaluation_depth": 255, // Time steps (temporal depth) to predict using decoder during evaluation (excluding the initial condition) 28 | 29 | "linear_kernel_size": 0, // Linear convolutional kernel size in parallel with nonlinear kernel in decoder 30 | 31 | "nonlin_kernel_size": 5, // Nonlinear convolutional kernel size in decoder 32 | "hidden_channels": 16, // Number of convolutional kernel channels for inner layers 33 | "prop_layers": 1, // Number of inner layers (excluding input and output conv. layers) 34 | 35 | "discount_rate": 0.0, // OPTIONAL: Discount rate for each successive decoder prediction step 36 | "rate_decay": 0.0, // OPTIONAL: Decay rate of discount_rate per epoch 37 | "param_dropout_prob": 0.0, // OPTIONAL: Drop out probability for latent parameter estimation 38 | "prop_noise": 0.0, // OPTIONAL: Stdv. of added noise after each PD step during training (sometimes improves stability of predictions) 39 | 40 | "param_size": 5, // Number of latent parameters available (should be greater than expected number relevant parameters) 41 | "beta": 5e-4, // beta-VAE regularization parameter 42 | 43 | "learning_rate": 1e-3, // ADAM optimizer initial learning rate 44 | "eps": 1e-8, // ADAM optimizer epsilon parameter 45 | "num_workers": 2, // Number of worker processes for loading data 46 | "batch_size": 50, // Batch size to run through the model 47 | 48 | "max_epochs": 20000, // Number of epochs to run 49 | "save_epochs": 2000, // Number of epochs between model saves 50 | } 51 | -------------------------------------------------------------------------------- /models/pde1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | pde1d.py 3 | 4 | PDE VAE model (PDEAutoEncoder module) for fitting data with 1 spatial dimension. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.nn.modules.utils import _single 12 | 13 | from torch.nn.parameter import Parameter 14 | 15 | 16 | class PeriodicPad1d(nn.Module): 17 | def __init__(self, pad, dim=-1): 18 | super(PeriodicPad1d, self).__init__() 19 | self.pad = pad 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | if self.pad > 0: 24 | front_padding = x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad) 25 | back_padding = x.narrow(self.dim, 0, self.pad) 26 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 27 | 28 | return x 29 | 30 | class AntiReflectionPad1d(nn.Module): 31 | def __init__(self, pad, dim=-1): 32 | super(PeriodicPad1d, self).__init__() 33 | self.pad = pad 34 | self.dim = dim 35 | 36 | def forward(self, x): 37 | if self.pad > 0: 38 | front_padding = -x.narrow(self.dim, 0, self.pad).flip([self.dim]) 39 | back_padding = -x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad).flip([self.dim]) 40 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 41 | 42 | return x 43 | 44 | 45 | class DynamicConv1d(nn.Module): 46 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 47 | padding=0, dilation=1, groups=1, boundary_cond='periodic'): 48 | 49 | super(DynamicConv1d, self).__init__() 50 | 51 | self.kernel_size = _single(kernel_size) 52 | self.stride = _single(stride) # not implemented 53 | self.padding = _single(padding) 54 | self.dilation = _single(dilation) # not implemented 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | 59 | self.boundary_cond = boundary_cond 60 | 61 | if self.padding[0] > 0 and boundary_cond == 'periodic': 62 | assert self.padding[0] == int((self.kernel_size[0]-1)/2) 63 | self.pad = PeriodicPad1d(self.padding[0], dim=-2) 64 | else: 65 | self.pad = None 66 | 67 | def forward(self, input, weight, bias): 68 | y = input.transpose(1, 2) 69 | 70 | if self.pad is not None: 71 | output_size = input.shape[-1] 72 | y = self.pad(y) 73 | else: 74 | output_size = input.shape[-1] - (self.kernel_size[0]-1) 75 | image_patches = y.unfold(-2, self.kernel_size[0], self.stride[0]).transpose(-1, -2).contiguous() 76 | y = torch.matmul( image_patches.view(-1, output_size, self.in_channels * self.kernel_size[0]), 77 | weight.view(-1, self.in_channels * self.kernel_size[0], self.out_channels) 78 | ) 79 | if bias is not None: 80 | y = y + bias.view(-1, 1, self.out_channels) 81 | 82 | return y.transpose(-1, -2) 83 | 84 | 85 | class ConvPropagator(nn.Module): 86 | def __init__(self, hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, stride=1, 87 | linear_padding=0, nonlin_padding=0, dilation=1, groups=1, prop_layers=1, prop_noise=0., boundary_cond='periodic'): 88 | 89 | self.data_channels = data_channels 90 | self.prop_layers = prop_layers 91 | self.prop_noise = prop_noise 92 | self.boundary_cond = boundary_cond 93 | 94 | assert nonlin_padding == int((nonlin_kernel_size-1)/2) 95 | if boundary_cond == 'crop' or boundary_cond == 'dirichlet0': 96 | self.padding = int((2+prop_layers)*nonlin_padding) 97 | 98 | super(ConvPropagator, self).__init__() 99 | 100 | self.conv_linear = DynamicConv1d(data_channels, data_channels, linear_kernel_size, stride, 101 | linear_padding, dilation, groups, boundary_cond) if linear_kernel_size > 0 else None 102 | 103 | self.conv_in = DynamicConv1d(data_channels, hidden_channels, nonlin_kernel_size, stride, 104 | nonlin_padding, dilation, groups, boundary_cond) 105 | 106 | self.conv_out = DynamicConv1d(hidden_channels, data_channels, nonlin_kernel_size, stride, 107 | nonlin_padding, dilation, groups, boundary_cond) 108 | 109 | if prop_layers > 0: 110 | self.conv_prop = nn.ModuleList([DynamicConv1d(hidden_channels, hidden_channels, nonlin_kernel_size, stride, 111 | nonlin_padding, dilation, groups, boundary_cond) 112 | for i in range(prop_layers)]) 113 | 114 | self.cutoff = Parameter(torch.Tensor([1])) 115 | 116 | def _target_pad_1d(self, y, y0): 117 | return torch.cat((y0[:,:,:self.padding], y, y0[:,:,-self.padding:]), dim=-1) 118 | 119 | def _antireflection_pad_1d(self, y, dim): 120 | front_padding = -y.narrow(dim, 0, self.padding).flip([dim]) 121 | back_padding = -y.narrow(dim, y.shape[dim]-self.padding, self.padding).flip([dim]) 122 | return torch.cat((front_padding, y, back_padding), dim=dim) 123 | 124 | def _f(self, y, linear_weight, linear_bias, in_weight, in_bias, 125 | out_weight, out_bias, prop_weight, prop_bias): 126 | y_lin = self.conv_linear(y, linear_weight, linear_bias) if self.conv_linear is not None else 0 127 | 128 | y = self.conv_in(y, in_weight, in_bias) 129 | y = F.relu(y, inplace=True) 130 | for j in range(self.prop_layers): 131 | y = self.conv_prop[j](y, prop_weight[:,j], prop_bias[:,j]) 132 | y = F.relu(y, inplace=True) 133 | y = self.conv_out(y, out_weight, out_bias) 134 | 135 | return y + y_lin 136 | 137 | def forward(self, y0, linear_weight, linear_bias, 138 | in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias, depth): 139 | if self.boundary_cond == 'crop': 140 | # requires entire target solution as y0 for padding purposes 141 | assert len(y0.shape) == 4 142 | assert y0.shape[1] == self.data_channels 143 | assert y0.shape[2] == depth 144 | y_pad = y0[:,:,0] 145 | y = y0[:,:,0, self.padding:-self.padding] 146 | elif self.boundary_cond == 'periodic' or self.boundary_cond == 'dirichlet0': 147 | assert len(y0.shape) == 3 148 | assert y0.shape[1] == self.data_channels 149 | y = y0 150 | else: 151 | raise ValueError("Invalid boundary condition.") 152 | 153 | f = lambda y: self._f(y, linear_weight, linear_bias, in_weight, in_bias, 154 | out_weight, out_bias, prop_weight, prop_bias) 155 | 156 | y_list = [] 157 | for i in range(depth): 158 | if self.boundary_cond == 'crop': 159 | if i > 0: 160 | y_pad = self._target_pad_1d(y, y0[:,:,i]) 161 | elif self.boundary_cond == 'dirichlet0': 162 | y_pad = self._antireflection_pad_1d(y, -1) 163 | elif self.boundary_cond == 'periodic': 164 | y_pad = y 165 | 166 | ### Euler integrator 167 | dt = 1e-6 # NOT REAL TIME STEP, JUST HYPERPARAMETER 168 | noise = self.prop_noise * torch.randn_like(y) if (self.training and self.prop_noise > 0) else 0 169 | y = y + self.cutoff * torch.tanh((dt * f(y_pad)) / self.cutoff) + noise 170 | 171 | y_list.append(y) 172 | 173 | return torch.stack(y_list, dim=-2) 174 | 175 | 176 | class PDEAutoEncoder(nn.Module): 177 | def __init__(self, param_size=1, data_channels=1, data_dimension=1, hidden_channels=16, 178 | linear_kernel_size=0, nonlin_kernel_size=5, prop_layers=1, prop_noise=0., 179 | boundary_cond='periodic', param_dropout_prob=0.1, debug=False): 180 | 181 | assert data_dimension == 1 182 | 183 | super(PDEAutoEncoder, self).__init__() 184 | 185 | self.param_size = param_size 186 | self.data_channels = data_channels 187 | self.hidden_channels = hidden_channels 188 | self.linear_kernel_size = linear_kernel_size 189 | self.nonlin_kernel_size = nonlin_kernel_size 190 | self.prop_layers = prop_layers 191 | self.boundary_cond = boundary_cond 192 | self.param_dropout_prob = param_dropout_prob 193 | self.debug = debug 194 | 195 | if param_size > 0: 196 | ### 2D Convolutional Encoder 197 | if boundary_cond =='crop' or boundary_cond == 'dirichlet0': 198 | pad_input = [0, 0, 0, 0] 199 | pad_func = PeriodicPad1d # can be anything since no padding is added 200 | elif boundary_cond == 'periodic': 201 | pad_input = [1, 2, 4, 8] 202 | pad_func = PeriodicPad1d 203 | else: 204 | raise ValueError("Invalid boundary condition.") 205 | 206 | self.encoder = nn.Sequential( pad_func(pad_input[0]), 207 | nn.Conv2d(data_channels, 4, kernel_size=3, dilation=1), 208 | nn.ReLU(inplace=True), 209 | 210 | pad_func(pad_input[1]), 211 | nn.Conv2d(4, 16, kernel_size=3, dilation=2), 212 | nn.ReLU(inplace=True), 213 | 214 | pad_func(pad_input[2]), 215 | nn.Conv2d(16, 64, kernel_size=3, dilation=4), 216 | nn.ReLU(inplace=True), 217 | 218 | pad_func(pad_input[3]), 219 | nn.Conv2d(64, 64, kernel_size=3, dilation=8), 220 | nn.ReLU(inplace=True), 221 | ) 222 | self.encoder_to_param = nn.Sequential(nn.Conv2d(64, param_size, kernel_size=1, stride=1)) 223 | self.encoder_to_logvar = nn.Sequential(nn.Conv2d(64, param_size, kernel_size=1, stride=1)) 224 | 225 | ### Parameter to weight/bias for dynamic convolutions 226 | if linear_kernel_size > 0: 227 | self.param_to_linear_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * data_channels), 228 | nn.ReLU(inplace=True), 229 | nn.Linear(4 * data_channels * data_channels, 230 | data_channels * data_channels * linear_kernel_size) 231 | ) 232 | 233 | self.param_to_in_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * hidden_channels), 234 | nn.ReLU(inplace=True), 235 | nn.Linear(4 * data_channels * hidden_channels, 236 | data_channels * hidden_channels * nonlin_kernel_size) 237 | ) 238 | self.param_to_in_bias = nn.Sequential( nn.Linear(param_size, 4 * hidden_channels), 239 | nn.ReLU(inplace=True), 240 | nn.Linear(4 * hidden_channels, hidden_channels) 241 | ) 242 | 243 | self.param_to_out_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * hidden_channels), 244 | nn.ReLU(inplace=True), 245 | nn.Linear(4 * data_channels * hidden_channels, 246 | data_channels * hidden_channels * nonlin_kernel_size) 247 | ) 248 | self.param_to_out_bias = nn.Sequential( nn.Linear(param_size, 4 * data_channels), 249 | nn.ReLU(inplace=True), 250 | nn.Linear(4 * data_channels, data_channels) 251 | ) 252 | 253 | if prop_layers > 0: 254 | self.param_to_prop_weight = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels * hidden_channels), 255 | nn.ReLU(inplace=True), 256 | nn.Linear(4 * prop_layers * hidden_channels * hidden_channels, 257 | prop_layers * hidden_channels * hidden_channels * nonlin_kernel_size) 258 | ) 259 | self.param_to_prop_bias = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels), 260 | nn.ReLU(inplace=True), 261 | nn.Linear(4 * prop_layers * hidden_channels, prop_layers * hidden_channels) 262 | ) 263 | 264 | ### Decoder/PDE simulator 265 | self.decoder = ConvPropagator(hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, 266 | linear_padding=int((linear_kernel_size-1)/2), 267 | nonlin_padding=int((nonlin_kernel_size-1)/2), 268 | prop_layers=prop_layers, prop_noise=prop_noise, boundary_cond=boundary_cond) 269 | 270 | def forward(self, x, y0, depth): 271 | 272 | if self.param_size > 0: 273 | assert len(x.shape) == 4 274 | assert x.shape[1] == self.data_channels 275 | 276 | ### 2D Convolutional Encoder 277 | encoder_out = self.encoder(x) 278 | 279 | logvar = self.encoder_to_logvar(encoder_out) 280 | logvar_size = logvar.shape 281 | logvar = logvar.view(logvar_size[0], logvar_size[1], -1) 282 | params = self.encoder_to_param(encoder_out).view(logvar_size[0], logvar_size[1], -1) 283 | 284 | if self.debug: 285 | raw_params = params 286 | 287 | # Parameter Spatial Averaging Dropout 288 | if self.training and self.param_dropout_prob > 0: 289 | mask = torch.bernoulli(torch.full_like(logvar, self.param_dropout_prob)) 290 | mask[mask > 0] = float("inf") 291 | logvar = logvar + mask 292 | 293 | # Inverse variance weighted average of params 294 | weights = F.softmax(-logvar, dim=-1) 295 | params = (params * weights).sum(dim=-1) 296 | 297 | # Compute logvar for inverse variance weighted average with a correlation length correction 298 | correlation_length = 31 # estimated as receptive field of the convolutional encoder 299 | logvar = -torch.logsumexp(-logvar, dim=-1) \ 300 | + torch.log(torch.tensor( 301 | max(1, (1 - self.param_dropout_prob) 302 | * min(correlation_length, logvar_size[-2]) 303 | * min(correlation_length, logvar_size[-1])), 304 | dtype=logvar.dtype, device=logvar.device)) 305 | 306 | ### Variational autoencoder reparameterization trick 307 | if self.training: 308 | stdv = (0.5 * logvar).exp() 309 | 310 | # Sample from unit normal 311 | z = params + stdv * torch.randn_like(stdv) 312 | else: 313 | z = params 314 | 315 | ### Parameter to weight/bias for dynamic convolutions 316 | if self.linear_kernel_size > 0: 317 | linear_weight = self.param_to_linear_weight(z) 318 | linear_bias = None 319 | else: 320 | linear_weight = None 321 | linear_bias = None 322 | 323 | in_weight = self.param_to_in_weight(z) 324 | in_bias = self.param_to_in_bias(z) 325 | 326 | out_weight = self.param_to_out_weight(z) 327 | out_bias = self.param_to_out_bias(z) 328 | 329 | if self.prop_layers > 0: 330 | prop_weight = self.param_to_prop_weight(z).view(-1, self.prop_layers, 331 | self.hidden_channels * self.hidden_channels * self.nonlin_kernel_size) 332 | prop_bias = self.param_to_prop_bias(z).view(-1, self.prop_layers, self.hidden_channels) 333 | else: 334 | prop_weight = None 335 | prop_bias = None 336 | 337 | else: # if no parameter used 338 | linear_weight = None 339 | linear_bias = None 340 | in_weight = None 341 | in_bias = None 342 | out_weight = None 343 | out_bias = None 344 | prop_weight = None 345 | prop_bias = None 346 | params = None 347 | logvar = None 348 | 349 | ### Decoder/PDE simulator 350 | y = self.decoder(y0, linear_weight, linear_bias, in_weight, in_bias, out_weight, out_bias, 351 | prop_weight, prop_bias, depth) 352 | 353 | if self.debug: 354 | return y, params, logvar, [in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias], \ 355 | weights.view(logvar_size), raw_params.view(logvar_size) 356 | 357 | return y, params, logvar 358 | -------------------------------------------------------------------------------- /models/pde1d_decoder_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | pde1d_decoder_only.py 3 | 4 | Propagating decoder network only (PDEDecoder module) with tunable latent parameters z. 5 | To use, import weights from trained model. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.nn.modules.utils import _single 13 | 14 | from torch.nn.parameter import Parameter 15 | 16 | 17 | class PeriodicPad1d(nn.Module): 18 | def __init__(self, pad, dim=-1): 19 | super(PeriodicPad1d, self).__init__() 20 | self.pad = pad 21 | self.dim = dim 22 | 23 | def forward(self, x): 24 | if self.pad > 0: 25 | front_padding = x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad) 26 | back_padding = x.narrow(self.dim, 0, self.pad) 27 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 28 | 29 | return x 30 | 31 | class AntiReflectionPad1d(nn.Module): 32 | def __init__(self, pad, dim=-1): 33 | super(PeriodicPad1d, self).__init__() 34 | self.pad = pad 35 | self.dim = dim 36 | 37 | def forward(self, x): 38 | if self.pad > 0: 39 | front_padding = -x.narrow(self.dim, 0, self.pad).flip([self.dim]) 40 | back_padding = -x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad).flip([self.dim]) 41 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 42 | 43 | return x 44 | 45 | 46 | class DynamicConv1d(nn.Module): 47 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 48 | padding=0, dilation=1, groups=1, boundary_cond='periodic'): 49 | 50 | super(DynamicConv1d, self).__init__() 51 | 52 | self.kernel_size = _single(kernel_size) 53 | self.stride = _single(stride) # not implemented 54 | self.padding = _single(padding) 55 | self.dilation = _single(dilation) # not implemented 56 | 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | 60 | self.boundary_cond = boundary_cond 61 | 62 | if self.padding[0] > 0 and boundary_cond == 'periodic': 63 | assert self.padding[0] == int((self.kernel_size[0]-1)/2) 64 | self.pad = PeriodicPad1d(self.padding[0], dim=-2) 65 | else: 66 | self.pad = None 67 | 68 | def forward(self, input, weight, bias): 69 | y = input.transpose(1, 2) 70 | 71 | if self.pad is not None: 72 | output_size = input.shape[-1] 73 | y = self.pad(y) 74 | else: 75 | output_size = input.shape[-1] - (self.kernel_size[0]-1) 76 | image_patches = y.unfold(-2, self.kernel_size[0], self.stride[0]).transpose(-1, -2).contiguous() 77 | y = torch.matmul( image_patches.view(-1, output_size, self.in_channels * self.kernel_size[0]), 78 | weight.view(-1, self.in_channels * self.kernel_size[0], self.out_channels) 79 | ) 80 | if bias is not None: 81 | y = y + bias.view(-1, 1, self.out_channels) 82 | 83 | return y.transpose(-1, -2) 84 | 85 | 86 | class ConvPropagator(nn.Module): 87 | def __init__(self, hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, stride=1, 88 | linear_padding=0, nonlin_padding=0, dilation=1, groups=1, prop_layers=1, prop_noise=0., boundary_cond='periodic'): 89 | 90 | self.data_channels = data_channels 91 | self.prop_layers = prop_layers 92 | self.prop_noise = prop_noise 93 | self.boundary_cond = boundary_cond 94 | 95 | assert nonlin_padding == int((nonlin_kernel_size-1)/2) 96 | if boundary_cond == 'crop' or boundary_cond == 'dirichlet0': 97 | self.padding = int((2+prop_layers)*nonlin_padding) 98 | 99 | super(ConvPropagator, self).__init__() 100 | 101 | self.conv_linear = DynamicConv1d(data_channels, data_channels, linear_kernel_size, stride, 102 | linear_padding, dilation, groups, boundary_cond) if linear_kernel_size > 0 else None 103 | 104 | self.conv_in = DynamicConv1d(data_channels, hidden_channels, nonlin_kernel_size, stride, 105 | nonlin_padding, dilation, groups, boundary_cond) 106 | 107 | self.conv_out = DynamicConv1d(hidden_channels, data_channels, nonlin_kernel_size, stride, 108 | nonlin_padding, dilation, groups, boundary_cond) 109 | 110 | if prop_layers > 0: 111 | self.conv_prop = nn.ModuleList([DynamicConv1d(hidden_channels, hidden_channels, nonlin_kernel_size, stride, 112 | nonlin_padding, dilation, groups, boundary_cond) 113 | for i in range(prop_layers)]) 114 | 115 | self.cutoff = Parameter(torch.Tensor([1])) 116 | 117 | def _target_pad_1d(self, y, y0): 118 | return torch.cat((y0[:,:,:self.padding], y, y0[:,:,-self.padding:]), dim=-1) 119 | 120 | def _antireflection_pad_1d(self, y, dim): 121 | front_padding = -y.narrow(dim, 0, self.padding).flip([dim]) 122 | back_padding = -y.narrow(dim, y.shape[dim]-self.padding, self.padding).flip([dim]) 123 | return torch.cat((front_padding, y, back_padding), dim=dim) 124 | 125 | def _f(self, y, linear_weight, linear_bias, in_weight, in_bias, 126 | out_weight, out_bias, prop_weight, prop_bias): 127 | y_lin = self.conv_linear(y, linear_weight, linear_bias) if self.conv_linear is not None else 0 128 | 129 | y = self.conv_in(y, in_weight, in_bias) 130 | y = F.relu(y, inplace=True) 131 | for j in range(self.prop_layers): 132 | y = self.conv_prop[j](y, prop_weight[:,j], prop_bias[:,j]) 133 | y = F.relu(y, inplace=True) 134 | y = self.conv_out(y, out_weight, out_bias) 135 | 136 | return y + y_lin 137 | 138 | def forward(self, y0, linear_weight, linear_bias, 139 | in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias, depth): 140 | if self.boundary_cond == 'crop': 141 | # requires entire target solution as y0 for padding purposes 142 | assert len(y0.shape) == 4 143 | assert y0.shape[1] == self.data_channels 144 | assert y0.shape[2] == depth 145 | y_pad = y0[:,:,0] 146 | y = y0[:,:,0, self.padding:-self.padding] 147 | elif self.boundary_cond == 'periodic' or self.boundary_cond == 'dirichlet0': 148 | assert len(y0.shape) == 3 149 | assert y0.shape[1] == self.data_channels 150 | y = y0 151 | else: 152 | raise ValueError("Invalid boundary condition.") 153 | 154 | f = lambda y: self._f(y, linear_weight, linear_bias, in_weight, in_bias, 155 | out_weight, out_bias, prop_weight, prop_bias) 156 | 157 | y_list = [] 158 | for i in range(depth): 159 | if self.boundary_cond == 'crop': 160 | if i > 0: 161 | y_pad = self._target_pad_1d(y, y0[:,:,i]) 162 | elif self.boundary_cond == 'dirichlet0': 163 | y_pad = self._antireflection_pad_1d(y, -1) 164 | elif self.boundary_cond == 'periodic': 165 | y_pad = y 166 | 167 | ### Euler integrator 168 | dt = 1e-6 # NOT REAL TIME STEP, JUST HYPERPARAMETER 169 | noise = self.prop_noise * torch.randn_like(y) if (self.training and self.prop_noise > 0) else 0 170 | y = y + self.cutoff * torch.tanh((dt * f(y_pad)) / self.cutoff) + noise 171 | 172 | y_list.append(y) 173 | 174 | return torch.stack(y_list, dim=-2) 175 | 176 | 177 | class PDEDecoder(nn.Module): 178 | def __init__(self, param_size=1, data_channels=1, data_dimension=1, hidden_channels=16, 179 | linear_kernel_size=0, nonlin_kernel_size=5, prop_layers=1, prop_noise=0., 180 | boundary_cond='periodic', param_dropout_prob=0.1, debug=False): 181 | 182 | assert data_dimension == 1 183 | 184 | super(PDEDecoder, self).__init__() 185 | 186 | self.param_size = param_size 187 | self.data_channels = data_channels 188 | self.hidden_channels = hidden_channels 189 | self.linear_kernel_size = linear_kernel_size 190 | self.nonlin_kernel_size = nonlin_kernel_size 191 | self.prop_layers = prop_layers 192 | self.boundary_cond = boundary_cond 193 | self.param_dropout_prob = param_dropout_prob 194 | self.debug = debug 195 | 196 | if param_size > 0: 197 | ### Parameter to weight/bias for dynamic convolutions 198 | if linear_kernel_size > 0: 199 | self.param_to_linear_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * data_channels), 200 | nn.ReLU(inplace=True), 201 | nn.Linear(4 * data_channels * data_channels, 202 | data_channels * data_channels * linear_kernel_size) 203 | ) 204 | 205 | self.param_to_in_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * hidden_channels), 206 | nn.ReLU(inplace=True), 207 | nn.Linear(4 * data_channels * hidden_channels, 208 | data_channels * hidden_channels * nonlin_kernel_size) 209 | ) 210 | self.param_to_in_bias = nn.Sequential( nn.Linear(param_size, 4 * hidden_channels), 211 | nn.ReLU(inplace=True), 212 | nn.Linear(4 * hidden_channels, hidden_channels) 213 | ) 214 | 215 | self.param_to_out_weight = nn.Sequential( nn.Linear(param_size, 4 * data_channels * hidden_channels), 216 | nn.ReLU(inplace=True), 217 | nn.Linear(4 * data_channels * hidden_channels, 218 | data_channels * hidden_channels * nonlin_kernel_size) 219 | ) 220 | self.param_to_out_bias = nn.Sequential( nn.Linear(param_size, 4 * data_channels), 221 | nn.ReLU(inplace=True), 222 | nn.Linear(4 * data_channels, data_channels) 223 | ) 224 | 225 | if prop_layers > 0: 226 | self.param_to_prop_weight = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels * hidden_channels), 227 | nn.ReLU(inplace=True), 228 | nn.Linear(4 * prop_layers * hidden_channels * hidden_channels, 229 | prop_layers * hidden_channels * hidden_channels * nonlin_kernel_size) 230 | ) 231 | self.param_to_prop_bias = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels), 232 | nn.ReLU(inplace=True), 233 | nn.Linear(4 * prop_layers * hidden_channels, prop_layers * hidden_channels) 234 | ) 235 | 236 | ### Decoder/PDE simulator 237 | self.decoder = ConvPropagator(hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, 238 | linear_padding=int((linear_kernel_size-1)/2), 239 | nonlin_padding=int((nonlin_kernel_size-1)/2), 240 | prop_layers=prop_layers, prop_noise=prop_noise, boundary_cond=boundary_cond) 241 | 242 | def forward(self, z, y0, depth): 243 | 244 | if self.param_size > 0: 245 | assert len(z.shape) == 2 246 | assert z.shape[1] == self.param_size 247 | 248 | ### Parameter to weight/bias for dynamic convolutions 249 | if self.linear_kernel_size > 0: 250 | linear_weight = self.param_to_linear_weight(z) 251 | linear_bias = None 252 | else: 253 | linear_weight = None 254 | linear_bias = None 255 | 256 | in_weight = self.param_to_in_weight(z) 257 | in_bias = self.param_to_in_bias(z) 258 | 259 | out_weight = self.param_to_out_weight(z) 260 | out_bias = self.param_to_out_bias(z) 261 | 262 | if self.prop_layers > 0: 263 | prop_weight = self.param_to_prop_weight(z).view(-1, self.prop_layers, 264 | self.hidden_channels * self.hidden_channels * self.nonlin_kernel_size) 265 | prop_bias = self.param_to_prop_bias(z).view(-1, self.prop_layers, self.hidden_channels) 266 | else: 267 | prop_weight = None 268 | prop_bias = None 269 | 270 | else: # if no parameter used 271 | linear_weight = None 272 | linear_bias = None 273 | in_weight = None 274 | in_bias = None 275 | out_weight = None 276 | out_bias = None 277 | prop_weight = None 278 | prop_bias = None 279 | 280 | ### Decoder/PDE simulator 281 | y = self.decoder(y0, linear_weight, linear_bias, in_weight, in_bias, out_weight, out_bias, 282 | prop_weight, prop_bias, depth) 283 | 284 | if self.debug: 285 | return y, z, None, [in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias] 286 | 287 | return y, z, None 288 | -------------------------------------------------------------------------------- /models/pde2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | pde2d.py 3 | 4 | PDE VAE model (PDEAutoEncoder module) for fitting data with 2 spatial dimensions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.nn.modules.utils import _pair 12 | 13 | from torch.nn.parameter import Parameter 14 | 15 | 16 | class PeriodicPad1d(nn.Module): 17 | def __init__(self, pad, dim=-1): 18 | super(PeriodicPad1d, self).__init__() 19 | self.pad = pad 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | if self.pad > 0: 24 | front_padding = x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad) 25 | back_padding = x.narrow(self.dim, 0, self.pad) 26 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 27 | 28 | return x 29 | 30 | class AntiReflectionPad1d(nn.Module): 31 | def __init__(self, pad, dim=-1): 32 | super(PeriodicPad1d, self).__init__() 33 | self.pad = pad 34 | self.dim = dim 35 | 36 | def forward(self, x): 37 | if self.pad > 0: 38 | front_padding = -x.narrow(self.dim, 0, self.pad).flip([self.dim]) 39 | back_padding = -x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad).flip([self.dim]) 40 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 41 | 42 | return x 43 | 44 | 45 | class DynamicConv2d(nn.Module): 46 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 47 | padding=0, dilation=1, groups=1, boundary_cond='periodic'): 48 | 49 | super(DynamicConv2d, self).__init__() 50 | 51 | self.kernel_size = _pair(kernel_size) 52 | self.stride = _pair(stride) # not implemented 53 | self.padding = _pair(padding) 54 | self.dilation = _pair(dilation) # not implemented 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | 59 | self.boundary_cond = boundary_cond 60 | 61 | if (self.padding[0] > 0 or self.padding[1] > 0) and boundary_cond == 'periodic': 62 | assert self.padding[0] == int((self.kernel_size[0]-1)/2) 63 | assert self.padding[1] == int((self.kernel_size[1]-1)/2) 64 | self.pad = nn.Sequential( PeriodicPad1d(self.padding[1], dim=-1), 65 | PeriodicPad1d(self.padding[0], dim=-2)) 66 | else: 67 | self.pad = None 68 | 69 | def forward(self, input, weight, bias): 70 | y = input 71 | 72 | if self.pad is not None: 73 | output_size = input.shape[-2:] 74 | y = self.pad(y) 75 | else: 76 | output_size = ( input.shape[-2] - (self.kernel_size[0]-1), 77 | input.shape[-1] - (self.kernel_size[1]-1)) 78 | image_patches = F.unfold(y, self.kernel_size, self.dilation, 0, self.stride).transpose(1, 2) 79 | y = image_patches.matmul(weight.view(-1, 80 | self.in_channels * self.kernel_size[0] * self.kernel_size[1], 81 | self.out_channels)) 82 | if bias is not None: 83 | y = y + bias.view(-1, 1, self.out_channels) 84 | 85 | return y.transpose(1, 2).view(-1, self.out_channels, output_size[0], output_size[1]) 86 | 87 | 88 | class ConvPropagator(nn.Module): 89 | def __init__(self, hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, stride=1, 90 | linear_padding=0, nonlin_padding=0, dilation=1, groups=1, prop_layers=1, prop_noise=0., boundary_cond='periodic'): 91 | 92 | self.data_channels = data_channels 93 | self.prop_layers = prop_layers 94 | self.prop_noise = prop_noise 95 | self.boundary_cond = boundary_cond 96 | 97 | assert nonlin_padding == int((nonlin_kernel_size-1)/2) 98 | if boundary_cond == 'crop' or boundary_cond == 'dirichlet0': 99 | self.padding = int((2+prop_layers)*nonlin_padding) 100 | 101 | super(ConvPropagator, self).__init__() 102 | 103 | self.conv_linear = DynamicConv2d(data_channels, data_channels, linear_kernel_size, stride, 104 | linear_padding, dilation, groups, boundary_cond) if linear_kernel_size > 0 else None 105 | 106 | self.conv_in = DynamicConv2d(data_channels, hidden_channels, nonlin_kernel_size, stride, 107 | nonlin_padding, dilation, groups, boundary_cond) 108 | 109 | self.conv_out = DynamicConv2d(hidden_channels, data_channels, nonlin_kernel_size, stride, 110 | nonlin_padding, dilation, groups, boundary_cond) 111 | 112 | if prop_layers > 0: 113 | self.conv_prop = nn.ModuleList([DynamicConv2d(hidden_channels, hidden_channels, nonlin_kernel_size, stride, 114 | nonlin_padding, dilation, groups, boundary_cond) 115 | for i in range(prop_layers)]) 116 | 117 | self.cutoff = Parameter(torch.Tensor([1])) 118 | 119 | def _target_pad_2d(self, y, y0): 120 | y = torch.cat((y0[:,:,:self.padding, self.padding:-self.padding], 121 | y, y0[:,:,-self.padding:, self.padding:-self.padding]), dim=-2) 122 | return torch.cat((y0[:,:,:,:self.padding], y, y0[:,:,:,-self.padding:]), dim=-1) 123 | 124 | def _antireflection_pad_1d(self, y, dim): 125 | front_padding = -y.narrow(dim, 0, self.padding).flip([dim]) 126 | back_padding = -y.narrow(dim, y.shape[dim]-self.padding, self.padding).flip([dim]) 127 | return torch.cat((front_padding, y, back_padding), dim=dim) 128 | 129 | def _f(self, y, linear_weight, linear_bias, in_weight, in_bias, 130 | out_weight, out_bias, prop_weight, prop_bias): 131 | y_lin = self.conv_linear(y, linear_weight, linear_bias) if self.conv_linear is not None else 0 132 | 133 | y = self.conv_in(y, in_weight, in_bias) 134 | y = F.relu(y, inplace=True) 135 | for j in range(self.prop_layers): 136 | y = self.conv_prop[j](y, prop_weight[:,j], prop_bias[:,j]) 137 | y = F.relu(y, inplace=True) 138 | y = self.conv_out(y, out_weight, out_bias) 139 | 140 | return y + y_lin 141 | 142 | def forward(self, y0, linear_weight, linear_bias, 143 | in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias, depth): 144 | if self.boundary_cond == 'crop': 145 | # requires entire target solution as y0 for padding purposes 146 | assert len(y0.shape) == 5 147 | assert y0.shape[1] == self.data_channels 148 | assert y0.shape[2] == depth 149 | y_pad = y0[:,:,0] 150 | y = y0[:,:,0, self.padding:-self.padding, self.padding:-self.padding] 151 | elif self.boundary_cond == 'periodic' or self.boundary_cond == 'dirichlet0': 152 | assert len(y0.shape) == 4 153 | assert y0.shape[1] == self.data_channels 154 | y = y0 155 | else: 156 | raise ValueError("Invalid boundary condition.") 157 | 158 | f = lambda y: self._f(y, linear_weight, linear_bias, in_weight, in_bias, 159 | out_weight, out_bias, prop_weight, prop_bias) 160 | 161 | y_list = [] 162 | for i in range(depth): 163 | if self.boundary_cond == 'crop': 164 | if i > 0: 165 | y_pad = self._target_pad_2d(y, y0[:,:,i]) 166 | elif self.boundary_cond == 'dirichlet0': 167 | y_pad = self._antireflection_pad_1d(self._antireflection_pad_1d(y, -1), -2) 168 | elif self.boundary_cond == 'periodic': 169 | y_pad = y 170 | 171 | ### Euler integrator 172 | dt = 1e-6 # NOT REAL TIME STEP, JUST HYPERPARAMETER 173 | noise = self.prop_noise * torch.randn_like(y) if (self.training and self.prop_noise > 0) else 0 174 | y = y + self.cutoff * torch.tanh((dt * f(y_pad)) / self.cutoff) + noise 175 | 176 | y_list.append(y) 177 | 178 | return torch.stack(y_list, dim=-3) 179 | 180 | 181 | class PDEAutoEncoder(nn.Module): 182 | def __init__(self, param_size=1, data_channels=1, data_dimension=2, hidden_channels=16, 183 | linear_kernel_size=0, nonlin_kernel_size=5, prop_layers=1, prop_noise=0., 184 | boundary_cond='periodic', param_dropout_prob=0.1, debug=False): 185 | 186 | assert data_dimension == 2 187 | 188 | super(PDEAutoEncoder, self).__init__() 189 | 190 | self.param_size = param_size 191 | self.data_channels = data_channels 192 | self.hidden_channels = hidden_channels 193 | self.linear_kernel_size = linear_kernel_size 194 | self.nonlin_kernel_size = nonlin_kernel_size 195 | self.prop_layers = prop_layers 196 | self.boundary_cond = boundary_cond 197 | self.param_dropout_prob = param_dropout_prob 198 | self.debug = debug 199 | 200 | if param_size > 0: 201 | ### 3D Convolutional Encoder 202 | if boundary_cond =='crop' or boundary_cond == 'dirichlet0': 203 | pad_input = [0, 0, 0, 0] 204 | pad_func = PeriodicPad1d # can be anything since no padding is added 205 | elif boundary_cond == 'periodic': 206 | pad_input = [1, 2, 4, 8] 207 | pad_func = PeriodicPad1d 208 | else: 209 | raise ValueError("Invalid boundary condition.") 210 | 211 | self.encoder = nn.Sequential( pad_func(pad_input[0], dim=-1), 212 | pad_func(pad_input[0], dim=-2), 213 | nn.Conv3d(data_channels, 8, kernel_size=3, dilation=1), 214 | nn.ReLU(inplace=True), 215 | 216 | pad_func(pad_input[1], dim=-1), 217 | pad_func(pad_input[1], dim=-2), 218 | nn.Conv3d(8, 64, kernel_size=3, dilation=2), 219 | nn.ReLU(inplace=True), 220 | 221 | pad_func(pad_input[2], dim=-1), 222 | pad_func(pad_input[2], dim=-2), 223 | nn.Conv3d(64, 64, kernel_size=3, dilation=4), 224 | nn.ReLU(inplace=True), 225 | 226 | pad_func(pad_input[3], dim=-1), 227 | pad_func(pad_input[3], dim=-2), 228 | nn.Conv3d(64, 64, kernel_size=3, dilation=8), 229 | nn.ReLU(inplace=True), 230 | ) 231 | self.encoder_to_param = nn.Sequential(nn.Conv3d(64, param_size, kernel_size=1, stride=1)) 232 | self.encoder_to_logvar = nn.Sequential(nn.Conv3d(64, param_size, kernel_size=1, stride=1)) 233 | 234 | ### Parameter to weight/bias for dynamic convolutions 235 | if linear_kernel_size > 0: 236 | self.param_to_linear_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * data_channels), 237 | nn.ReLU(inplace=True), 238 | nn.Linear(16 * data_channels * data_channels, 239 | data_channels * data_channels * linear_kernel_size * linear_kernel_size) 240 | ) 241 | 242 | self.param_to_in_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * hidden_channels), 243 | nn.ReLU(inplace=True), 244 | nn.Linear(16 * data_channels * hidden_channels, 245 | data_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 246 | ) 247 | self.param_to_in_bias = nn.Sequential( nn.Linear(param_size, 4 * hidden_channels), 248 | nn.ReLU(inplace=True), 249 | nn.Linear(4 * hidden_channels, hidden_channels) 250 | ) 251 | 252 | self.param_to_out_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * hidden_channels), 253 | nn.ReLU(inplace=True), 254 | nn.Linear(16 * data_channels * hidden_channels, 255 | data_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 256 | ) 257 | self.param_to_out_bias = nn.Sequential( nn.Linear(param_size, 4 * data_channels), 258 | nn.ReLU(inplace=True), 259 | nn.Linear(4 * data_channels, data_channels) 260 | ) 261 | 262 | if prop_layers > 0: 263 | self.param_to_prop_weight = nn.Sequential( nn.Linear(param_size, 16 * prop_layers * hidden_channels * hidden_channels), 264 | nn.ReLU(inplace=True), 265 | nn.Linear(16 * prop_layers * hidden_channels * hidden_channels, 266 | prop_layers * hidden_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 267 | ) 268 | self.param_to_prop_bias = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels), 269 | nn.ReLU(inplace=True), 270 | nn.Linear(4 * prop_layers * hidden_channels, prop_layers * hidden_channels) 271 | ) 272 | 273 | ### Decoder/PDE simulator 274 | self.decoder = ConvPropagator(hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, 275 | linear_padding=int((linear_kernel_size-1)/2), 276 | nonlin_padding=int((nonlin_kernel_size-1)/2), 277 | prop_layers=prop_layers, prop_noise=prop_noise, boundary_cond=boundary_cond) 278 | 279 | def forward(self, x, y0, depth): 280 | 281 | if self.param_size > 0: 282 | assert len(x.shape) == 5 283 | assert x.shape[1] == self.data_channels 284 | 285 | ### 3D Convolutional Encoder 286 | encoder_out = self.encoder(x) 287 | 288 | logvar = self.encoder_to_logvar(encoder_out) 289 | logvar_size = logvar.shape 290 | logvar = logvar.view(logvar_size[0], logvar_size[1], -1) 291 | params = self.encoder_to_param(encoder_out).view(logvar_size[0], logvar_size[1], -1) 292 | 293 | if self.debug: 294 | raw_params = params 295 | 296 | # Parameter Spatial Averaging Dropout 297 | if self.training and self.param_dropout_prob > 0: 298 | mask = torch.bernoulli(torch.full_like(logvar, self.param_dropout_prob)) 299 | mask[mask > 0] = float("inf") 300 | logvar = logvar + mask 301 | 302 | # Inverse variance weighted average of params 303 | weights = F.softmax(-logvar, dim=-1) 304 | params = (params * weights).sum(dim=-1) 305 | 306 | # Compute logvar for inverse variance weighted average with a correlation length correction 307 | correlation_length = 31 # estimated as receptive field of the convolutional encoder 308 | logvar = -torch.logsumexp(-logvar, dim=-1) \ 309 | + torch.log(torch.tensor( 310 | max(1, (1 - self.param_dropout_prob) 311 | * min(correlation_length, logvar_size[-3]) 312 | * min(correlation_length, logvar_size[-2]) 313 | * min(correlation_length, logvar_size[-1])), 314 | dtype=logvar.dtype, device=logvar.device)) 315 | 316 | ### Variational autoencoder reparameterization trick 317 | if self.training: 318 | stdv = (0.5 * logvar).exp() 319 | 320 | # Sample from unit normal 321 | z = params + stdv * torch.randn_like(stdv) 322 | else: 323 | z = params 324 | 325 | ### Parameter to weight/bias for dynamic convolutions 326 | if self.linear_kernel_size > 0: 327 | linear_weight = self.param_to_linear_weight(z) 328 | linear_bias = None 329 | else: 330 | linear_weight = None 331 | linear_bias = None 332 | 333 | in_weight = self.param_to_in_weight(z) 334 | in_bias = self.param_to_in_bias(z) 335 | 336 | out_weight = self.param_to_out_weight(z) 337 | out_bias = self.param_to_out_bias(z) 338 | 339 | if self.prop_layers > 0: 340 | prop_weight = self.param_to_prop_weight(z).view(-1, self.prop_layers, 341 | self.hidden_channels * self.hidden_channels * self.nonlin_kernel_size * self.nonlin_kernel_size) 342 | prop_bias = self.param_to_prop_bias(z).view(-1, self.prop_layers, self.hidden_channels) 343 | else: 344 | prop_weight = None 345 | prop_bias = None 346 | 347 | else: # if no parameter used 348 | linear_weight = None 349 | linear_bias = None 350 | in_weight = None 351 | in_bias = None 352 | out_weight = None 353 | out_bias = None 354 | prop_weight = None 355 | prop_bias = None 356 | params = None 357 | logvar = None 358 | 359 | ### Decoder/PDE simulator 360 | y = self.decoder(y0, linear_weight, linear_bias, in_weight, in_bias, out_weight, out_bias, 361 | prop_weight, prop_bias, depth) 362 | 363 | if self.debug: 364 | return y, params, logvar, [in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias], \ 365 | weights.view(logvar_size), raw_params.view(logvar_size) 366 | 367 | return y, params, logvar 368 | -------------------------------------------------------------------------------- /models/pde2d_decoder_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | pde2d_decoder_only.py 3 | 4 | Propagating decoder network only (PDEDecoder module) with tunable latent parameters z. 5 | To use, import weights from trained model. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.nn.modules.utils import _pair 13 | 14 | from torch.nn.parameter import Parameter 15 | 16 | 17 | class PeriodicPad1d(nn.Module): 18 | def __init__(self, pad, dim=-1): 19 | super(PeriodicPad1d, self).__init__() 20 | self.pad = pad 21 | self.dim = dim 22 | 23 | def forward(self, x): 24 | if self.pad > 0: 25 | front_padding = x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad) 26 | back_padding = x.narrow(self.dim, 0, self.pad) 27 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 28 | 29 | return x 30 | 31 | class AntiReflectionPad1d(nn.Module): 32 | def __init__(self, pad, dim=-1): 33 | super(PeriodicPad1d, self).__init__() 34 | self.pad = pad 35 | self.dim = dim 36 | 37 | def forward(self, x): 38 | if self.pad > 0: 39 | front_padding = -x.narrow(self.dim, 0, self.pad).flip([self.dim]) 40 | back_padding = -x.narrow(self.dim, x.shape[self.dim]-self.pad, self.pad).flip([self.dim]) 41 | x = torch.cat((front_padding, x, back_padding), dim=self.dim) 42 | 43 | return x 44 | 45 | 46 | class DynamicConv2d(nn.Module): 47 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 48 | padding=0, dilation=1, groups=1, boundary_cond='periodic'): 49 | 50 | super(DynamicConv2d, self).__init__() 51 | 52 | self.kernel_size = _pair(kernel_size) 53 | self.stride = _pair(stride) # not implemented 54 | self.padding = _pair(padding) 55 | self.dilation = _pair(dilation) # not implemented 56 | 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | 60 | self.boundary_cond = boundary_cond 61 | 62 | if (self.padding[0] > 0 or self.padding[1] > 0) and boundary_cond == 'periodic': 63 | assert self.padding[0] == int((self.kernel_size[0]-1)/2) 64 | assert self.padding[1] == int((self.kernel_size[1]-1)/2) 65 | self.pad = nn.Sequential( PeriodicPad1d(self.padding[1], dim=-1), 66 | PeriodicPad1d(self.padding[0], dim=-2)) 67 | else: 68 | self.pad = None 69 | 70 | def forward(self, input, weight, bias): 71 | y = input 72 | 73 | if self.pad is not None: 74 | output_size = input.shape[-2:] 75 | y = self.pad(y) 76 | else: 77 | output_size = ( input.shape[-2] - (self.kernel_size[0]-1), 78 | input.shape[-1] - (self.kernel_size[1]-1)) 79 | image_patches = F.unfold(y, self.kernel_size, self.dilation, 0, self.stride).transpose(1, 2) 80 | y = image_patches.matmul(weight.view(-1, 81 | self.in_channels * self.kernel_size[0] * self.kernel_size[1], 82 | self.out_channels)) 83 | if bias is not None: 84 | y = y + bias.view(-1, 1, self.out_channels) 85 | 86 | return y.transpose(1, 2).view(-1, self.out_channels, output_size[0], output_size[1]) 87 | 88 | 89 | class ConvPropagator(nn.Module): 90 | def __init__(self, hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, stride=1, 91 | linear_padding=0, nonlin_padding=0, dilation=1, groups=1, prop_layers=1, prop_noise=0., boundary_cond='periodic'): 92 | 93 | self.data_channels = data_channels 94 | self.prop_layers = prop_layers 95 | self.prop_noise = prop_noise 96 | self.boundary_cond = boundary_cond 97 | 98 | assert nonlin_padding == int((nonlin_kernel_size-1)/2) 99 | if boundary_cond == 'crop' or boundary_cond == 'dirichlet0': 100 | self.padding = int((2+prop_layers)*nonlin_padding) 101 | 102 | super(ConvPropagator, self).__init__() 103 | 104 | self.conv_linear = DynamicConv2d(data_channels, data_channels, linear_kernel_size, stride, 105 | linear_padding, dilation, groups, boundary_cond) if linear_kernel_size > 0 else None 106 | 107 | self.conv_in = DynamicConv2d(data_channels, hidden_channels, nonlin_kernel_size, stride, 108 | nonlin_padding, dilation, groups, boundary_cond) 109 | 110 | self.conv_out = DynamicConv2d(hidden_channels, data_channels, nonlin_kernel_size, stride, 111 | nonlin_padding, dilation, groups, boundary_cond) 112 | 113 | if prop_layers > 0: 114 | self.conv_prop = nn.ModuleList([DynamicConv2d(hidden_channels, hidden_channels, nonlin_kernel_size, stride, 115 | nonlin_padding, dilation, groups, boundary_cond) 116 | for i in range(prop_layers)]) 117 | 118 | self.cutoff = Parameter(torch.Tensor([1])) 119 | 120 | def _target_pad_2d(self, y, y0): 121 | y = torch.cat((y0[:,:,:self.padding, self.padding:-self.padding], 122 | y, y0[:,:,-self.padding:, self.padding:-self.padding]), dim=-2) 123 | return torch.cat((y0[:,:,:,:self.padding], y, y0[:,:,:,-self.padding:]), dim=-1) 124 | 125 | def _antireflection_pad_1d(self, y, dim): 126 | front_padding = -y.narrow(dim, 0, self.padding).flip([dim]) 127 | back_padding = -y.narrow(dim, y.shape[dim]-self.padding, self.padding).flip([dim]) 128 | return torch.cat((front_padding, y, back_padding), dim=dim) 129 | 130 | def _f(self, y, linear_weight, linear_bias, in_weight, in_bias, 131 | out_weight, out_bias, prop_weight, prop_bias): 132 | y_lin = self.conv_linear(y, linear_weight, linear_bias) if self.conv_linear is not None else 0 133 | 134 | y = self.conv_in(y, in_weight, in_bias) 135 | y = F.relu(y, inplace=True) 136 | for j in range(self.prop_layers): 137 | y = self.conv_prop[j](y, prop_weight[:,j], prop_bias[:,j]) 138 | y = F.relu(y, inplace=True) 139 | y = self.conv_out(y, out_weight, out_bias) 140 | 141 | return y + y_lin 142 | 143 | def forward(self, y0, linear_weight, linear_bias, 144 | in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias, depth): 145 | if self.boundary_cond == 'crop': 146 | # requires entire target solution as y0 for padding purposes 147 | assert len(y0.shape) == 5 148 | assert y0.shape[1] == self.data_channels 149 | assert y0.shape[2] == depth 150 | y_pad = y0[:,:,0] 151 | y = y0[:,:,0, self.padding:-self.padding, self.padding:-self.padding] 152 | elif self.boundary_cond == 'periodic' or self.boundary_cond == 'dirichlet0': 153 | assert len(y0.shape) == 4 154 | assert y0.shape[1] == self.data_channels 155 | y = y0 156 | else: 157 | raise ValueError("Invalid boundary condition.") 158 | 159 | f = lambda y: self._f(y, linear_weight, linear_bias, in_weight, in_bias, 160 | out_weight, out_bias, prop_weight, prop_bias) 161 | 162 | y_list = [] 163 | for i in range(depth): 164 | if self.boundary_cond == 'crop': 165 | if i > 0: 166 | y_pad = self._target_pad_2d(y, y0[:,:,i]) 167 | elif self.boundary_cond == 'dirichlet0': 168 | y_pad = self._antireflection_pad_1d(self._antireflection_pad_1d(y, -1), -2) 169 | elif self.boundary_cond == 'periodic': 170 | y_pad = y 171 | 172 | ### Euler integrator 173 | dt = 1e-6 # NOT REAL TIME STEP, JUST HYPERPARAMETER 174 | noise = self.prop_noise * torch.randn_like(y) if (self.training and self.prop_noise > 0) else 0 175 | y = y + self.cutoff * torch.tanh((dt * f(y_pad)) / self.cutoff) + noise 176 | 177 | y_list.append(y) 178 | 179 | return torch.stack(y_list, dim=-3) 180 | 181 | 182 | class PDEDecoder(nn.Module): 183 | def __init__(self, param_size=1, data_channels=1, data_dimension=2, hidden_channels=16, 184 | linear_kernel_size=0, nonlin_kernel_size=5, prop_layers=1, prop_noise=0., 185 | boundary_cond='periodic', param_dropout_prob=0.1, debug=False): 186 | 187 | assert data_dimension == 2 188 | 189 | super(PDEDecoder, self).__init__() 190 | 191 | self.param_size = param_size 192 | self.data_channels = data_channels 193 | self.hidden_channels = hidden_channels 194 | self.linear_kernel_size = linear_kernel_size 195 | self.nonlin_kernel_size = nonlin_kernel_size 196 | self.prop_layers = prop_layers 197 | self.boundary_cond = boundary_cond 198 | self.param_dropout_prob = param_dropout_prob 199 | self.debug = debug 200 | 201 | if param_size > 0: 202 | ### Parameter to weight/bias for dynamic convolutions 203 | if linear_kernel_size > 0: 204 | self.param_to_linear_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * data_channels), 205 | nn.ReLU(inplace=True), 206 | nn.Linear(16 * data_channels * data_channels, 207 | data_channels * data_channels * linear_kernel_size * linear_kernel_size) 208 | ) 209 | 210 | self.param_to_in_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * hidden_channels), 211 | nn.ReLU(inplace=True), 212 | nn.Linear(16 * data_channels * hidden_channels, 213 | data_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 214 | ) 215 | self.param_to_in_bias = nn.Sequential( nn.Linear(param_size, 4 * hidden_channels), 216 | nn.ReLU(inplace=True), 217 | nn.Linear(4 * hidden_channels, hidden_channels) 218 | ) 219 | 220 | self.param_to_out_weight = nn.Sequential( nn.Linear(param_size, 16 * data_channels * hidden_channels), 221 | nn.ReLU(inplace=True), 222 | nn.Linear(16 * data_channels * hidden_channels, 223 | data_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 224 | ) 225 | self.param_to_out_bias = nn.Sequential( nn.Linear(param_size, 4 * data_channels), 226 | nn.ReLU(inplace=True), 227 | nn.Linear(4 * data_channels, data_channels) 228 | ) 229 | 230 | if prop_layers > 0: 231 | self.param_to_prop_weight = nn.Sequential( nn.Linear(param_size, 16 * prop_layers * hidden_channels * hidden_channels), 232 | nn.ReLU(inplace=True), 233 | nn.Linear(16 * prop_layers * hidden_channels * hidden_channels, 234 | prop_layers * hidden_channels * hidden_channels * nonlin_kernel_size * nonlin_kernel_size) 235 | ) 236 | self.param_to_prop_bias = nn.Sequential( nn.Linear(param_size, 4 * prop_layers * hidden_channels), 237 | nn.ReLU(inplace=True), 238 | nn.Linear(4 * prop_layers * hidden_channels, prop_layers * hidden_channels) 239 | ) 240 | 241 | ### Decoder/PDE simulator 242 | self.decoder = ConvPropagator(hidden_channels, linear_kernel_size, nonlin_kernel_size, data_channels, 243 | linear_padding=int((linear_kernel_size-1)/2), 244 | nonlin_padding=int((nonlin_kernel_size-1)/2), 245 | prop_layers=prop_layers, prop_noise=prop_noise, boundary_cond=boundary_cond) 246 | 247 | def forward(self, z, y0, depth): 248 | 249 | if self.param_size > 0: 250 | assert len(z.shape) == 2 251 | assert z.shape[1] == self.param_size 252 | 253 | ### Parameter to weight/bias for dynamic convolutions 254 | if self.linear_kernel_size > 0: 255 | linear_weight = self.param_to_linear_weight(z) 256 | linear_bias = None 257 | else: 258 | linear_weight = None 259 | linear_bias = None 260 | 261 | in_weight = self.param_to_in_weight(z) 262 | in_bias = self.param_to_in_bias(z) 263 | 264 | out_weight = self.param_to_out_weight(z) 265 | out_bias = self.param_to_out_bias(z) 266 | 267 | if self.prop_layers > 0: 268 | prop_weight = self.param_to_prop_weight(z).view(-1, self.prop_layers, 269 | self.hidden_channels * self.hidden_channels * self.nonlin_kernel_size * self.nonlin_kernel_size) 270 | prop_bias = self.param_to_prop_bias(z).view(-1, self.prop_layers, self.hidden_channels) 271 | else: 272 | prop_weight = None 273 | prop_bias = None 274 | 275 | else: # if no parameter used 276 | linear_weight = None 277 | linear_bias = None 278 | in_weight = None 279 | in_bias = None 280 | out_weight = None 281 | out_bias = None 282 | prop_weight = None 283 | prop_bias = None 284 | 285 | ### Decoder/PDE simulator 286 | y = self.decoder(y0, linear_weight, linear_bias, in_weight, in_bias, out_weight, out_bias, 287 | prop_weight, prop_bias, depth) 288 | 289 | if self.debug: 290 | return y, z, None, [in_weight, in_bias, out_weight, out_bias, prop_weight, prop_bias] 291 | 292 | return y, z, None 293 | -------------------------------------------------------------------------------- /npz_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | npz_dataset.py 3 | 4 | Dataset class (for PyTorch DataLoader) for data saved in *.npz or *.npy format. 5 | If using a *.npz file, it must contain an array 'x' that stores all the data and 6 | can contain an optional array 'params' of known parameters for comparison. 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | class PDEDataset(Dataset): 15 | """PDE dataset with inputs x and targets also x.""" 16 | 17 | def __init__(self, data_file=None, transform=None, data_size=None): 18 | """ 19 | Args: 20 | data_file (numpy save): file with all data 21 | transform (callable, optional): Optional transform to be applied 22 | on a sample. 23 | """ 24 | data = np.load(data_file) 25 | 26 | if type(data) is np.ndarray: 27 | self.data_x = data 28 | self.params = None 29 | elif 'x' in data.files: 30 | self.data_x = data['x'] 31 | self.params = data['params'] if 'params' in data.files else None 32 | else: 33 | raise ValueError("Dataset import failed. NPZ files must include 'x' array containing data.") 34 | 35 | self.transform = transform 36 | 37 | 38 | def __len__(self): 39 | return len(self.data_x) 40 | 41 | def __getitem__(self, idx): 42 | 43 | x = torch.from_numpy(self.data_x[idx]) 44 | 45 | if self.params is None: 46 | sample = [x, x, torch.tensor(float('nan'))] 47 | else: 48 | sample = [x, x, torch.from_numpy(self.params[idx])] 49 | 50 | if self.transform: 51 | sample = self.transform(sample) 52 | 53 | return sample 54 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | ''''exec python -u -- "$0" "$@" # ''' 3 | # vi: syntax=python 4 | 5 | """ 6 | run.py 7 | 8 | Main script for training or evaluating a PDE-VAE model specified by the input file (JSON format). 9 | 10 | Usage: 11 | python run.py input_file.json > out 12 | """ 13 | 14 | import os 15 | import sys 16 | from shutil import copy2 17 | import json 18 | from types import SimpleNamespace 19 | import warnings 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | def setup(in_file): 30 | # Load configuration from json 31 | with open(in_file) as f: 32 | s = json.load(f, object_hook=lambda d: SimpleNamespace(**d)) 33 | 34 | # Some defaults 35 | if not hasattr(s, 'train'): 36 | raise NameError("'train' must be set to True for training or False for evaluation.") 37 | elif s.train == False and not hasattr(s, 'MODELLOAD'): 38 | raise NameError("'MODELLOAD' file name required for evaluation.") 39 | 40 | if not hasattr(s, 'restart'): 41 | s.restart = not s.train 42 | warnings.warn("Automatically setting 'restart' to " + str(s.restart)) 43 | if s.restart and not hasattr(s, 'MODELLOAD'): 44 | raise NameError("'MODELLOAD' file name required for restart.") 45 | 46 | if not hasattr(s, 'freeze_encoder'): 47 | s.freeze_encoder = False 48 | elif s.freeze_encoder and not s.restart: 49 | raise ValueError("Freeezing encoder weights requires 'restart' set to True with encoder weights loaded from file.") 50 | 51 | if not hasattr(s, 'data_parallel'): 52 | s.data_parallel = False 53 | if not hasattr(s, 'debug'): 54 | s.debug = False 55 | if not hasattr(s, 'discount_rate'): 56 | s.discount_rate = 0. 57 | if not hasattr(s, 'rate_decay'): 58 | s.rate_decay = 0. 59 | if not hasattr(s, 'param_dropout_prob'): 60 | s.param_dropout_prob = 0. 61 | if not hasattr(s, 'prop_noise'): 62 | s.prop_noise = 0. 63 | 64 | if not hasattr(s, 'boundary_cond'): 65 | raise NameError("Boundary conditions 'boundary_cond' not set. Options include: 'crop', 'periodic', 'dirichlet0'") 66 | elif s.boundary_cond == 'crop' and (not hasattr(s, 'input_size') or not hasattr(s, 'training_size')): 67 | raise NameError("'input_size' or 'training_size' not set for crop boundary conditions.") 68 | 69 | # Create output folder 70 | if not os.path.exists(s.OUTFOLDER): 71 | print("Creating output folder: " + s.OUTFOLDER) 72 | os.makedirs(s.OUTFOLDER) 73 | elif s.train and os.listdir(s.OUTFOLDER): 74 | raise FileExistsError("Output folder " + s.OUTFOLDER + " is not empty.") 75 | 76 | # Make a copy of the configuration file in the output folder 77 | copy2(in_file, s.OUTFOLDER) 78 | 79 | # Print configuration 80 | print(s) 81 | 82 | # Import class for dataset type 83 | dataset = __import__(s.dataset_type, globals(), locals(), ['PDEDataset']) 84 | s.PDEDataset = dataset.PDEDataset 85 | 86 | # Import selected model from models as PDEModel 87 | models = __import__('models.' + s.model, globals(), locals(), ['PDEAutoEncoder']) 88 | PDEModel = models.PDEAutoEncoder 89 | 90 | # Initialize model 91 | model = PDEModel(param_size=s.param_size, data_channels=s.data_channels, data_dimension=s.data_dimension, 92 | hidden_channels=s.hidden_channels, linear_kernel_size=s.linear_kernel_size, 93 | nonlin_kernel_size=s.nonlin_kernel_size, prop_layers=s.prop_layers, prop_noise=s.prop_noise, 94 | boundary_cond=s.boundary_cond, param_dropout_prob=s.param_dropout_prob, debug=s.debug) 95 | 96 | # Set CUDA device 97 | s.use_cuda = torch.cuda.is_available() 98 | if s.use_cuda: 99 | print("Using cuda device(s): " + str(s.cuda_device)) 100 | torch.cuda.set_device(s.cuda_device) 101 | model.cuda() 102 | else: 103 | warnings.warn("Warning: Using CPU only. This is untested.") 104 | 105 | print("\nModel parameters:") 106 | for name, param in model.named_parameters(): 107 | if param.requires_grad: 108 | print("\t{:<40}{}".format(name + ":", param.shape)) 109 | 110 | return model, s 111 | 112 | 113 | def _periodic_pad_1d(x, dim, pad): 114 | back_padding = x.narrow(dim, 0, pad) 115 | return torch.cat((x, back_padding), dim=dim) 116 | 117 | 118 | def _random_crop_1d(sample, depth, crop_size): 119 | sample_size = sample[0].shape 120 | crop_t = [np.random.randint(sample_size[-2]-depth[0]+1), np.random.randint(sample_size[-2]-depth[1]+1)] 121 | crop_x = [np.random.randint(sample_size[-1]), np.random.randint(sample_size[-1])] 122 | 123 | if crop_size[0] > 1: 124 | sample[0] = _periodic_pad_1d(sample[0], -1, crop_size[0]-1) 125 | if crop_size[1] > 1: 126 | sample[1] = _periodic_pad_1d(sample[1], -1, crop_size[1]-1) 127 | 128 | if len(sample_size) == 3: 129 | sample[0] = sample[0][:, crop_t[0]:(crop_t[0]+depth[0]), crop_x[0]:(crop_x[0]+crop_size[0])] 130 | sample[1] = sample[1][:, crop_t[1]:(crop_t[1]+depth[1]), crop_x[1]:(crop_x[1]+crop_size[1])] 131 | elif len(sample_size) == 2: 132 | sample[0] = sample[0][crop_t[0]:(crop_t[0]+depth[0]), crop_x[0]:(crop_x[0]+crop_size[0])] 133 | sample[1] = sample[1][crop_t[1]:(crop_t[1]+depth[1]), crop_x[1]:(crop_x[1]+crop_size[1])] 134 | else: 135 | raise ValueError('Sample is the wrong shape.') 136 | 137 | return sample 138 | 139 | 140 | def _random_crop_2d(sample, depth, crop_size): 141 | sample_size = sample[0].shape 142 | crop_t = [np.random.randint(sample_size[-3]-depth[0]+1), np.random.randint(sample_size[-3]-depth[1]+1)] 143 | crop_x = [np.random.randint(sample_size[-2]), np.random.randint(sample_size[-2])] 144 | crop_y = [np.random.randint(sample_size[-1]), np.random.randint(sample_size[-1])] 145 | 146 | if crop_size[0] > 1: 147 | sample[0] = _periodic_pad_1d(_periodic_pad_1d(sample[0], -1, crop_size[0]-1), -2, crop_size[0]-1) 148 | if crop_size[1] > 1: 149 | sample[1] = _periodic_pad_1d(_periodic_pad_1d(sample[1], -1, crop_size[1]-1), -2, crop_size[1]-1) 150 | 151 | if len(sample_size) == 4: 152 | sample[0] = sample[0][:, crop_t[0]:(crop_t[0]+depth[0]), crop_x[0]:(crop_x[0]+crop_size[0]), crop_y[0]:(crop_y[0]+crop_size[0])] 153 | sample[1] = sample[1][:, crop_t[1]:(crop_t[1]+depth[1]), crop_x[1]:(crop_x[1]+crop_size[1]), crop_y[1]:(crop_y[1]+crop_size[1])] 154 | elif len(sample_size) == 3: 155 | sample[0] = sample[0][crop_t[0]:(crop_t[0]+depth[0]), crop_x[0]:(crop_x[0]+crop_size[0]), crop_y[0]:(crop_y[0]+crop_size[0])] 156 | sample[1] = sample[1][crop_t[1]:(crop_t[1]+depth[1]), crop_x[1]:(crop_x[1]+crop_size[1]), crop_y[1]:(crop_y[1]+crop_size[1])] 157 | else: 158 | raise ValueError('Sample is the wrong shape.') 159 | 160 | return sample 161 | 162 | 163 | def train(model, s): 164 | ### Train model on training set 165 | print("\nTraining...") 166 | 167 | if s.restart: # load model to restart training 168 | print("Loading model from: " + s.MODELLOAD) 169 | strict_load = not s.freeze_encoder 170 | if s.use_cuda: 171 | state_dict = torch.load(s.MODELLOAD, map_location=torch.device('cuda', torch.cuda.current_device())) 172 | else: 173 | state_dict = torch.load(s.MODELLOAD) 174 | model.load_state_dict(state_dict, strict=strict_load) 175 | 176 | if s.freeze_encoder: # freeze encoder weights 177 | print("Freezing weights:") 178 | for name, param in model.encoder.named_parameters(): 179 | param.requires_grad = False 180 | print("\t{:<40}{}".format("encoder." + name + ":", param.size())) 181 | for name, param in model.encoder_to_param.named_parameters(): 182 | param.requires_grad = False 183 | print("\t{:<40}{}".format("encoder_to_param." + name + ":", param.size())) 184 | for name, param in model.encoder_to_logvar.named_parameters(): 185 | param.requires_grad = False 186 | print("\t{:<40}{}".format("encoder_to_logvar." + name + ":", param.size())) 187 | 188 | if s.data_parallel: 189 | model = nn.DataParallel(model, device_ids=s.cuda_device) 190 | 191 | if s.boundary_cond == 'crop': 192 | if s.data_dimension == 1: 193 | transform = lambda x: _random_crop_1d(x, (s.input_depth, s.training_depth+1), (s.input_size, s.training_size)) 194 | elif s.data_dimension == 2: 195 | transform = lambda x: _random_crop_2d(x, (s.input_depth, s.training_depth+1), (s.input_size, s.training_size)) 196 | 197 | pad = int((2+s.prop_layers)*(s.nonlin_kernel_size-1)/2) #for cropping targets 198 | 199 | elif s.boundary_cond == 'periodic' or s.boundary_cond == 'dirichlet0': 200 | transform = None 201 | 202 | else: 203 | raise ValueError("Invalid boundary condition.") 204 | 205 | train_loader = torch.utils.data.DataLoader( 206 | s.PDEDataset(data_file=s.DATAFILE, transform=transform), 207 | batch_size=s.batch_size, shuffle=True, num_workers=s.num_workers, pin_memory=True, 208 | worker_init_fn=lambda _: np.random.seed()) 209 | 210 | optimizer = torch.optim.Adam(model.parameters(), lr=s.learning_rate, eps=s.eps) 211 | 212 | model.train() 213 | 214 | writer = SummaryWriter(log_dir=os.path.join(s.OUTFOLDER, 'data')) 215 | 216 | # Initialize training variables 217 | loss_list = [] 218 | recon_loss_list = [] 219 | mse_list = [] 220 | acc_loss = 0 221 | acc_recon_loss = 0 222 | acc_latent_loss = 0 223 | acc_mse = 0 224 | best_mse = None 225 | step = 0 226 | current_discount_rate = s.discount_rate 227 | 228 | ### Training loop 229 | for epoch in range(1, s.max_epochs+1): 230 | print('\nEpoch: ' + str(epoch)) 231 | 232 | # Introduce a discount rate to favor predicting better in the near future 233 | current_discount_rate = s.discount_rate * np.exp(-s.rate_decay * (epoch-1)) # discount rate decay every epoch 234 | print('discount rate = ' + str(current_discount_rate)) 235 | if current_discount_rate > 0: 236 | w = torch.tensor(np.exp(-current_discount_rate * np.arange(s.training_depth)).reshape( 237 | [s.training_depth] + s.data_dimension * [1]), dtype=torch.float32, device='cuda' if s.use_cuda else 'cpu') 238 | w = w * s.training_depth/w.sum(dim=0, keepdim=True) 239 | else: 240 | w = None 241 | 242 | # Load batch and train 243 | for data, target, data_params in train_loader: 244 | step += 1 245 | 246 | if s.use_cuda: 247 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 248 | 249 | data = data[:,:,:s.input_depth] 250 | if s.boundary_cond == 'crop': 251 | target0 = target[:,:,:s.training_depth] 252 | if s.data_dimension == 1: 253 | target = target[:,:,1:s.training_depth+1, pad:-pad] 254 | elif s.data_dimension == 2: 255 | target = target[:,:,1:s.training_depth+1, pad:-pad, pad:-pad] 256 | 257 | elif s.boundary_cond == 'periodic' or s.boundary_cond == 'dirichlet0': 258 | target0 = target[:,:,0] 259 | target = target[:,:,1:s.training_depth+1] 260 | 261 | else: 262 | raise ValueError("Invalid boundary condition.") 263 | 264 | # Run model 265 | output, params, logvar = model(data, target0, depth=s.training_depth) 266 | 267 | # Reset gradients 268 | optimizer.zero_grad() 269 | 270 | # Calculate loss 271 | if s.data_parallel: 272 | output = output.cpu() 273 | recon_loss = F.mse_loss(output * w, target * w) if w is not None else F.mse_loss(output, target) 274 | if s.param_size > 0: 275 | latent_loss = s.beta * 0.5 * torch.mean(torch.sum(params * params + logvar.exp() - logvar - 1, dim=-1)) 276 | else: 277 | latent_loss = 0 278 | loss = recon_loss + latent_loss 279 | 280 | mse = F.mse_loss(output.detach(), target.detach()).item() if w is not None else recon_loss.item() 281 | 282 | loss_list.append(loss.item()) 283 | recon_loss_list.append(recon_loss.item()) 284 | mse_list.append(mse) 285 | 286 | acc_loss += loss.item() 287 | acc_recon_loss += recon_loss.item() 288 | acc_latent_loss += latent_loss.item() 289 | acc_mse += mse 290 | 291 | # Calculate gradients 292 | loss.backward() 293 | 294 | # Clip gradients 295 | # grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1e0) 296 | 297 | # Update gradients 298 | optimizer.step() 299 | 300 | # Output every 100 steps 301 | if step % 100 == 0: 302 | # Check every 500 steps and save checkpoint if new model is at least 2% better than previous best 303 | if (step > 1 and step % 500 == 0) and ((best_mse is None) or (acc_mse/100 < 0.98*best_mse)): 304 | best_mse = acc_mse/100 305 | torch.save(model.state_dict(), os.path.join(s.OUTFOLDER, "best.tar")) 306 | print('New Best MSE at Step {}: {:.4f}'.format(step, best_mse)) 307 | 308 | # Output losses and weights 309 | if s.param_size > 0: 310 | if step > 1: 311 | # Write losses to summary 312 | writer.add_scalars('losses', {'loss': acc_loss/100, 313 | 'recon_loss': acc_recon_loss/100, 314 | 'latent_loss': acc_latent_loss/100, 315 | 'mse': acc_mse/100}, step) 316 | 317 | acc_loss = 0 318 | acc_recon_loss = 0 319 | acc_latent_loss = 0 320 | acc_mse = 0 321 | 322 | # Write mean model weights to summary 323 | weight_dict = {} 324 | for name, param in model.named_parameters(): 325 | if param.requires_grad: 326 | weight_dict[name] = param.detach().abs().mean().item() 327 | writer.add_scalars('weight_avg', weight_dict, step) 328 | 329 | print('Train Step: {}\tTotal Loss: {:.4f}\tRecon. Loss: {:.4f}\tRecon./Latent: {:.1f}\tMSE: {:.4f}' 330 | .format(step, loss.item(), recon_loss.item(), recon_loss.item()/latent_loss.item(), mse)) 331 | 332 | # Save current set of extracted latent parameters 333 | np.savez(os.path.join(s.OUTFOLDER, "training_params.npz"), data_params=data_params.numpy(), 334 | params=params.detach().cpu().numpy()) 335 | else: 336 | print('Train Step: {}\tTotal Loss: {:.4f}\tRecon. Loss: {:.4f}\tMSE: {:.4f}' 337 | .format(step, loss.item(), recon_loss.item(), mse)) 338 | 339 | # Export checkpoints and loss history after every s.save_epochs epochs 340 | if s.save_epochs > 0 and epoch % s.save_epochs == 0: 341 | torch.save(model.state_dict(), os.path.join(s.OUTFOLDER, "epoch{:06d}.tar".format(epoch))) 342 | np.savez(os.path.join(s.OUTFOLDER, "loss.npz"), loss=np.array(loss_list), 343 | recon_loss=np.array(recon_loss_list), 344 | mse=np.array(mse_list)) 345 | 346 | return model 347 | 348 | 349 | def evaluate(model, s, params_filename="params.npz", rmse_filename="rmse_with_depth.npy"): 350 | ### Evaluate model on test set 351 | print("\nEvaluating...") 352 | 353 | if rmse_filename is not None and os.path.exists(os.path.join(s.OUTFOLDER, rmse_filename)): 354 | raise FileExistsError(rmse_filename + " already exists.") 355 | if os.path.exists(os.path.join(s.OUTFOLDER, params_filename)): 356 | raise FileExistsError(params_filename + " already exists.") 357 | 358 | if not s.train: 359 | print("Loading model from: " + s.MODELLOAD) 360 | if s.use_cuda: 361 | state_dict = torch.load(s.MODELLOAD, map_location=torch.device('cuda', torch.cuda.current_device())) 362 | else: 363 | state_dict = torch.load(s.MODELLOAD) 364 | model.load_state_dict(state_dict) 365 | 366 | pad = int((2+s.prop_layers)*(s.nonlin_kernel_size-1)/2) #for cropping targets (if necessary) 367 | 368 | test_loader = torch.utils.data.DataLoader( 369 | s.PDEDataset(data_file=s.DATAFILE, transform=None), 370 | batch_size=s.batch_size, num_workers=s.num_workers, pin_memory=True) 371 | 372 | model.eval() 373 | torch.set_grad_enabled(False) 374 | 375 | ### Evaluation loop 376 | loss = 0 377 | if rmse_filename is not None: 378 | rmse_with_depth = torch.zeros(s.evaluation_depth, device='cuda' if s.use_cuda else 'cpu') 379 | params_list = [] 380 | logvar_list = [] 381 | data_params_list = [] 382 | step = 0 383 | for data, target, data_params in test_loader: 384 | step += 1 385 | 386 | if s.use_cuda: 387 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 388 | 389 | if s.boundary_cond == 'crop': 390 | target0 = target[:,:,:s.evaluation_depth] 391 | if s.data_dimension == 1: 392 | target = target[:,:,1:s.evaluation_depth+1, pad:-pad] 393 | elif s.data_dimension == 2: 394 | target = target[:,:,1:s.evaluation_depth+1, pad:-pad, pad:-pad] 395 | 396 | elif s.boundary_cond == 'periodic' or s.boundary_cond == 'dirichlet0': 397 | target0 = target[:,:,0] 398 | target = target[:,:,1:s.evaluation_depth+1] 399 | 400 | else: 401 | raise ValueError("Invalid boundary condition.") 402 | 403 | # Run model 404 | if s.debug: 405 | output, params, logvar, _, weights, raw_params = model(data.contiguous(), target0, depth=s.evaluation_depth) 406 | else: 407 | output, params, logvar = model(data.contiguous(), target0, depth=s.evaluation_depth) 408 | 409 | data_params = data_params.numpy() 410 | data_params_list.append(data_params) 411 | 412 | if s.param_size > 0: 413 | params = params.detach().cpu().numpy() 414 | params_list.append(params) 415 | logvar_list.append(logvar.detach().cpu().numpy()) 416 | 417 | assert output.shape[2] == s.evaluation_depth 418 | loss += F.mse_loss(output, target).item() 419 | 420 | if rmse_filename is not None: 421 | rmse_with_depth += torch.sqrt(torch.mean((output - target).transpose(2,1).contiguous() 422 | .view(target.size()[0], s.evaluation_depth, -1) ** 2, 423 | dim=-1)).mean(0) 424 | 425 | rmse_with_depth = rmse_with_depth.cpu().numpy()/step 426 | print('\nTest Set: Recon. Loss: {:.4f}'.format(loss/step)) 427 | 428 | if rmse_filename is not None: 429 | np.save(os.path.join(s.OUTFOLDER, rmse_filename), rmse_with_depth) 430 | 431 | np.savez(os.path.join(s.OUTFOLDER, params_filename), params=np.concatenate(params_list), 432 | logvar=np.concatenate(logvar_list), 433 | data_params=np.concatenate(data_params_list)) 434 | 435 | 436 | if __name__ == "__main__": 437 | in_file = sys.argv[1] 438 | if not os.path.exists(in_file): 439 | raise FileNotFoundError("Input file " + in_file + " not found.") 440 | 441 | model, s = setup(in_file) 442 | if s.train: 443 | model = train(model, s) 444 | else: 445 | evaluate(model, s) 446 | --------------------------------------------------------------------------------