├── .gitignore ├── LICENSE ├── README.md ├── data_generation ├── KS │ ├── GRF1.m │ ├── KS.m │ └── ks.m └── lorenz │ ├── odelibrary.py │ └── run_ode_solver.py ├── dissipative_utils.py ├── lorenz.ipynb ├── models ├── densenet.py └── fno_2d.py ├── scripts ├── NS_fno_baseline.py ├── NS_mno_dissipative.py ├── lorenz_densenet.py └── lorenz_dissipative_densenet.py ├── utilities.py └── visualize_navier_stokes2d.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Miguel Liu-Schiaffini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Markov Neural Operator (MNO) 2 | 3 | This repository contains the code for the paper ["Learning Dissipative Dynamics in Chaotic Systems,"](https://arxiv.org/abs/2106.06898) published in NeurIPS 2022. 4 | 5 | In this work, we propose a machine learning framework, which we call the Markov Neural Operator (MNO), to learn the underlying solution operator for dissipative chaotic systems, showing that the resulting learned operator accurately captures short-time trajectories and long-time statistical behavior. Using this framework, we are able to predict various statistics of the invariant measure for the turbulent Kolmogorov Flow dynamics with Reynolds numbers up to 5000. 6 | 7 | ## Requirements 8 | * Neural operator code is based on the [Fourier Neural Operator (FNO)](https://github.com/zongyi-li/fourier_neural_operator), which requires PyTorch 1.8.0 or later. 9 | 10 | ## Files 11 | * ``utilities.py``: basic utilities including a reader for .mat files and Sobolev (Hk) and Lp losses. 12 | * ``dissipative_utils.py``: helper functions for encouraging (regularization loss) and enforcing dissipative dynamics (postprocessing). 13 | * ``models/``: model architectures 14 | * ``densenet.py``: simple feedforward neural network 15 | * ``fno_2d.py``: FNO architecture for operators acting on a function space with two spatial dimensions. 16 | * ``data_generation/``: directory containing data generation code for our toy Lorenz-63 dataset and the 1D Kuramoto–Sivashinsky PDE. 17 | * ``scripts/``: scripts for training Lorenz-63 model, 1D KS, and 2D NS equations. 18 | * ``NS_fno_baseline.py``: FNO baseline trained on 2D NS with Reynolds number 500. No dissipativity or Sobolev loss. 19 | * ``NS_mno_dissipative.py``: MNO model built on FNO architecture with dissipativity encouraged and Sobolev loss. 20 | * ``lorenz_densenet.py``: simple feedforward neural network learning Markovian solution operator for Lorenz-63 system. 21 | * ``lorenz_dissipative_densenet.py``: simple feedforward neural network with dissipativity encouraged trained on Lorenz-63 system. 22 | * `lorenz.ipynb`: Jupyter notebook with examples to reproduce plots and figures for our Lorenz-63 examples in the paper. 23 | * `visualize_navier_stokes2d.ipynb` : Jupyter notebook with examples to reproduce plots and figures for our 2D Navier-Stokes case study in the paper. 24 | 25 | ## Datasets 26 | In our work, we train and evaluate on datasets from the Lorenz-63 system (finite-dimensional ODE), Kuramoto–Sivashinsky equation (1D PDE system), and the 2D Navier-Stokes equations (Kolmogorov flow, 2D PDE). Our datasets can be found online under DOI [10.5281/zenodo.74955555](https://zenodo.org/record/7495555). 27 | * Lorenz: Can be found in the `data_generation` directory. 28 | * KS: Can be found in the `data_generation` directory. 29 | * Data generation for 2D Navier-Stokes is based on the data generation scripts in the [FNO repository](https://github.com/zongyi-li/fourier_neural_operator/tree/master/data_generation/navier_stokes). 30 | 31 | ## Models 32 | In our work, we use three different models to learn the Markovian solution operator. These can be found under the ``models/`` folder in the repository. 33 | * **Lorenz:** Since the Lorenz-63 system is a finite-dimensional ODE system, we use a standard feedforward neural network to learn the Markov solution operator. 34 | * **1D KS and 2D NS equations:** We interpret PDEs as function-space ODEs, and we adopt the 1D and 2D FNO architecture (resp.) to learn the Markov solution operator for the 1D KS and 2D NS equations. 35 | 36 | ## Citation 37 | ``` 38 | @article{MNO, 39 | title={Learning chaotic dynamics in dissipative systems}, 40 | author={Li, Zongyi and Liu-Schiaffini, Miguel and Kovachki, Nikola and Azizzadenesheli, Kamyar and Liu, Burigede and Bhattacharya, Kaushik and Stuart, Andrew and Anandkumar, Anima}, 41 | journal={Advances in Neural Information Processing Systems}, 42 | volume={35}, 43 | pages={16768--16781}, 44 | year={2022} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /data_generation/KS/GRF1.m: -------------------------------------------------------------------------------- 1 | %Random function from N(m, C) on [0 1] where 2 | %C = sigma^2(-Delta + tau^2 I)^(-gamma) 3 | %with periodic, zero dirichlet, and zero neumann boundary. 4 | %Dirichlet only supports m = 0. 5 | %N is the # of Fourier modes, usually, grid size / 2. 6 | function u = GRF1(N, m, gamma, tau, sigma, type) 7 | 8 | L = 1; 9 | 10 | if type == "dirichlet" 11 | m = 0; 12 | end 13 | 14 | if type == "periodic" 15 | my_const = 2*pi/L; 16 | else 17 | my_const = pi; 18 | end 19 | 20 | my_eigs = sqrt(2)*(abs(sigma).*((my_const.*(1:N)').^2 + tau^2).^(-gamma/2)); 21 | 22 | if type == "dirichlet" 23 | alpha = zeros(N,1); 24 | else 25 | xi_alpha = randn(N,1); 26 | alpha = my_eigs.*xi_alpha; 27 | end 28 | 29 | if type == "neumann" 30 | beta = zeros(N,1); 31 | else 32 | xi_beta = randn(N,1); 33 | beta = my_eigs.*xi_beta; 34 | end 35 | 36 | a = alpha/2; 37 | b = -beta/2; 38 | 39 | c = [flipud(a) - flipud(b).*1i;m + 0*1i;a + b.*1i]; 40 | 41 | if type == "periodic" 42 | uu = chebfun(c, [0 L], 'trig', 'coeffs'); 43 | u = chebfun(@(t) uu(t - L/2), [0 L], 'trig'); 44 | else 45 | uu = chebfun(c, [-pi pi], 'trig', 'coeffs'); 46 | u = chebfun(@(t) uu(pi*t), [0 1]); 47 | end -------------------------------------------------------------------------------- /data_generation/KS/KS.m: -------------------------------------------------------------------------------- 1 | %u - initial condition 2 | %l - length of interval [0,l) or [-l/2,l/2) 3 | %T - final time 4 | %N - number of solutions to record 5 | %h - internal time step 6 | function [uu, tt] = KS(u, l, T, N, h) 7 | s = length(u(:)); 8 | 9 | v = fft(u); 10 | 11 | k = (2*pi/l)*[0:s/2-1 0 -s/2+1:-1]'; 12 | L = k.^2 - k.^4; 13 | E = exp(h*L); E2 = exp(h*L/2); 14 | M = 64; 15 | r = exp(1i*pi*((1:M)-.5)/M); 16 | LR = h*L(:,ones(M,1)) + r(ones(s,1),:); 17 | Q = h*real(mean((exp(LR/2)-1)./LR,2)); 18 | f1 = h*real(mean( (-4-LR+exp(LR).*(4-3*LR+LR.^2))./LR.^3 ,2)); 19 | f2 = h*real(mean((2+LR+exp(LR).*(-2+LR))./LR.^3,2)); 20 | f3 = h*real(mean( (-4-3*LR-LR.^2+exp(LR).*(4-LR))./LR.^3 ,2)); 21 | 22 | uu = zeros(N,s); 23 | tt = zeros(N,1); 24 | nmax = round(T/h); 25 | nrec = floor((T/N)/h); 26 | g = -0.5i*k; 27 | q = 1; 28 | for n = 1:nmax 29 | t = n*h; 30 | Nv = g.*fft(real(ifft(v)).^2); 31 | a = E2.*v + Q.*Nv; 32 | Na = g.*fft(real(ifft(a)).^2); 33 | b = E2.*v + Q.*Na; 34 | Nb = g.*fft(real(ifft(b)).^2); 35 | c = E2.*a + Q.*(2*Nb-Nv); 36 | Nc = g.*fft(real(ifft(c)).^2); 37 | v = E.*v + Nv.*f1 + 2*(Na+Nb).*f2 + Nc.*f3; 38 | if mod(n,nrec)==0 39 | u = real(ifft(v)); 40 | uu(q,:) = u; tt(q) = t; 41 | q = q + 1; 42 | end 43 | %disp(n); 44 | end -------------------------------------------------------------------------------- /data_generation/KS/ks.m: -------------------------------------------------------------------------------- 1 | s = 512; 2 | %Pretend this is on the interval [0,2*pi*L) 3 | x = (1:s)'/s; 4 | 5 | N = 200; % number of case 6 | T = 1000; % time 7 | t = 10000; % time steps 8 | u_out = zeros(N, t, s); 9 | 10 | for i=1:N 11 | disp(i); 12 | u = GRF1(s/2, 0, 2, 5, 5^2, "periodic"); 13 | u = u(x); 14 | 15 | [uu, tt] = KS(u, 2*pi*32, T, t, 0.1); 16 | u_out(i,:,:) = uu; 17 | 18 | % surf(tt,x,uu'); 19 | % view([90 -90]); 20 | % shading interp; 21 | % colormap jet; 22 | % axis tight; 23 | % colorbar; 24 | end 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /data_generation/lorenz/odelibrary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot 3 | # from numba import jitclass # import the decorator 4 | # from numba import boolean, int64, float32, float64 # import the types 5 | 6 | import pdb 7 | # Correspondence with Dima via Whatsapp on Feb 24, 2020: 8 | # RK45 (explicit) for slow-system-only 9 | # RK45 (implicit) aka Radau for multi-scale-system 10 | # In both cases, set abstol to 1e-6, reltol to 1e-3, dtmax to 1e-3 11 | # L96spec = [ 12 | # ('K', int64), # a simple scalar field 13 | # ('J', int64), # a simple scalar field 14 | # ('hx', float64[:]), # a simple scalar field 15 | # ('hy', float64), # a simple scalar field 16 | # ('F', float64), # a simple scalar field 17 | # ('eps', float64), # a simple scalar field 18 | # ('k0', float64), # a simple scalar field 19 | # ('slow_only', boolean), # a simple scalar field 20 | # ('xk_star', float64[:]) # a simple scalar field 21 | # ] 22 | 23 | class LDS_COUPLED_X: 24 | """ 25 | A simple class that implements a coupled linear dynamical system 26 | 27 | The class computes RHS's to make use of scipy's ODE solvers. 28 | 29 | Parameters: 30 | A 31 | 32 | """ 33 | 34 | def __init__(_s, 35 | A = np.array([[0, 1], [-1, 0]]), 36 | eps_min = 0.001, 37 | eps_max = 0.05, 38 | h = 3.0, 39 | share_gp=True, 40 | add_closure=False): 41 | ''' 42 | Initialize an instance: setting parameters and xkstar 43 | ''' 44 | _s.share_gp = share_gp 45 | _s.A = A 46 | _s.hx = h # just useful when re-using L96 code 47 | _s.eps_min = eps_min 48 | _s.eps_max = eps_max 49 | _s.K = _s.A.shape[0] # slow state dims 50 | _s.J = _s.A.shape[0] # fast state dims 51 | _s.slow_only = False 52 | _s.exchangeable_states = False 53 | _s.add_closure = add_closure 54 | 55 | def get_inits(_s): 56 | state_inits = np.random.uniform(low=-1, high=1, size=_s.K+_s.J) 57 | # normalize inits so that slow and fast system both start on unit circle 58 | state_inits[:_s.K] /= np.sqrt(np.sum(state_inits[:_s.K]**2)) 59 | state_inits[_s.K:] /= np.sqrt(np.sum(state_inits[_s.K:]**2)) 60 | return state_inits 61 | 62 | def get_state_names(_s): 63 | return ['X_'+ str(k+1) for k in range(_s.K)] 64 | 65 | def plot_state_indices(_s): 66 | return [0, _s.K] 67 | 68 | def slow(_s, x, t): 69 | ''' Full system RHS ''' 70 | foo_rhs = _s.A @ x 71 | return foo_rhs 72 | 73 | def eps_f(_s, x): 74 | return _s.eps_min + 2 * (_s.eps_max - _s.eps_min) * (np.prod(x))**2 75 | 76 | def full(_s, z, t): 77 | ''' Full system RHS ''' 78 | x = z[:_s.K] 79 | y = z[_s.K:] 80 | foo_rhs = np.empty(_s.K + _s.J) 81 | foo_rhs[:_s.K] = _s.A @ x + _s.hx*y 82 | foo_rhs[_s.K:] = _s.A @ y / _s.eps_f(x) 83 | return foo_rhs 84 | 85 | def rhs(_s, z, t): 86 | if _s.slow_only: 87 | foo_rhs = _s.slow(z, t) 88 | else: 89 | foo_rhs = _s.full(z, t) 90 | if _s.add_closure: 91 | foo_rhs += _s.simulate(z) 92 | return foo_rhs 93 | 94 | 95 | def regressed(_s, x, t): 96 | ''' Only slow variables with RHS learned from data ''' 97 | rhs = _s.rhs(x,t) 98 | # add data-learned coupling term 99 | rhs += _s.simulate(x) 100 | return rhs 101 | 102 | def set_stencil(_s, left = 0, right = 0): 103 | _s.stencil = np.arange(left, 1 + right) 104 | 105 | def single_step_implied_Ybar(_s, Xnow, Xnext, delta_t): 106 | # use an euler scheme to back-out the implied avg Ybar_t from X_t and X_t+1 107 | Ybar = (Xnext - Xnow)/delta_t - _s.rhs(S=Xnow, t=None) 108 | 109 | return Ybar 110 | 111 | def implied_Ybar(_s, X_in, X_out, delta_t): 112 | # the idea is that X_in are true data coming from a test/training set 113 | # Xout(k) is the 1-step-ahed prediction associated to Xin(k). 114 | # In other words Xout(k) = Psi-ML(Xin(k)) 115 | T = X_in.shape[0] 116 | Ybar = np.zeros( (T, _s.K) ) 117 | for t in range(T): 118 | Ybar[t,:] = _s.single_step_implied_Ybar(Xnow=X_in[t,:], Xnext=X_out[t,:], delta_t=delta_t) 119 | return Ybar 120 | 121 | def get_state_limits(_s): 122 | lims = (None,None) 123 | return lims 124 | 125 | def set_predictor(_s, predictor): 126 | _s.predictor = predictor 127 | 128 | # def set_G0_predictor(_s): 129 | # _s.predictor = lambda x: _s.hy * x 130 | 131 | def set_null_predictor(_s): 132 | _s.predictor = lambda x: 0 133 | 134 | def simulate(_s, slow): 135 | if _s.share_gp: 136 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 137 | else: 138 | return np.reshape(_s.predictor(slow.reshape(1,-1)), (-1,)) 139 | 140 | def apply_stencil(_s, slow): 141 | # behold: the blackest of all black magic! 142 | # (in a year, I will not understand what this does) 143 | # the idea: shift xk's so that each row corresponds to the stencil: 144 | # (x_{k-1}, x_{k}, x_{k+1}), for example, 145 | # based on '_s.stencil' and 'slow' array (which is (x1,...,xK) ) 146 | return slow[np.add.outer(np.arange(_s.K), _s.stencil) % _s.K] 147 | 148 | 149 | 150 | 151 | 152 | class LDS_COUPLED: 153 | """ 154 | A simple class that implements a coupled linear dynamical system 155 | 156 | The class computes RHS's to make use of scipy's ODE solvers. 157 | 158 | Parameters: 159 | A 160 | 161 | """ 162 | 163 | def __init__(_s, 164 | A = np.array([[0, 1], [-1, 0]]), 165 | eps = 0.05, 166 | h = 0.1, 167 | share_gp=True, 168 | add_closure=False): 169 | ''' 170 | Initialize an instance: setting parameters and xkstar 171 | ''' 172 | _s.share_gp = share_gp 173 | _s.A = A 174 | _s.hx = h # just useful when re-using L96 code 175 | _s.eps = eps 176 | _s.K = _s.A.shape[0] # slow state dims 177 | _s.J = _s.A.shape[0] # fast state dims 178 | _s.slow_only = False 179 | _s.exchangeable_states = False 180 | _s.add_closure = add_closure 181 | 182 | def get_inits(_s): 183 | state_inits = np.random.uniform(low=-1, high=1, size=_s.K+_s.J) 184 | # normalize inits so that slow and fast system both start on unit circle 185 | state_inits[:_s.K] /= np.sqrt(np.sum(state_inits[:_s.K]**2)) 186 | state_inits[_s.K:] /= np.sqrt(np.sum(state_inits[_s.K:]**2)) 187 | return state_inits 188 | 189 | def get_state_names(_s): 190 | return ['X_'+ str(k+1) for k in range(_s.K)] 191 | 192 | def plot_state_indices(_s): 193 | return [0, _s.K] 194 | 195 | def slow(_s, x, t): 196 | ''' Full system RHS ''' 197 | foo_rhs = _s.A @ x 198 | return foo_rhs 199 | 200 | def full(_s, z, t): 201 | ''' Full system RHS ''' 202 | x = z[:_s.K] 203 | y = z[_s.K:] 204 | foo_rhs = np.empty(_s.K + _s.J) 205 | foo_rhs[:_s.K] = _s.A @ x + _s.hx*y 206 | foo_rhs[_s.K:] = _s.A @ y / _s.eps 207 | return foo_rhs 208 | 209 | def rhs(_s, z, t): 210 | if _s.slow_only: 211 | foo_rhs = _s.slow(z, t) 212 | else: 213 | foo_rhs = _s.full(z, t) 214 | if _s.add_closure: 215 | foo_rhs += _s.simulate(z) 216 | return foo_rhs 217 | 218 | 219 | def regressed(_s, x, t): 220 | ''' Only slow variables with RHS learned from data ''' 221 | rhs = _s.rhs(x,t) 222 | # add data-learned coupling term 223 | rhs += _s.simulate(x) 224 | return rhs 225 | 226 | def set_stencil(_s, left = 0, right = 0): 227 | _s.stencil = np.arange(left, 1 + right) 228 | 229 | def single_step_implied_Ybar(_s, Xnow, Xnext, delta_t): 230 | # use an euler scheme to back-out the implied avg Ybar_t from X_t and X_t+1 231 | Ybar = (Xnext - Xnow)/delta_t - _s.rhs(S=Xnow, t=None) 232 | 233 | return Ybar 234 | 235 | def implied_Ybar(_s, X_in, X_out, delta_t): 236 | # the idea is that X_in are true data coming from a test/training set 237 | # Xout(k) is the 1-step-ahed prediction associated to Xin(k). 238 | # In other words Xout(k) = Psi-ML(Xin(k)) 239 | T = X_in.shape[0] 240 | Ybar = np.zeros( (T, _s.K) ) 241 | for t in range(T): 242 | Ybar[t,:] = _s.single_step_implied_Ybar(Xnow=X_in[t,:], Xnext=X_out[t,:], delta_t=delta_t) 243 | return Ybar 244 | 245 | def get_state_limits(_s): 246 | lims = (None,None) 247 | return lims 248 | 249 | def set_predictor(_s, predictor): 250 | _s.predictor = predictor 251 | 252 | # def set_G0_predictor(_s): 253 | # _s.predictor = lambda x: _s.hy * x 254 | 255 | def set_null_predictor(_s): 256 | _s.predictor = lambda x: 0 257 | 258 | def simulate(_s, slow): 259 | if _s.share_gp: 260 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 261 | else: 262 | return np.reshape(_s.predictor(slow.reshape(1,-1)), (-1,)) 263 | 264 | def apply_stencil(_s, slow): 265 | # behold: the blackest of all black magic! 266 | # (in a year, I will not understand what this does) 267 | # the idea: shift xk's so that each row corresponds to the stencil: 268 | # (x_{k-1}, x_{k}, x_{k+1}), for example, 269 | # based on '_s.stencil' and 'slow' array (which is (x1,...,xK) ) 270 | return slow[np.add.outer(np.arange(_s.K), _s.stencil) % _s.K] 271 | 272 | 273 | 274 | class LDS: 275 | """ 276 | A simple class that implements a linear dynamical system 277 | 278 | The class computes RHS's to make use of scipy's ODE solvers. 279 | 280 | Parameters: 281 | A 282 | 283 | """ 284 | 285 | def __init__(_s, 286 | A = np.array([[0, 5], [-5, 0]]), share_gp=True, add_closure=False): 287 | ''' 288 | Initialize an instance: setting parameters and xkstar 289 | ''' 290 | _s.share_gp = share_gp 291 | _s.A = A 292 | _s.K = _s.A.shape[0] # state dims 293 | _s.hx = 1 # just useful when re-using L96 code 294 | _s.slow_only = False 295 | _s.exchangeable_states = False 296 | _s.add_closure = add_closure 297 | 298 | def get_inits(_s): 299 | state_inits = np.random.randn(_s.K) 300 | return state_inits 301 | 302 | def get_state_names(_s): 303 | return ['X_'+ str(k+1) for k in range(_s.K)] 304 | 305 | def plot_state_indices(_s): 306 | return [0, _s.K] 307 | 308 | def slow(_s, y, t): 309 | return _s.rhs(y,t) 310 | 311 | def rhs(_s, S, t): 312 | ''' Full system RHS ''' 313 | foo_rhs = _s.A @ S 314 | if _s.add_closure: 315 | foo_rhs += _s.simulate(S) 316 | return foo_rhs 317 | 318 | def regressed(_s, x, t): 319 | ''' Only slow variables with RHS learned from data ''' 320 | rhs = _s.rhs(x,t) 321 | # add data-learned coupling term 322 | rhs += _s.simulate(x) 323 | return rhs 324 | 325 | def set_stencil(_s, left = 0, right = 0): 326 | _s.stencil = np.arange(left, 1 + right) 327 | 328 | def single_step_implied_Ybar(_s, Xnow, Xnext, delta_t): 329 | # use an euler scheme to back-out the implied avg Ybar_t from X_t and X_t+1 330 | Ybar = (Xnext - Xnow)/delta_t - _s.rhs(S=Xnow, t=None) 331 | 332 | return Ybar 333 | 334 | def implied_Ybar(_s, X_in, X_out, delta_t): 335 | # the idea is that X_in are true data coming from a test/training set 336 | # Xout(k) is the 1-step-ahed prediction associated to Xin(k). 337 | # In other words Xout(k) = Psi-ML(Xin(k)) 338 | T = X_in.shape[0] 339 | Ybar = np.zeros( (T, _s.K) ) 340 | for t in range(T): 341 | Ybar[t,:] = _s.single_step_implied_Ybar(Xnow=X_in[t,:], Xnext=X_out[t,:], delta_t=delta_t) 342 | return Ybar 343 | 344 | def get_state_limits(_s): 345 | lims = (None,None) 346 | return lims 347 | 348 | def set_predictor(_s, predictor): 349 | _s.predictor = predictor 350 | 351 | # def set_G0_predictor(_s): 352 | # _s.predictor = lambda x: _s.hy * x 353 | 354 | def set_null_predictor(_s): 355 | _s.predictor = lambda x: 0 356 | 357 | def simulate(_s, slow): 358 | if _s.share_gp: 359 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 360 | else: 361 | return np.reshape(_s.predictor(slow.reshape(1,-1)), (-1,)) 362 | 363 | def apply_stencil(_s, slow): 364 | # behold: the blackest of all black magic! 365 | # (in a year, I will not understand what this does) 366 | # the idea: shift xk's so that each row corresponds to the stencil: 367 | # (x_{k-1}, x_{k}, x_{k+1}), for example, 368 | # based on '_s.stencil' and 'slow' array (which is (x1,...,xK) ) 369 | return slow[np.add.outer(np.arange(_s.K), _s.stencil) % _s.K] 370 | 371 | 372 | 373 | 374 | # @jitclass(L96spec) 375 | class L96M: 376 | """ 377 | A simple class that implements Lorenz '96M model w/ slow and fast variables 378 | 379 | The class computes RHS's to make use of scipy's ODE solvers. 380 | 381 | Parameters: 382 | K, J, hx, hy, F, eps 383 | 384 | The convention is that the first K variables are slow, while the rest K*J 385 | variables are fast. 386 | """ 387 | 388 | def __init__(_s, 389 | K = 9, J = 8, hx = -0.8, hy = 1, F = 10, eps = 2**(-7), k0 = 0, slow_only=False, dima_style=False, share_gp=True, add_closure=False): 390 | ''' 391 | Initialize an instance: setting parameters and xkstar 392 | ''' 393 | hx = hx * np.ones(K) 394 | if hx.size != K: 395 | raise ValueError("'hx' must be a 1D-array of size 'K'") 396 | _s.predictor = None 397 | _s.dima_style = dima_style 398 | _s.share_gp = share_gp # if true, then GP is R->R and is applied to each state independently. 399 | # if share_gp=False, then GP is R^K -> R^K and is applied to the whole state vector at once. 400 | _s.slow_only = slow_only 401 | _s.K = K 402 | _s.J = J 403 | _s.hx = hx 404 | _s.hy = hy 405 | _s.F = F 406 | _s.eps = eps 407 | _s.k0 = k0 # for filtered integration 408 | _s.exchangeable_states = True 409 | # 0 410 | #_s.xk_star = np.random.rand(K) * 15 - 5 411 | # 1 412 | #_s.xk_star = np.ones(K) * 5 413 | # 2 414 | #_s.xk_star = np.ones(K) * 2 415 | #_s.xk_star[K//2:] = -0.2 416 | # 3 417 | _s.xk_star = 0.0 * np.zeros(K) 418 | _s.xk_star[0] = 5 419 | _s.xk_star[1] = 5 420 | _s.xk_star[-1] = 5 421 | _s.add_closure = add_closure 422 | 423 | def get_inits(_s, sigma = 15, mu = -5): 424 | z0 = np.zeros((_s.K + _s.K * _s.J)) 425 | z0[:_s.K] = mu + np.random.rand(_s.K) * sigma 426 | if _s.slow_only: 427 | return z0[:_s.K] 428 | else: 429 | for k_ in range(_s.K): 430 | z0[_s.K + k_*_s.J : _s.K + (k_+1)*_s.J] = z0[k_] 431 | return z0 432 | 433 | def get_state_limits(_s): 434 | if _s.K==4 and _s.J==4: 435 | lims = (-27.5, 36.5) 436 | elif _s.K==9 and _s.J==8: 437 | lims = (-9.5, 14.5) 438 | else: 439 | lims = (None,None) 440 | return lims 441 | 442 | def get_fast_state_names(_s): 443 | state_names = [] 444 | for k in range(_s.K): 445 | state_names += ['Y_' + str(j+1) + ',' + str(k+1) for j in range(_s.J)] 446 | return state_names 447 | 448 | def get_slow_state_names(_s): 449 | state_names = ['X_'+ str(k+1) for k in range(_s.K)] 450 | return state_names 451 | 452 | def get_state_names(_s, get_all=False): 453 | state_names = _s.get_slow_state_names() 454 | if get_all or not _s.slow_only: 455 | state_names += _s.get_fast_state_names() 456 | return state_names 457 | 458 | def get_fast_state_indices(_s): 459 | return np.arange(_s.K, _s.K + _s.K * _s.J) 460 | 461 | def plot_state_indices(_s): 462 | if _s.slow_only: 463 | return [0, 1, _s.K-1, _s.K-2] # return a 4 coupled slow variables 464 | else: 465 | return [0, _s.K] # return 1st slow variable and 1st coupled fast variable 466 | 467 | def set_predictor(_s, predictor): 468 | _s.predictor = predictor 469 | 470 | def set_G0_predictor(_s): 471 | _s.predictor = lambda x: _s.hy * x 472 | 473 | def set_null_predictor(_s): 474 | _s.predictor = lambda x: 0 475 | 476 | def set_stencil(_s, left = 0, right = 0): 477 | _s.stencil = np.arange(left, 1 + right) 478 | 479 | def hit_value(_s, k, val): 480 | return lambda t, z: z[k] - val 481 | 482 | def rhs(_s, z, t): 483 | if _s.slow_only: 484 | foo_rhs = _s.slow(z, t) 485 | else: 486 | foo_rhs = _s.full(z, t) 487 | if _s.add_closure: 488 | foo_rhs += _s.simulate(z) 489 | return foo_rhs 490 | 491 | def full(_s, z, t): 492 | ''' Full system RHS ''' 493 | K = _s.K 494 | J = _s.J 495 | rhs = np.empty(K + K*J) 496 | x = z[:K] 497 | y = z[K:] 498 | 499 | ### slow variables subsystem ### 500 | # compute Yk averages 501 | Yk = _s.compute_Yk(z) 502 | 503 | # three boundary cases 504 | rhs[0] = -x[K-1] * (x[K-2] - x[1]) - x[0] 505 | rhs[1] = -x[0] * (x[K-1] - x[2]) - x[1] 506 | rhs[K-1] = -x[K-2] * (x[K-3] - x[0]) - x[K-1] 507 | 508 | # general case 509 | rhs[2:K-1] = -x[1:K-2] * (x[0:K-3] - x[3:K]) - x[2:K-1] 510 | 511 | # add forcing 512 | rhs[:K] += _s.F 513 | 514 | # add coupling w/ fast variables via averages 515 | # XXX verify this (twice: sign and vector-vector multiplication) 516 | rhs[:K] += _s.hx * Yk 517 | #rhs[:K] -= _s.hx * Yk 518 | 519 | ### fast variables subsystem ### 520 | # three boundary cases 521 | rhs[K] = -y[1] * (y[2] - y[-1]) - y[0] 522 | rhs[-2] = -y[-1] * (y[0] - y[-3]) - y[-2] 523 | rhs[-1] = -y[0] * (y[1] - y[-2]) - y[-1] 524 | 525 | # general case 526 | rhs[K+1:-2] = -y[2:-1] * (y[3:] - y[:-3]) - y[1:-2] 527 | 528 | # add coupling w/ slow variables 529 | for k in range(K): 530 | rhs[K + k*J : K + (k+1)*J] += _s.hy * x[k] 531 | 532 | # divide by epsilon 533 | rhs[K:] /= _s.eps 534 | 535 | return rhs 536 | 537 | def decoupled(_s, z, t): 538 | ''' Only fast variables with fixed slow ones to verify ergodicity ''' 539 | K = _s.K 540 | J = _s.J 541 | _i = _s.fidx_dec 542 | rhs = np.empty(K*J) 543 | 544 | ## boundary: k = 0 545 | # boundary: j = 0, j = J-2, j = J-1 546 | rhs[_i(0,0)] = \ 547 | -z[_i(1,0)] * (z[_i(2,0)] - z[_i(J-1,K-1)]) - z[_i(0,0)] 548 | rhs[_i(J-2,0)] = \ 549 | -z[_i(J-1,0)] * (z[_i(0,1)] - z[_i(J-3,0)]) - z[_i(J-2,0)] 550 | rhs[_i(J-1,0)] = \ 551 | -z[_i(0,1)] * (z[_i(1,1)] - z[_i(J-2,0)]) - z[_i(J-1,0)] 552 | # general (for k = 0) 553 | for j in range(1, J-2): 554 | rhs[_i(j,0)] = \ 555 | -z[_i(j+1,0)] * (z[_i(j+2,0)] - z[_i(j-1,0)]) - z[_i(j,0)] 556 | ## boundary: k = 0 (end) 557 | 558 | ## boundary: k = K-1 559 | # boundary: j = 0, j = J-2, j = J-1 560 | rhs[_i(0,K-1)] = \ 561 | -z[_i(1,K-1)] * (z[_i(2,K-1)] - z[_i(J-1,K-2)]) - z[_i(0,K-1)] 562 | rhs[_i(J-2,K-1)] = \ 563 | -z[_i(J-1,K-1)] * (z[_i(0,0)] - z[_i(J-3,K-1)]) - z[_i(J-2,K-1)] 564 | rhs[_i(J-1,K-1)] = \ 565 | -z[_i(0,0)] * (z[_i(1,0)] - z[_i(J-2,K-1)]) - z[_i(J-1,K-1)] 566 | # general (for k = K-1) 567 | for j in range(1, J-2): 568 | rhs[_i(j,K-1)] = \ 569 | -z[_i(j+1,K-1)] * (z[_i(j+2,K-1)] - z[_i(j-1,K-1)]) - z[_i(j,K-1)] 570 | ## boundary: k = K-1 (end) 571 | 572 | ## general case for k (w/ corresponding inner boundary conditions) 573 | for k in range(1, K-1): 574 | # boundary: j = 0, j = J-2, j = J-1 575 | rhs[_i(0,k)] = \ 576 | -z[_i(1,k)] * (z[_i(2,k)] - z[_i(J-1,k-1)]) - z[_i(0,k)] 577 | rhs[_i(J-2,k)] = \ 578 | -z[_i(J-1,k)] * (z[_i(0,k+1)] - z[_i(J-3,k)]) - z[_i(J-2,k)] 579 | rhs[_i(J-1,k)] = \ 580 | -z[_i(0,k+1)] * (z[_i(1,k+1)] - z[_i(J-2,k)]) - z[_i(J-1,k)] 581 | # general case for j 582 | for j in range(1, J-2): 583 | rhs[_i(j,k)] = \ 584 | -z[_i(j+1,k)] * (z[_i(j+2,k)] - z[_i(j-1,k)]) - z[_i(j,k)] 585 | 586 | ## add coupling w/ slow variables 587 | for k in range(0, K): 588 | rhs[k*J : (k+1)*J] += _s.hy * _s.xk_star[k] 589 | 590 | ## divide by epsilon 591 | rhs /= _s.eps 592 | 593 | return rhs 594 | 595 | def balanced(_s, x, t): 596 | ''' Only slow variables with balanced RHS ''' 597 | K = _s.K 598 | rhs = np.empty(K) 599 | 600 | # three boundary cases: k = 0, k = 1, k = K-1 601 | rhs[0] = -x[K-1] * (x[K-2] - x[1]) - (1 - _s.hx[0]*_s.hy) * x[0] 602 | rhs[1] = -x[0] * (x[K-1] - x[2]) - (1 - _s.hx[1]*_s.hy) * x[1] 603 | rhs[K-1] = -x[K-2] * (x[K-3] - x[0]) - (1 - _s.hx[K-1]*_s.hy) * x[K-1] 604 | 605 | # general case 606 | for k in range(2, K-1): 607 | rhs[k] = -x[k-1] * (x[k-2] - x[k+1]) - (1 - _s.hx[k]*_s.hy) * x[k] 608 | 609 | # add forcing 610 | rhs += _s.F 611 | 612 | return rhs 613 | 614 | def slow(_s, x, t): 615 | ''' Only slow variables with RHS learned from data ''' 616 | K = _s.K 617 | rhs = np.empty(K) 618 | 619 | # three boundary cases: k = 0, k = 1, k = K-1 620 | rhs[0] = -x[K-1] * (x[K-2] - x[1]) - x[0] 621 | rhs[1] = -x[0] * (x[K-1] - x[2]) - x[1] 622 | rhs[K-1] = -x[K-2] * (x[K-3] - x[0]) - x[K-1] 623 | 624 | # general case 625 | for k in range(2, K-1): 626 | rhs[k] = -x[k-1] * (x[k-2] - x[k+1]) - x[k] 627 | 628 | # add forcing 629 | rhs += _s.F 630 | 631 | return rhs 632 | 633 | def regressed(_s, x, t): 634 | ''' Only slow variables with RHS learned from data ''' 635 | K = _s.K 636 | rhs = np.empty(K) 637 | 638 | # three boundary cases: k = 0, k = 1, k = K-1 639 | rhs[0] = -x[K-1] * (x[K-2] - x[1]) - x[0] 640 | rhs[1] = -x[0] * (x[K-1] - x[2]) - x[1] 641 | rhs[K-1] = -x[K-2] * (x[K-3] - x[0]) - x[K-1] 642 | 643 | # general case 644 | for k in range(2, K-1): 645 | rhs[k] = -x[k-1] * (x[k-2] - x[k+1]) - x[k] 646 | 647 | # add forcing 648 | rhs += _s.F 649 | 650 | # add data-learned coupling term 651 | # XXX verify this (twice: sign and vector-vector multiplication) 652 | if _s.dima_style: 653 | rhs += _s.hx * _s.simulate(x) 654 | else: 655 | rhs += _s.simulate(x) 656 | 657 | return rhs 658 | 659 | def filtered(_s, t, z): 660 | ''' Only slow variables with one set of fast ones and RHS learned from data 661 | 662 | Vector z is of size (K + J), i.e. all slow variables + fast variables at k0 663 | ''' 664 | K = _s.K 665 | J = _s.J 666 | rhs = np.empty(K + J) 667 | 668 | ### slow variables subsystem ### 669 | # compute Yk average for k0 670 | Yk0 = z[K:].sum() / J 671 | 672 | # three boundary cases: k = 0, k = 1, k = K-1 673 | rhs[0] = -z[K-1] * (z[K-2] - z[1]) - z[0] 674 | rhs[1] = -z[0] * (z[K-1] - z[2]) - z[1] 675 | rhs[K-1] = -z[K-2] * (z[K-3] - z[0]) - z[K-1] 676 | 677 | # general case 678 | for k in range(2, K-1): 679 | rhs[k] = -z[k-1] * (z[k-2] - z[k+1]) - z[k] 680 | 681 | # add forcing 682 | rhs[:K] += _s.F 683 | 684 | # add coupling w/ fast variables via average for k0 685 | # NOTE This has to be tested; maybe predictor everywhere is better 686 | rhs[_s.k0] += _s.hx[_s.k0] * Yk0 687 | # add coupling w/ the rest via simulation 688 | wo_k0 = np.r_[:_s.k0, _s.k0+1:K] 689 | Yk_simul = _s.simulate(z[:K]) 690 | rhs[wo_k0] += _s.hx[wo_k0] * Yk_simul[wo_k0] 691 | #rhs[_s.k0] += _s.hx[_s.k0] * Yk_simul[_s.k0] 692 | 693 | ### fast variables subsystem ### 694 | # boundary: j = 0, j = J-2, j = J-1 695 | rhs[K] = -z[K+1] * (z[K+2] - z[-1]) - z[K] 696 | rhs[K+J-2] = -z[K+J-1] * (z[K] - z[K+J-3]) - z[K+J-2] 697 | rhs[K+J-1] = -z[K] * (z[K+1] - z[K+J-2]) - z[K+J-1] 698 | # general case for j 699 | for j in range(1, J-2): 700 | rhs[K+j] = -z[K+j+1] * (z[K+j+2] - z[K+j-1]) - z[K+j] 701 | 702 | ## add coupling w/ the k0 slow variable 703 | rhs[K:] += _s.hy * z[_s.k0] 704 | 705 | ## divide by epsilon 706 | rhs[K:] /= _s.eps 707 | 708 | return rhs 709 | 710 | def fidx(_s, j, k): 711 | """Fast-index evaluation (based on the convention, see class description)""" 712 | return _s.K + k*_s.J + j 713 | 714 | def fidx_dec(_s, j, k): 715 | """Fast-index evaluation for the decoupled system""" 716 | return k*_s.J + j 717 | 718 | def simulate(_s, slow): 719 | if _s.share_gp: 720 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 721 | else: 722 | return np.reshape(_s.predictor(slow.reshape(1,-1)), (-1,)) 723 | 724 | def simulate_OLD(_s, slow): 725 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 726 | 727 | def single_step_implied_Ybar(_s, Xnow, Xnext, delta_t): 728 | # use an euler scheme to back-out the implied avg Ybar_t from X_t and X_t+1 729 | Ybar = (Xnext - Xnow)/delta_t - _s.slow(x=Xnow, t=None) 730 | 731 | # divide by hx 732 | Ybar /= _s.hx 733 | 734 | return Ybar 735 | 736 | def implied_Ybar(_s, X_in, X_out, delta_t): 737 | # the idea is that X_in are true data coming from a test/training set 738 | # Xout(k) is the 1-step-ahed prediction associated to Xin(k). 739 | # In other words Xout(k) = Psi-ML(Xin(k)) 740 | T = X_in.shape[0] 741 | Ybar = np.zeros( (T, _s.K) ) 742 | for t in range(T): 743 | Ybar[t,:] = _s.single_step_implied_Ybar(Xnow=X_in[t,:], Xnext=X_out[t,:], delta_t=delta_t) 744 | return Ybar 745 | 746 | def compute_Yk(_s, z): 747 | return z[_s.K:].reshape( (_s.J, _s.K), order = 'F').sum(axis = 0) / _s.J 748 | # TODO delete these two lines after testing 749 | #_s.Yk = z[_s.K:].reshape( (_s.J, _s.K), order = 'F').sum(axis = 0) 750 | #_s.Yk /= _s.J 751 | 752 | def gather_pairs(_s, tseries): 753 | n = tseries.shape[1] 754 | pairs = np.empty( (_s.K * n, _s.stencil.size + 1) ) 755 | for j in range(n): 756 | pairs[_s.K * j : _s.K * (j+1), :-1] = _s.apply_stencil(tseries[:_s.K, j]) 757 | pairs[_s.K * j : _s.K * (j+1), -1] = _s.compute_Yk(tseries[:,j]) 758 | return pairs 759 | 760 | def gather_pairs_k0(_s, tseries): 761 | n = tseries.shape[1] 762 | pairs = np.empty( (n, 2) ) 763 | for j in range(n): 764 | pairs[j, 0] = tseries[_s.k0, j] 765 | pairs[j, 1] = tseries[_s.K:, j].sum() / _s.J 766 | return pairs 767 | 768 | def apply_stencil(_s, slow): 769 | # behold: the blackest of all black magic! 770 | # (in a year, I will not understand what this does) 771 | # the idea: shift xk's so that each row corresponds to the stencil: 772 | # (x_{k-1}, x_{k}, x_{k+1}), for example, 773 | # based on '_s.stencil' and 'slow' array (which is (x1,...,xK) ) 774 | return slow[np.add.outer(np.arange(_s.K), _s.stencil) % _s.K] 775 | 776 | ################################################################################ 777 | # end of L96M ################################################################## 778 | ################################################################################ 779 | # L63spec = [ 780 | # ('a', float32), # a simple scalar field 781 | # ('b', float32), # a simple scalar field 782 | # ('c', float32), # a simple scalar field 783 | # ] 784 | 785 | # @jitclass(L63spec) 786 | class L63: 787 | """ 788 | A simple class that implements Lorenz 63 model 789 | 790 | The class computes RHS's to make use of scipy's ODE solvers. 791 | 792 | Parameters: 793 | a, b, c 794 | 795 | """ 796 | 797 | def __init__(_s, 798 | a = 10, b = 28, c = 8/3, share_gp=True, add_closure=False): 799 | ''' 800 | Initialize an instance: setting parameters and xkstar 801 | ''' 802 | _s.share_gp = share_gp 803 | _s.a = a 804 | _s.b = b 805 | _s.c = c 806 | _s.K = 3 # state dims 807 | _s.hx = 1 # just useful when re-using L96 code 808 | _s.slow_only = False 809 | _s.exchangeable_states = False 810 | _s.add_closure = add_closure 811 | 812 | def get_inits(_s): 813 | (xmin, xmax) = (-10,10) 814 | (ymin, ymax) = (-20,30) 815 | (zmin, zmax) = (10,40) 816 | 817 | xrand = xmin+(xmax-xmin)*np.random.random() 818 | yrand = ymin+(ymax-ymin)*np.random.random() 819 | zrand = zmin+(zmax-zmin)*np.random.random() 820 | state_inits = np.array([xrand, yrand, zrand]) 821 | return state_inits 822 | 823 | def get_state_names(_s): 824 | return ['x','y','z'] 825 | 826 | def plot_state_indices(_s): 827 | return [0,1,2] 828 | 829 | def slow(_s, y, t): 830 | return _s.rhs(y,t) 831 | 832 | def rhs(_s, S, t): 833 | ''' Full system RHS ''' 834 | a = _s.a 835 | b = _s.b 836 | c = _s.c 837 | x = S[0] 838 | y = S[1] 839 | z = S[2] 840 | 841 | foo_rhs = np.empty(3) 842 | foo_rhs[0] = -a*x + a*y 843 | foo_rhs[1] = b*x - y - x*z 844 | foo_rhs[2] = -c*z + x*y 845 | 846 | if _s.add_closure: 847 | foo_rhs += _s.simulate(S) 848 | return foo_rhs 849 | 850 | def regressed(_s, x, t): 851 | ''' Only slow variables with RHS learned from data ''' 852 | rhs = _s.rhs(x,t) 853 | # add data-learned coupling term 854 | rhs += _s.simulate(x) 855 | return rhs 856 | 857 | def set_stencil(_s, left = 0, right = 0): 858 | _s.stencil = np.arange(left, 1 + right) 859 | 860 | def single_step_implied_Ybar(_s, Xnow, Xnext, delta_t): 861 | # use an euler scheme to back-out the implied avg Ybar_t from X_t and X_t+1 862 | Ybar = (Xnext - Xnow)/delta_t - _s.rhs(S=Xnow, t=None) 863 | 864 | return Ybar 865 | 866 | def implied_Ybar(_s, X_in, X_out, delta_t): 867 | # the idea is that X_in are true data coming from a test/training set 868 | # Xout(k) is the 1-step-ahed prediction associated to Xin(k). 869 | # In other words Xout(k) = Psi-ML(Xin(k)) 870 | T = X_in.shape[0] 871 | Ybar = np.zeros( (T, _s.K) ) 872 | for t in range(T): 873 | Ybar[t,:] = _s.single_step_implied_Ybar(Xnow=X_in[t,:], Xnext=X_out[t,:], delta_t=delta_t) 874 | return Ybar 875 | 876 | def get_state_limits(_s): 877 | lims = (None,None) 878 | return lims 879 | 880 | def set_predictor(_s, predictor): 881 | _s.predictor = predictor 882 | 883 | # def set_G0_predictor(_s): 884 | # _s.predictor = lambda x: _s.hy * x 885 | 886 | def set_null_predictor(_s): 887 | _s.predictor = lambda x: 0 888 | 889 | def simulate(_s, slow): 890 | if _s.share_gp: 891 | return np.reshape(_s.predictor(_s.apply_stencil(slow)), (-1,)) 892 | else: 893 | return np.reshape(_s.predictor(slow.reshape(1,-1)), (-1,)) 894 | 895 | def apply_stencil(_s, slow): 896 | # behold: the blackest of all black magic! 897 | # (in a year, I will not understand what this does) 898 | # the idea: shift xk's so that each row corresponds to the stencil: 899 | # (x_{k-1}, x_{k}, x_{k+1}), for example, 900 | # based on '_s.stencil' and 'slow' array (which is (x1,...,xK) ) 901 | return slow[np.add.outer(np.arange(_s.K), _s.stencil) % _s.K] 902 | 903 | ################################################################################ 904 | # end of L63 ################################################################## 905 | ################################################################################ 906 | -------------------------------------------------------------------------------- /data_generation/lorenz/run_ode_solver.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # # -*- coding: utf-8 -*- 4 | import pdb 5 | import numpy as np 6 | import pickle 7 | from time import time 8 | from scipy.integrate import solve_ivp 9 | from odelibrary import L63 10 | 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | 14 | # read in ODE class 15 | l63 = L63() 16 | 17 | # swap input order for expectation of scipy.integrate.solve_ivp 18 | f_ode = lambda t, y: l63.rhs(y, t) 19 | 20 | T1 = 100; T2 = 100; dt = 0.001; 21 | 22 | # INTEGRATION 23 | u0 = l63.get_inits() 24 | t0 = 0 25 | 26 | print("Integrating through an initial transient phase to reach the attractor...") 27 | tstart = time() 28 | t_span = [t0, T1] 29 | t_eval = np.array([t0+T1]) 30 | sol = solve_ivp(fun=f_ode, t_span=t_span, y0=u0, t_eval=t_eval, max_step=dt, method='RK45') 31 | 32 | print('took', '{:.2f}'.format((time() - tstart)/60),'minutes') 33 | 34 | print("Integrating trajectory on the attractor...") 35 | tstart = time() 36 | u0 = np.squeeze(sol.y) 37 | t_span = [t0, T2] 38 | t_eval_tmp = np.arange(t0, T2, dt) 39 | t_eval = np.zeros(len(t_eval_tmp)+1) 40 | t_eval[:-1] = t_eval_tmp 41 | t_eval[-1] = T2 42 | sol = solve_ivp(fun=f_ode, t_span=t_span, y0=u0, t_eval=t_eval, max_step=dt, method='RK45') 43 | u = sol.y.T 44 | 45 | data = { 46 | "T1":T1, 47 | "T2":T2, 48 | "dt":dt, 49 | "u":u, 50 | } 51 | 52 | print('took', '{:.2f}'.format((time() - tstart)/60),'minutes') 53 | 54 | # save data 55 | with open("data.pickle", "wb") as file: 56 | # Pickle the "data" dictionary using the highest protocol available. 57 | pickle.dump(data, file, pickle.HIGHEST_PROTOCOL) 58 | 59 | # plot trajectory 60 | T_plot = 20 61 | n_plot = int(T_plot/dt) 62 | K = u.shape[1] #number of ode states 63 | fig, axes = plt.subplots(nrows=K, ncols=1,figsize=(12, 6)) 64 | times = dt*np.arange(n_plot) 65 | pdb.set_trace() 66 | for k in range(K): 67 | axes[k].plot(times, u[:n_plot,k], linewidth=2) 68 | axes[k].set_ylabel('X_{k}'.format(k=k)) 69 | axes[k].set_xlabel('Time') 70 | fig.suptitle('Lorenz 63 Trajectory simulated with RK45') 71 | plt.savefig('l63trajectory') 72 | plt.close() 73 | -------------------------------------------------------------------------------- /dissipative_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # shape is the tuple shape of each instance 5 | def sample_uniform_spherical_shell(npoints: int, radii: float, shape: tuple): 6 | ndim = np.prod(shape) 7 | inner_radius, outer_radius = radii 8 | pts = [] 9 | for i in range(npoints): 10 | # uniformly sample radius 11 | samp_radius = np.random.uniform(inner_radius, outer_radius) 12 | vec = np.random.randn(ndim) # ref: https://mathworld.wolfram.com/SpherePointPicking.html 13 | vec /= np.linalg.norm(vec, axis=0) 14 | pts.append(np.reshape(samp_radius*vec, shape)) 15 | 16 | return np.array(pts) 17 | 18 | # Partitions of unity - input is real number, output is in interval [0,1] 19 | """ 20 | norm_of_x: real number input 21 | shift: x-coord of 0.5 point in graph of function 22 | scale: larger numbers make a steeper descent at shift x-coord 23 | """ 24 | def sigmoid_partition_unity(norm_of_x, shift, scale): 25 | return 1/(1 + torch.exp(scale * (norm_of_x - shift))) 26 | 27 | # Dissipative functions - input is point x in state space (practically, subset of R^n) 28 | """ 29 | inputs: input point in state space 30 | scale: real number 0 < scale < 1 that scales down input x 31 | """ 32 | def linear_scale_dissipative_target(inputs, scale): 33 | return scale * inputs 34 | 35 | """ 36 | Outputs prediction after post-processing according to: 37 | rho(|x|) * model(x) + (1 - rho(|x|)) * diss(x) 38 | 39 | x: input point as torch tensor 40 | model: torch model 41 | rho: partition of unity, a map from R to [0,1] 42 | diss: baseline dissipative map from R^n to R^n 43 | """ 44 | def part_unity_post_process(x, model, rho, diss): 45 | return rho(torch.norm(x)) * model(x).reshape(x.shape[0],) + (1 - rho(torch.norm(x))) * diss(x) -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from functools import reduce 5 | import operator 6 | 7 | # A simple feedforward neural network 8 | class DenseNet(torch.nn.Module): 9 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 10 | super(DenseNet, self).__init__() 11 | 12 | self.n_layers = len(layers) - 1 13 | 14 | assert self.n_layers >= 1 15 | 16 | self.layers = nn.ModuleList() 17 | 18 | for j in range(self.n_layers): 19 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 20 | 21 | if j != self.n_layers - 1: 22 | if normalize: 23 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 24 | 25 | self.layers.append(nonlinearity()) 26 | 27 | if out_nonlinearity is not None: 28 | self.layers.append(out_nonlinearity()) 29 | 30 | def forward(self, x): 31 | for _, l in enumerate(self.layers): 32 | x = l(x) 33 | 34 | return x 35 | 36 | def count_params(self): 37 | c = 0 38 | for p in self.parameters(): 39 | c += reduce(operator.mul, list(p.size())) 40 | 41 | return c -------------------------------------------------------------------------------- /models/fno_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import operator 7 | from functools import reduce 8 | 9 | import scipy.io 10 | 11 | import sys 12 | sys.path.append('../') 13 | from utilities import * 14 | 15 | torch.manual_seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | def compl_mul2d(a, b): 20 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 21 | return torch.einsum("bixy,ioxy->boxy", a, b) 22 | 23 | class SpectralConv2d(nn.Module): 24 | def __init__(self, in_channels, out_channels, modes1, modes2): 25 | super(SpectralConv2d, self).__init__() 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 29 | self.modes2 = modes2 30 | 31 | self.scale = (1 / (in_channels * out_channels)) 32 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 33 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 34 | 35 | def forward(self, x, size=None): 36 | if size==None: 37 | size = x.size(-1) 38 | 39 | batchsize = x.shape[0] 40 | #Compute Fourier coeffcients up to factor of e^(- something constant) 41 | x_ft = torch.fft.rfftn(x, dim=[2,3], norm="ortho") 42 | 43 | # Multiply relevant Fourier modes 44 | out_ft = torch.zeros(batchsize, self.out_channels, size, size//2 + 1, device=x.device, dtype=torch.cfloat) 45 | out_ft[:, :, :self.modes1, :self.modes2] = \ 46 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 47 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 48 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 49 | 50 | 51 | #Return to physical space 52 | x = torch.fft.irfftn(out_ft, s=(size, size), dim=[2,3], norm="ortho") 53 | return x 54 | 55 | class SimpleBlock2d(nn.Module): 56 | def __init__(self, in_dim, out_dim, domain_size, modes1, modes2, width): # assumes square domain 57 | super(SimpleBlock2d, self).__init__() 58 | 59 | self.modes1 = modes1 60 | self.modes2 = modes2 61 | 62 | self.width_list = [width*2//4, width*3//4, width*4//4, width*4//4, width*5//4] 63 | self.size_list = [domain_size] * 5 64 | self.grid_dim = 2 65 | 66 | self.fc0 = nn.Linear(in_dim+self.grid_dim, self.width_list[0]) 67 | 68 | self.conv0 = SpectralConv2d(self.width_list[0], self.width_list[1], self.modes1*4//4, self.modes2*4//4) 69 | self.conv1 = SpectralConv2d(self.width_list[1], self.width_list[2], self.modes1*3//4, self.modes2*3//4) 70 | self.conv2 = SpectralConv2d(self.width_list[2], self.width_list[3], self.modes1*2//4, self.modes2*2//4) 71 | self.conv3 = SpectralConv2d(self.width_list[3], self.width_list[4], self.modes1*2//4, self.modes2*2//4) 72 | self.w0 = nn.Conv1d(self.width_list[0], self.width_list[1], 1) 73 | self.w1 = nn.Conv1d(self.width_list[1], self.width_list[2], 1) 74 | self.w2 = nn.Conv1d(self.width_list[2], self.width_list[3], 1) 75 | self.w3 = nn.Conv1d(self.width_list[3], self.width_list[4], 1) 76 | 77 | self.fc1 = nn.Linear(self.width_list[4], self.width_list[4]*2) 78 | self.fc2 = nn.Linear(self.width_list[4]*2, self.width_list[4]*2) 79 | self.fc3 = nn.Linear(self.width_list[4]*2, out_dim) 80 | 81 | def forward(self, x): 82 | 83 | batchsize = x.shape[0] 84 | size_x, size_y= x.shape[1], x.shape[2] 85 | grid = self.get_grid(size_x, batchsize, x.device) 86 | size_list = self.size_list 87 | 88 | x = torch.cat((x, grid.permute(0, 2, 3, 1)), dim=-1) 89 | 90 | x = self.fc0(x) 91 | x = x.permute(0, 3, 1, 2) 92 | 93 | x1 = self.conv0(x, size_list[1]) 94 | x2 = self.w0(x.view(batchsize, self.width_list[0], size_list[0]**2)).view(batchsize, self.width_list[1], size_list[0], size_list[0]) 95 | # x2 = F.interpolate(x2, size=size_list[1], mode='trilinear') 96 | x = x1 + x2 97 | x = F.selu(x) 98 | 99 | x1 = self.conv1(x, size_list[2]) 100 | x2 = self.w1(x.view(batchsize, self.width_list[1], size_list[1]**2)).view(batchsize, self.width_list[2], size_list[1], size_list[1]) 101 | # x2 = F.interpolate(x2, size=size_list[2], mode='trilinear') 102 | x = x1 + x2 103 | x = F.selu(x) 104 | 105 | x1 = self.conv2(x, size_list[3]) 106 | x2 = self.w2(x.view(batchsize, self.width_list[2], size_list[2]**2)).view(batchsize, self.width_list[3], size_list[2], size_list[2]) 107 | # x2 = F.interpolate(x2, size=size_list[3], mode='trilinear') 108 | x = x1 + x2 109 | x = F.selu(x) 110 | 111 | x1 = self.conv3(x, size_list[4]) 112 | x2 = self.w3(x.view(batchsize, self.width_list[3], size_list[3]**2)).view(batchsize, self.width_list[4], size_list[3], size_list[3]) 113 | # x2 = F.interpolate(x2, size=size_list[4], mode='trilinear') 114 | x = x1 + x2 115 | 116 | x = x.permute(0, 2, 3, 1) 117 | x = self.fc1(x) 118 | x = F.selu(x) 119 | x = self.fc2(x) 120 | x = F.selu(x) 121 | x = self.fc3(x) 122 | return x 123 | 124 | def get_grid(self, S, batchsize, device): 125 | gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 126 | gridx = gridx.reshape(1, 1, S, 1).repeat([batchsize, 1, 1, S]) 127 | gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float) 128 | gridy = gridy.reshape(1, 1, 1, S).repeat([batchsize, 1, S, 1]) 129 | return torch.cat((gridx, gridy), dim=1).to(device) 130 | 131 | class Net2d(nn.Module): 132 | def __init__(self, in_dim, out_dim, domain_size, modes, width): 133 | super(Net2d, self).__init__() 134 | self.conv1 = SimpleBlock2d(in_dim, out_dim, domain_size, modes, modes, width) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | return x 139 | 140 | def count_params(self): 141 | c = 0 142 | for p in self.parameters(): 143 | c += reduce(operator.mul, list(p.size())) 144 | 145 | return c -------------------------------------------------------------------------------- /scripts/NS_fno_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | sys.path.append('../') 8 | from utilities import * 9 | 10 | sys.path.append('../models') 11 | from fno_2d import * 12 | 13 | from timeit import default_timer 14 | import scipy.io 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | # Main 20 | ntrain = 900 21 | ntest = 100 22 | 23 | modes = 20 24 | width = 128 25 | 26 | in_dim = 1 27 | out_dim = 1 28 | 29 | batch_size = 50 30 | epochs = 50 31 | learning_rate = 0.0005 32 | scheduler_step = 10 33 | scheduler_gamma = 0.5 34 | 35 | loss_k = 0 # H0 Sobolev loss = L2 loss 36 | loss_group = True 37 | 38 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 39 | 40 | path = 'NS_fno_N'+str(ntrain)+'_k' + str(loss_k)+'_g' + str(loss_group)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 41 | path_model = 'model/'+path 42 | 43 | sub = 1 # spatial subsample 44 | S = 64 45 | 46 | T_in = 100 # skip first 100 seconds of each trajectory to let trajectory reach attractor 47 | T = 400 # seconds to extract from each trajectory in data 48 | T_out = T_in + T 49 | step = 1 # Seconds to learn solution operator 50 | 51 | t1 = default_timer() 52 | data = np.load('../data/KFvorticity_Re500_N1000_T500.npy') 53 | data = torch.tensor(data, dtype=torch.float)[..., ::sub, ::sub] 54 | 55 | train_a = data[:ntrain,T_in-1:T_out-1].reshape(ntrain*T, S, S) 56 | train_u = data[:ntrain,T_in:T_out].reshape(ntrain*T, S, S) 57 | 58 | test_a = data[-ntest:,T_in-1:T_out-1].reshape(ntest*T, S, S) 59 | test_u = data[-ntest:,T_in:T_out].reshape(ntest*T, S, S) 60 | 61 | assert (S == train_u.shape[2]) 62 | 63 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 64 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 65 | 66 | t2 = default_timer() 67 | 68 | print('preprocessing finished, time used:', t2-t1) 69 | device = torch.device('cuda') 70 | 71 | # Model 72 | model = Net2d(in_dim, out_dim, S, modes, width).cuda() 73 | print(model.count_params()) 74 | 75 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 76 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 77 | 78 | lploss = LpLoss(size_average=False) 79 | h1loss = HsLoss(k=1, group=False, size_average=False) 80 | h2loss = HsLoss(k=2, group=False, size_average=False) 81 | myloss = HsLoss(k=loss_k, group=loss_group, size_average=False) 82 | 83 | # Training 84 | for ep in range(1, epochs + 1): 85 | model.train() 86 | t1 = default_timer() 87 | train_loss = 0 88 | for x, y in train_loader: 89 | x = x.to(device).view(batch_size, S, S, in_dim) 90 | y = y.to(device).view(batch_size, S, S, out_dim) 91 | 92 | out = model(x).reshape(batch_size, S, S, out_dim) 93 | loss = myloss(out, y) 94 | train_loss += loss.item() 95 | 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | test_l2 = 0 101 | test_h1 = 0 102 | test_h2 = 0 103 | with torch.no_grad(): 104 | for x, y in test_loader: 105 | x = x.to(device).view(batch_size, S, S, in_dim) 106 | y = y.to(device).view(batch_size, S, S, out_dim) 107 | 108 | out = model(x).reshape(batch_size, S, S, out_dim) 109 | test_l2 += lploss(out, y).item() 110 | test_h1 += h1loss(out, y).item() 111 | test_h2 += h2loss(out, y).item() 112 | 113 | t2 = default_timer() 114 | scheduler.step() 115 | print("Epoch " + str(ep) + " completed in " + "{0:.{1}f}".format(t2-t1, 3) + " seconds. Train err:", "{0:.{1}f}".format(train_loss/(ntrain*T), 3), "Test L2 err:", "{0:.{1}f}".format(test_l2/(ntest*T), 3), "Test H1 err:", "{0:.{1}f}".format(test_h1/(ntest*T), 3), "Test H2 err:", "{0:.{1}f}".format(test_h2/(ntest*T), 3)) 116 | 117 | torch.save(model, path_model) 118 | print("Weights saved to", path_model) 119 | 120 | model.eval() 121 | test_a = test_a[0,:,:] 122 | 123 | # Long-time prediction 124 | T = 10000 125 | pred = torch.zeros(S,S,T) 126 | out = test_a.reshape(1,S,S).cuda() 127 | with torch.no_grad(): 128 | for i in range(T): 129 | out = model(out.reshape(1,S,S,in_dim)) 130 | pred[:,:,i] = out.view(S,S) 131 | 132 | pred_path = 'pred/'+path+'.mat' 133 | scipy.io.savemat(pred_path, mdict={'pred': pred.cpu().numpy()}) 134 | print("10000 seconds of predictions saved to", pred_path) 135 | -------------------------------------------------------------------------------- /scripts/NS_mno_dissipative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | sys.path.append('../') 8 | from utilities import * 9 | from dissipative_utils import sample_uniform_spherical_shell, linear_scale_dissipative_target 10 | 11 | sys.path.append('../models') 12 | from fno_2d import * 13 | 14 | from timeit import default_timer 15 | import scipy.io 16 | 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | 20 | # Main 21 | ntrain = 900 22 | ntest = 100 23 | 24 | S = 64 25 | 26 | # DISSIPATIVE REGULARIZATION PARAMETERS 27 | # below, the number before multiplication by S is the radius in the L2 norm of the function space 28 | radius = 156.25 * S # radius of inner ball 29 | scale_down = 0.5 # rate at which to linearly scale down inputs 30 | loss_weight = 0.01 * (S**2) # normalized by L2 norm in function space 31 | radii = (radius, (525 * S) + radius) # inner and outer radii, in L2 norm of function space 32 | sampling_fn = sample_uniform_spherical_shell #numsampled is batch size 33 | target_fn = linear_scale_dissipative_target 34 | 35 | modes = 20 36 | width = 64 37 | 38 | in_dim = 1 39 | out_dim = 1 40 | 41 | batch_size = 50 42 | 43 | epochs = 50 44 | learning_rate = 0.0005 45 | scheduler_step = 10 46 | scheduler_gamma = 0.5 47 | 48 | loss_k = 1 49 | loss_group = True 50 | 51 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 52 | 53 | path = 'NS_fourier_MNO_dissipative_N_'+str(ntrain)+'_k' + str(loss_k)+'_g' + str(loss_group)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width) 54 | path_model = 'model/'+path 55 | 56 | sub = 1 # spatial subsample 57 | S = 64 58 | 59 | T_in = 100 # skip first 100 seconds of each trajectory to let trajectory reach attractor 60 | T = 400 # seconds to extract from each trajectory in data 61 | T_out = T_in + T 62 | step = 1 # Seconds to learn solution operator 63 | 64 | t1 = default_timer() 65 | data = np.load('../data/KFvorticity_Re500_N1000_T500.npy') 66 | data = torch.tensor(data, dtype=torch.float)[..., ::sub, ::sub] 67 | 68 | train_a = data[:ntrain,T_in-1:T_out-1].reshape(ntrain*T, S, S) 69 | train_u = data[:ntrain,T_in:T_out].reshape(ntrain*T, S, S) 70 | 71 | test_a = data[-ntest:,T_in-1:T_out-1].reshape(ntest*T, S, S) 72 | test_u = data[-ntest:,T_in:T_out].reshape(ntest*T, S, S) 73 | 74 | assert (S == train_u.shape[2]) 75 | 76 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 77 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 78 | 79 | t2 = default_timer() 80 | 81 | print('preprocessing finished, time used:', t2-t1) 82 | device = torch.device('cuda') 83 | 84 | # Model 85 | model = Net2d(in_dim, out_dim, S, modes, width).cuda() 86 | print(model.count_params()) 87 | 88 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 89 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 90 | 91 | lploss = LpLoss(size_average=False) 92 | h1loss = HsLoss(k=1, group=False, size_average=False) 93 | h2loss = HsLoss(k=2, group=False, size_average=False) 94 | myloss = HsLoss(k=loss_k, group=loss_group, size_average=False) 95 | dissloss = nn.MSELoss(reduction='mean') 96 | 97 | # Training 98 | for ep in range(1, epochs + 1): 99 | model.train() 100 | t1 = default_timer() 101 | train_loss = 0 102 | diss_l2 = 0 103 | for x, y in train_loader: 104 | x = x.to(device).view(batch_size, S, S, in_dim) 105 | y = y.to(device).view(batch_size, S, S, out_dim) 106 | 107 | out = model(x).reshape(batch_size, S, S, out_dim) 108 | data_loss = myloss(out, y) 109 | train_loss += data_loss.item() 110 | 111 | x_diss = torch.tensor(sampling_fn(x.shape[0], radii, (S, S, 1)), dtype=torch.float).to(device) 112 | assert(x_diss.shape == x.shape) 113 | y_diss = torch.tensor(target_fn(x_diss, scale_down), dtype=torch.float).to(device) 114 | out_diss = model(x_diss).reshape(-1, out_dim) 115 | diss_loss = (1/(S**2)) * loss_weight * dissloss(out_diss, y_diss.reshape(-1, out_dim)) # weighted by 1 / (S**2) 116 | diss_l2 += diss_loss.item() 117 | 118 | loss = data_loss + diss_loss 119 | 120 | optimizer.zero_grad() 121 | loss.backward() 122 | optimizer.step() 123 | 124 | test_l2 = 0 125 | test_h1 = 0 126 | test_h2 = 0 127 | with torch.no_grad(): 128 | for x, y in test_loader: 129 | x = x.to(device).view(batch_size, S, S, in_dim) 130 | y = y.to(device).view(batch_size, S, S, out_dim) 131 | 132 | out = model(x).reshape(batch_size, S, S, out_dim) 133 | test_l2 += lploss(out, y).item() 134 | test_h1 += h1loss(out, y).item() 135 | test_h2 += h2loss(out, y).item() 136 | 137 | t2 = default_timer() 138 | scheduler.step() 139 | print("Epoch " + str(ep) + " completed in " + "{0:.{1}f}".format(t2-t1, 3) + " seconds. Train err:", "{0:.{1}f}".format(train_loss/(ntrain*T), 3), "Test L2 err:", "{0:.{1}f}".format(test_l2/(ntest*T), 3), "Test H1 err:", "{0:.{1}f}".format(test_h1/(ntest*T), 3), "Test H2 err:", "{0:.{1}f}".format(test_h2/(ntest*T), 3), "Train diss err:", "{0:.{1}f}".format(diss_l2/(ntrain), 3)) 140 | print(ep, t2 - t1, train_loss/(ntrain*T), test_l2/(ntest*T), test_h1/(ntest*T), test_h2/(ntest*T), diss_l2/(ntrain)) 141 | 142 | torch.save(model, path_model) 143 | print("Weights saved to", path_model) 144 | 145 | # Long-time prediction 146 | model.eval() 147 | test_a = test_a[0,:,:] 148 | 149 | T = 10000 150 | pred = torch.zeros(S,S,T) 151 | out = test_a.reshape(1,S,S).cuda() 152 | with torch.no_grad(): 153 | for i in range(T): 154 | out = model(out.reshape(1,S,S,in_dim)) 155 | pred[:,:,i] = out.view(S,S) 156 | 157 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 158 | print("10000 seconds of predictions saved to", pred_path) -------------------------------------------------------------------------------- /scripts/lorenz_densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from timeit import default_timer 7 | import scipy.io 8 | 9 | import sys 10 | sys.path.append('../') 11 | from utilities import * 12 | 13 | sys.path.append('../models') 14 | from densenet import * 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | device = torch.device('cuda') 19 | 20 | # Main 21 | ntrain = 160000 22 | ntest = 38000 23 | scale_inputs = False 24 | 25 | in_dim = 3 26 | out_dim = 3 27 | 28 | batch_size = 256 29 | epochs = 1000 30 | learning_rate = 0.0005 31 | 32 | layers = [in_dim, in_dim*50, in_dim*50, in_dim*50, in_dim*50, in_dim*50, in_dim*50, out_dim] # list of layer widths 33 | nonlinearity = nn.ReLU 34 | 35 | rel_loss = True # relative Lp loss 36 | 37 | scheduler_step = 100 38 | scheduler_gamma = 0.5 39 | 40 | print() 41 | print("Epochs:", epochs) 42 | print("Learning rate:", learning_rate) 43 | print("Scheduler step:", scheduler_step) 44 | print("Scheduler gamma:", scheduler_gamma) 45 | print() 46 | 47 | path = 'lorenz_densenet_relu_dt_0_05'+str(ntrain)+'_ep' + str(epochs) + '_lr' + str(learning_rate).replace('.','_') + '_schedstep' + str(scheduler_step).replace('.','_') + '_relLp' + str(rel_loss) + '_layers' + str(layers)[1:-1].replace(', ', '_') 48 | path_model = 'weights/'+path 49 | print(path) 50 | 51 | # Data 52 | sub = 6 # temporal subsampling rate 53 | steps_per_sec = 21 # given temporal subsampling, num of time-steps per second 54 | t1 = default_timer() 55 | 56 | predloader = MatReader('../data/L63T10000.mat') 57 | data = predloader.read_field('u')[::sub] 58 | data = torch.tensor(data, dtype=torch.float) 59 | 60 | train_a = data[:ntrain] 61 | train_u = data[1:ntrain+1] 62 | 63 | train_mean = torch.mean(train_a) 64 | train_max = torch.max(train_a) 65 | train_min = torch.min(train_a) 66 | 67 | if scale_inputs: 68 | train_a = (train_a - train_mean)/(train_max - train_min) 69 | train_u = (train_u - train_mean)/(train_max - train_min) 70 | 71 | test_a = data[-ntest:-1] 72 | test_u = data[-ntest + 1:] 73 | 74 | if scale_inputs: 75 | test_a = (test_a - train_mean)/(train_max - train_min) 76 | test_u = (test_u - train_mean)/(train_max - train_min) 77 | 78 | assert train_a.shape == train_u.shape 79 | assert test_a.shape == test_u.shape 80 | 81 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 82 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 83 | 84 | t2 = default_timer() 85 | 86 | print('preprocessing finished, time used:', t2-t1) 87 | print() 88 | 89 | # Model 90 | model = DenseNet(layers, nonlinearity).cuda() 91 | print("Model parameters:", model.count_params()) 92 | print() 93 | 94 | # Training 95 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 96 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 97 | 98 | if rel_loss: 99 | trainloss = LpLoss(size_average=False) 100 | testloss = LpLoss(size_average=False) 101 | testloss_1sec = LpLoss(size_average=False) 102 | else: 103 | trainloss = nn.MSELoss(reduction='sum') 104 | testloss = nn.MSELoss(reduction='sum') 105 | testloss_1sec = nn.MSELoss(reduction='sum') 106 | 107 | # Begin train 108 | for ep in range(1, epochs + 1): 109 | model.train() 110 | t1 = default_timer() 111 | train_l2 = 0 112 | one_sec_count = 0 113 | for x, y in train_loader: 114 | x = x.to(device).view(-1, out_dim) 115 | y = y.to(device).view(-1, out_dim) 116 | 117 | out = model(x).reshape(-1, out_dim) 118 | loss = trainloss(out, y) 119 | train_l2 += loss.item() 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | test_l2 = 0 126 | test_l2_1_sec = 0 127 | with torch.no_grad(): 128 | for x, y in test_loader: 129 | x = x.to(device).view(-1, out_dim) 130 | y = y.to(device).view(-1, out_dim) 131 | 132 | out = model(x).reshape(-1, out_dim) 133 | test_l2 += testloss(out, y).item() 134 | 135 | x_subsample = x[::steps_per_sec] 136 | x_1sec = x_subsample[:-2] # inputs 137 | y_1sec = x_subsample[1:-1] # ground truth 138 | out = x_1sec 139 | for i in range(steps_per_sec): 140 | out = model(out).reshape(-1, out_dim) 141 | test_1_sec_loss = testloss_1sec(out, y_1sec) 142 | test_l2_1_sec += test_1_sec_loss.item() 143 | one_sec_count += (int)(y_1sec.shape[0]) 144 | 145 | t2 = default_timer() 146 | scheduler.step() 147 | print("Epoch " + str(ep) + " completed in " + "{0:.{1}f}".format(t2-t1, 3) + " seconds. Train L2 err:", "{0:.{1}f}".format(train_l2/(ntrain), 3), "Test L2 err:", "{0:.{1}f}".format(test_l2/(ntest), 3), "Test L2 err over 1 sec:", "{0:.{1}f}".format(test_l2_1_sec/(one_sec_count), 3)) 148 | 149 | torch.save(model, path_model) 150 | print("Weights saved to", path_model) 151 | 152 | # Long-time prediction 153 | model.eval() 154 | test_a = test_a[0] 155 | 156 | T = 10000 * steps_per_sec 157 | pred = torch.zeros(T, out_dim) 158 | out = test_a.reshape(1,in_dim).cuda() 159 | with torch.no_grad(): 160 | for i in range(T): 161 | out = model(out.reshape(1,in_dim)) 162 | pred[i] = out.view(out_dim,) 163 | 164 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 165 | print("10000 seconds of predictions saved to", 'pred/'+path+'.mat') -------------------------------------------------------------------------------- /scripts/lorenz_dissipative_densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from timeit import default_timer 7 | import scipy.io 8 | 9 | import sys 10 | sys.path.append('../') 11 | from utilities import * 12 | from dissipative_utils import sample_uniform_spherical_shell, linear_scale_dissipative_target 13 | 14 | sys.path.append('../models') 15 | from densenet import * 16 | 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | 20 | # Main 21 | ntrain = 160000 22 | ntest = 38000 23 | scale_inputs = False 24 | 25 | # DISSIPATIVE REGULARIZATION PARAMETERS 26 | scale_down = 0.5 # rate at which to linearly scale down inputs 27 | radius = 90 # radius of inner ball 28 | loss_weight = 1 # weighting term between data loss and dissipativity term 29 | radii = (radius, radius + 40) # inner and outer radii 30 | sampling_fn = sample_uniform_spherical_shell # numsamples is batch size 31 | target_fn = linear_scale_dissipative_target 32 | 33 | in_dim = 3 34 | out_dim = 3 35 | 36 | batch_size = 256 37 | epochs = 1000 38 | learning_rate = 0.0005 39 | 40 | layers = [in_dim, in_dim*50, in_dim*50, in_dim*50, in_dim*50, in_dim*50, in_dim*50, out_dim] 41 | nonlinearity = nn.ReLU 42 | 43 | rel_loss = True # relative Lp loss 44 | 45 | scheduler_step = 100 46 | scheduler_gamma = 0.5 47 | 48 | print() 49 | print("Epochs:", epochs) 50 | print("Learning rate:", learning_rate) 51 | print("Scheduler step:", scheduler_step) 52 | print("Scheduler gamma:", scheduler_gamma) 53 | print() 54 | 55 | path = 'lorenz_dissipative_densenet_dt_0_05_inner_rad'+str(int(radius))+'_outer_rad'+str(int(radii[1]))+'_lambda'+str(scale_down).replace('.','_')+'_diss_weight'+str(loss_weight).replace('.','_')+'_time'+str(ntrain)+'_ep' + str(epochs) + '_lr' + str(learning_rate).replace('.','_') + '_schedstep' + str(scheduler_step).replace('.','_') + '_relLp' + str(rel_loss) + '_layers' + str(layers)[1:-1].replace(', ', '_') 56 | path_model = 'model/'+path 57 | print(path) 58 | 59 | # Data 60 | sub = 6 # temporal subsampling rate 61 | steps_per_sec = 21 # given temporal subsampling, num of time-steps per second 62 | t1 = default_timer() 63 | 64 | predloader = MatReader('../data/L63T10000.mat') 65 | data = predloader.read_field('u')[::sub] 66 | data = torch.tensor(data, dtype=torch.float) 67 | 68 | train_a = data[:ntrain] 69 | train_u = data[1:ntrain+1] 70 | 71 | train_mean = torch.mean(train_a) 72 | train_max = torch.max(train_a) 73 | train_min = torch.min(train_a) 74 | 75 | if scale_inputs: 76 | train_a = (train_a - train_mean)/(train_max - train_min) 77 | train_u = (train_u - train_mean)/(train_max - train_min) 78 | 79 | test_a = data[-ntest:-1] 80 | test_u = data[-ntest + 1:] 81 | 82 | if scale_inputs: 83 | test_a = (test_a - train_mean)/(train_max - train_min) 84 | test_u = (test_u - train_mean)/(train_max - train_min) 85 | 86 | assert train_a.shape == train_u.shape 87 | assert test_a.shape == test_u.shape 88 | 89 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True) 90 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False) 91 | 92 | t2 = default_timer() 93 | 94 | print('preprocessing finished, time used:', t2-t1) 95 | print() 96 | device = torch.device('cuda') 97 | 98 | # Model 99 | model = DenseNet(layers, nonlinearity).cuda() 100 | print("Model parameters:", model.count_params()) 101 | print() 102 | 103 | # Training 104 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 105 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 106 | 107 | if rel_loss: 108 | trainloss = LpLoss(size_average=False) 109 | dissloss = LpLoss(size_average=False) 110 | testloss = LpLoss(size_average=False) 111 | test_dissloss = LpLoss(size_average=False) 112 | testloss_1sec = LpLoss(size_average=False) 113 | else: 114 | trainloss = nn.MSELoss(reduction='sum') 115 | dissloss = nn.MSELoss(reduction='sum') 116 | testloss = nn.MSELoss(reduction='sum') 117 | test_dissloss = nn.MSELoss(reduction='sum') 118 | testloss_1sec = nn.MSELoss(reduction='sum') 119 | 120 | for ep in range(1, epochs + 1): 121 | model.train() 122 | t1 = default_timer() 123 | one_sec_count = 0 124 | train_l2 = 0 125 | diss_l2 = 0 126 | for x, y in train_loader: 127 | x = x.to(device).view(-1, out_dim) 128 | y = y.to(device).view(-1, out_dim) 129 | 130 | out = model(x).reshape(-1, out_dim) 131 | data_loss = trainloss(out, y) 132 | train_l2 += data_loss.item() 133 | 134 | x_diss = torch.tensor(sampling_fn(x.shape[0], radii, (in_dim,)), dtype=torch.float).to(device) 135 | assert(x_diss.shape == x.shape) 136 | y_diss = torch.tensor(target_fn(x_diss, scale_down), dtype=torch.float).to(device) 137 | out_diss = model(x_diss).reshape(-1, out_dim) 138 | diss_loss = loss_weight*dissloss(out_diss, y_diss) # weighted 139 | diss_l2 += diss_loss.item() 140 | 141 | loss = data_loss + diss_loss 142 | 143 | optimizer.zero_grad() 144 | loss.backward() 145 | optimizer.step() 146 | 147 | test_l2 = 0 148 | test_diss_l2 = 0 149 | test_l2_1_sec = 0 150 | with torch.no_grad(): 151 | for x, y in test_loader: 152 | x = x.to(device).view(-1, out_dim) 153 | y = y.to(device).view(-1, out_dim) 154 | 155 | out = model(x).reshape(-1, out_dim) 156 | test_l2 += testloss(out, y).item() 157 | 158 | x_diss = torch.tensor(sampling_fn(x.shape[0], radii, (in_dim,)), dtype=torch.float).to(device) 159 | assert(x_diss.shape == x.shape) 160 | y_diss = torch.tensor(target_fn(x_diss, scale_down), dtype=torch.float).to(device) 161 | out_diss = model(x_diss).reshape(-1, out_dim) 162 | test_diss_loss = test_dissloss(out_diss, y_diss) # unweighted 163 | test_diss_l2 += test_diss_loss.item() 164 | 165 | x_subsample = x[::steps_per_sec] 166 | x_1sec = x_subsample[:-2] # inputs 167 | y_1sec = x_subsample[1:-1] # ground truth 168 | out = x_1sec 169 | for i in range(steps_per_sec): 170 | out = model(out).reshape(-1, out_dim) 171 | test_1_sec_loss = testloss_1sec(out, y_1sec) 172 | test_l2_1_sec += test_1_sec_loss.item() 173 | one_sec_count += (int)(y_1sec.shape[0]) 174 | 175 | t2 = default_timer() 176 | scheduler.step() 177 | print("Epoch " + str(ep) + " completed in " + "{0:.{1}f}".format(t2-t1, 3) + " seconds. Train L2 err:", "{0:.{1}f}".format(train_l2/(ntrain), 3), "Test L2 err:", "{0:.{1}f}".format(test_l2/(ntest), 3), "Train diss. err:", "{0:.{1}f}".format(diss_l2/(ntrain), 3), "Test diss. err:", "{0:.{1}f}".format(test_diss_l2/(ntest), 3), "Test L2 err over 1 sec:", "{0:.{1}f}".format(test_l2_1_sec/(one_sec_count), 3)) 178 | 179 | torch.save(model, path_model) 180 | print("Weights saved to", path_model) 181 | 182 | # Long-time prediction 183 | model.eval() 184 | test_a = test_a[0] 185 | 186 | T = 10000 * steps_per_sec 187 | pred = torch.zeros(T, out_dim) 188 | out = test_a.reshape(1,in_dim).cuda() 189 | with torch.no_grad(): 190 | for i in range(T): 191 | out = model(out.reshape(1,in_dim)) 192 | pred[i] = out.view(out_dim,) 193 | 194 | scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()}) 195 | print("10000 seconds of predictions saved to", 'pred/'+path+'.mat') -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | import operator 7 | from functools import reduce 8 | 9 | ################################################# 10 | # 11 | # Utilities: 12 | # 13 | ################################################# 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | # PCA 17 | class PCA(object): 18 | def __init__(self, x, dim, subtract_mean=True): 19 | super(PCA, self).__init__() 20 | 21 | # Input size 22 | x_size = list(x.size()) 23 | 24 | # Input data is a matrix 25 | assert len(x_size) == 2 26 | 27 | # Reducing dimension is less than the minimum of the 28 | # number of observations and the feature dimension 29 | assert dim <= min(x_size) 30 | 31 | self.reduced_dim = dim 32 | 33 | if subtract_mean: 34 | self.x_mean = torch.mean(x, dim=0).view(1, -1) 35 | else: 36 | self.x_mean = torch.zeros((x_size[1],), dtype=x.dtype, layout=x.layout, device=x.device) 37 | 38 | # SVD 39 | U, S, V = torch.svd(x - self.x_mean) 40 | V = V.t() 41 | 42 | # Flip sign to ensure deterministic output 43 | max_abs_cols = torch.argmax(torch.abs(U), dim=0) 44 | signs = torch.sign(U[max_abs_cols, range(U.size()[1])]).view(-1, 1) 45 | V *= signs 46 | 47 | self.W = V.t()[:, 0:self.reduced_dim] 48 | self.sing_vals = S.view(-1, ) 49 | 50 | def cuda(self): 51 | self.W = self.W.cuda() 52 | self.x_mean = self.x_mean.cuda() 53 | self.sing_vals = self.sing_vals.cuda() 54 | 55 | def encode(self, x): 56 | return (x - self.x_mean).mm(self.W) 57 | 58 | def decode(self, x): 59 | return x.mm(self.W.t()) + self.x_mean 60 | 61 | def forward(self, x): 62 | return self.decode(self.encode(x)) 63 | 64 | def __call__(self, x): 65 | return self.forward(x) 66 | 67 | 68 | # reading data 69 | class MatReader(object): 70 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 71 | super(MatReader, self).__init__() 72 | 73 | self.to_torch = to_torch 74 | self.to_cuda = to_cuda 75 | self.to_float = to_float 76 | 77 | self.file_path = file_path 78 | 79 | self.data = None 80 | self.old_mat = True 81 | self.h5 = False 82 | self._load_file() 83 | 84 | def _load_file(self): 85 | 86 | if self.file_path[-3:] == '.h5': 87 | self.data = h5py.File(self.file_path, 'r') 88 | self.h5 = True 89 | 90 | else: 91 | try: 92 | self.data = scipy.io.loadmat(self.file_path) 93 | except: 94 | self.data = h5py.File(self.file_path, 'r') 95 | self.old_mat = False 96 | 97 | def load_file(self, file_path): 98 | self.file_path = file_path 99 | self._load_file() 100 | 101 | def read_field(self, field): 102 | x = self.data[field] 103 | 104 | if self.h5: 105 | x = x[()] 106 | 107 | if not self.old_mat: 108 | x = x[()] 109 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 110 | 111 | if self.to_float: 112 | x = x.astype(np.float32) 113 | 114 | if self.to_torch: 115 | x = torch.from_numpy(x) 116 | 117 | if self.to_cuda: 118 | x = x.cuda() 119 | 120 | return x 121 | 122 | def set_cuda(self, to_cuda): 123 | self.to_cuda = to_cuda 124 | 125 | def set_torch(self, to_torch): 126 | self.to_torch = to_torch 127 | 128 | def set_float(self, to_float): 129 | self.to_float = to_float 130 | 131 | # normalization, pointwise gaussian 132 | class UnitGaussianNormalizer(object): 133 | def __init__(self, x, eps=0.00001): 134 | super(UnitGaussianNormalizer, self).__init__() 135 | 136 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 137 | self.mean = torch.mean(x, 0) 138 | self.std = torch.std(x, 0) 139 | self.eps = eps 140 | 141 | def encode(self, x): 142 | x = (x - self.mean) / (self.std + self.eps) 143 | return x 144 | 145 | def decode(self, x, sample_idx=None): 146 | if sample_idx is None: 147 | std = self.std + self.eps # n 148 | mean = self.mean 149 | else: 150 | if len(self.mean.shape) == len(sample_idx[0].shape): 151 | std = self.std[sample_idx] + self.eps # batch*n 152 | mean = self.mean[sample_idx] 153 | if len(self.mean.shape) > len(sample_idx[0].shape): 154 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 155 | mean = self.mean[:,sample_idx] 156 | 157 | # x is in shape of batch*n or T*batch*n 158 | x = (x * std) + mean 159 | return x 160 | 161 | def cuda(self): 162 | self.mean = self.mean.cuda() 163 | self.std = self.std.cuda() 164 | 165 | def cpu(self): 166 | self.mean = self.mean.cpu() 167 | self.std = self.std.cpu() 168 | 169 | # normalization, Gaussian 170 | class GaussianNormalizer(object): 171 | def __init__(self, x, eps=0.00001): 172 | super(GaussianNormalizer, self).__init__() 173 | 174 | self.mean = torch.mean(x) 175 | self.std = torch.std(x) 176 | self.eps = eps 177 | 178 | def encode(self, x): 179 | x = (x - self.mean) / (self.std + self.eps) 180 | return x 181 | 182 | def decode(self, x, sample_idx=None): 183 | x = (x * (self.std + self.eps)) + self.mean 184 | return x 185 | 186 | def cuda(self): 187 | self.mean = self.mean.cuda() 188 | self.std = self.std.cuda() 189 | 190 | def cpu(self): 191 | self.mean = self.mean.cpu() 192 | self.std = self.std.cpu() 193 | 194 | 195 | # normalization, scaling by range 196 | class RangeNormalizer(object): 197 | def __init__(self, x, low=0.0, high=1.0): 198 | super(RangeNormalizer, self).__init__() 199 | mymin = torch.min(x, 0)[0].view(-1) 200 | mymax = torch.max(x, 0)[0].view(-1) 201 | 202 | self.a = (high - low)/(mymax - mymin) 203 | self.b = -self.a*mymax + high 204 | 205 | def encode(self, x): 206 | s = x.size() 207 | x = x.view(s[0], -1) 208 | x = self.a*x + self.b 209 | x = x.view(s) 210 | return x 211 | 212 | def decode(self, x): 213 | s = x.size() 214 | x = x.view(s[0], -1) 215 | x = (x - self.b)/self.a 216 | x = x.view(s) 217 | return x 218 | 219 | #loss function with rel/abs Lp loss 220 | class LpLoss(object): 221 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 222 | super(LpLoss, self).__init__() 223 | 224 | #Dimension and Lp-norm type are postive 225 | assert d > 0 and p > 0 226 | 227 | self.d = d 228 | self.p = p 229 | self.reduction = reduction 230 | self.size_average = size_average 231 | 232 | def abs(self, x, y): 233 | num_examples = x.size()[0] 234 | 235 | #Assume uniform mesh 236 | h = 1.0 / (x.size()[1] - 1.0) 237 | 238 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 239 | 240 | if self.reduction: 241 | if self.size_average: 242 | return torch.mean(all_norms) 243 | else: 244 | return torch.sum(all_norms) 245 | 246 | return all_norms 247 | 248 | def rel(self, x, y, std): 249 | num_examples = x.size()[0] 250 | 251 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 252 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 253 | 254 | if std == True: 255 | return torch.std(diff_norms / y_norms) 256 | 257 | if self.reduction: 258 | if self.size_average: 259 | return torch.mean(diff_norms / y_norms) 260 | else: 261 | return torch.sum(diff_norms / y_norms) 262 | return diff_norms / y_norms 263 | 264 | 265 | def __call__(self, x, y, std=False): 266 | return self.rel(x, y, std) 267 | 268 | 269 | class HsLoss(object): 270 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 271 | super(HsLoss, self).__init__() 272 | 273 | #Dimension and Lp-norm type are postive 274 | assert d > 0 and p > 0 275 | 276 | self.d = d 277 | self.p = p 278 | self.k = k 279 | self.balanced = group 280 | self.reduction = reduction 281 | self.size_average = size_average 282 | 283 | if a == None: 284 | a = [1,] * k 285 | self.a = a 286 | 287 | def rel(self, x, y): 288 | num_examples = x.size()[0] 289 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 290 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 291 | if self.reduction: 292 | if self.size_average: 293 | return torch.mean(diff_norms/y_norms) 294 | else: 295 | return torch.sum(diff_norms/y_norms) 296 | return diff_norms/y_norms 297 | 298 | def __call__(self, x, y, a=None): 299 | nx = x.size()[1] 300 | ny = x.size()[2] 301 | k = self.k 302 | balanced = self.balanced 303 | a = self.a 304 | x = x.view(x.shape[0], nx, ny, -1) 305 | y = y.view(y.shape[0], nx, ny, -1) 306 | 307 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 308 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 309 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 310 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 311 | 312 | x = torch.fft.fftn(x, dim=[1, 2]) 313 | y = torch.fft.fftn(y, dim=[1, 2]) 314 | 315 | if balanced==False: 316 | weight = 1 317 | if k >= 1: 318 | weight += a[0]**2 * (k_x**2 + k_y**2) 319 | if k >= 2: 320 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 321 | weight = torch.sqrt(weight) 322 | loss = self.rel(x*weight, y*weight) 323 | else: 324 | loss = self.rel(x, y) 325 | if k >= 1: 326 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 327 | loss += self.rel(x*weight, y*weight) 328 | if k >= 2: 329 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 330 | loss += self.rel(x*weight, y*weight) 331 | loss = loss / (k+1) 332 | 333 | return loss 334 | 335 | def pdist(sample_1, sample_2, norm=2, eps=1e-5): 336 | r"""Compute the matrix of all squared pairwise distances. 337 | Arguments 338 | --------- 339 | sample_1 : torch.Tensor or Variable 340 | The first sample, should be of shape ``(n_1, d)``. 341 | sample_2 : torch.Tensor or Variable 342 | The second sample, should be of shape ``(n_2, d)``. 343 | norm : float 344 | The l_p norm to be used. 345 | Returns 346 | ------- 347 | torch.Tensor or Variable 348 | Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to 349 | ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" 350 | n_1, n_2 = sample_1.size(0), sample_2.size(0) 351 | norm = float(norm) 352 | if norm == 2.: 353 | norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True) 354 | norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True) 355 | norms = (norms_1.expand(n_1, n_2) + 356 | norms_2.transpose(0, 1).expand(n_1, n_2)) 357 | distances_squared = norms - 2 * sample_1.mm(sample_2.t()) 358 | return torch.sqrt(eps + torch.abs(distances_squared)) 359 | else: 360 | dim = sample_1.size(1) 361 | expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim) 362 | expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim) 363 | differences = torch.abs(expanded_1 - expanded_2) ** norm 364 | inner = torch.sum(differences, dim=2, keepdim=False) 365 | return (eps + inner) ** (1. / norm) 366 | 367 | class MMDStatistic: 368 | r"""The *unbiased* MMD test of :cite:`gretton2012kernel`. 369 | The kernel used is equal to: 370 | .. math :: 371 | k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2}, 372 | for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`. 373 | Arguments 374 | --------- 375 | n_1: int 376 | The number of points in the first sample. 377 | n_2: int 378 | The number of points in the second sample.""" 379 | 380 | def __init__(self, n_1, n_2): 381 | self.n_1 = n_1 382 | self.n_2 = n_2 383 | 384 | # The three constants used in the test. 385 | self.a00 = 1. / (n_1 * (n_1 - 1)) 386 | self.a11 = 1. / (n_2 * (n_2 - 1)) 387 | self.a01 = - 1. / (n_1 * n_2) 388 | 389 | def __call__(self, sample_1, sample_2, alphas, ret_matrix=False): 390 | r"""Evaluate the statistic. 391 | The kernel used is 392 | .. math:: 393 | k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2}, 394 | for the provided ``alphas``. 395 | Arguments 396 | --------- 397 | sample_1: :class:`torch:torch.autograd.Variable` 398 | The first sample, of size ``(n_1, d)``. 399 | sample_2: variable of shape (n_2, d) 400 | The second sample, of size ``(n_2, d)``. 401 | alphas : list of :class:`float` 402 | The kernel parameters. 403 | ret_matrix: bool 404 | If set, the call with also return a second variable. 405 | This variable can be then used to compute a p-value using 406 | :py:meth:`~.MMDStatistic.pval`. 407 | Returns 408 | ------- 409 | :class:`float` 410 | The test statistic. 411 | :class:`torch:torch.autograd.Variable` 412 | Returned only if ``ret_matrix`` was set to true.""" 413 | sample_12 = torch.cat((sample_1, sample_2), 0) 414 | distances = pdist(sample_12, sample_12, norm=2) 415 | 416 | kernels = None 417 | for alpha in alphas: 418 | kernels_a = torch.exp(- alpha * distances ** 2) 419 | if kernels is None: 420 | kernels = kernels_a 421 | else: 422 | kernels = kernels + kernels_a 423 | 424 | k_1 = kernels[:self.n_1, :self.n_1] 425 | k_2 = kernels[self.n_1:, self.n_1:] 426 | k_12 = kernels[:self.n_1, self.n_1:] 427 | 428 | mmd = (2 * self.a01 * k_12.sum() + 429 | self.a00 * (k_1.sum() - torch.trace(k_1)) + 430 | self.a11 * (k_2.sum() - torch.trace(k_2))) 431 | if ret_matrix: 432 | return mmd, kernels 433 | else: 434 | return mmd 435 | 436 | def pval(self, distances, n_permutations=1000): 437 | r"""Compute a p-value using a permutation test. 438 | Arguments 439 | --------- 440 | matrix: :class:`torch:torch.autograd.Variable` 441 | The matrix computed using :py:meth:`~.MMDStatistic.__call__`. 442 | n_permutations: int 443 | The number of random draws from the permutation null. 444 | Returns 445 | ------- 446 | float 447 | The estimated p-value.""" 448 | if isinstance(distances, Variable): 449 | distances = distances.data 450 | return permutation_test_mat(distances.cpu().numpy(), 451 | self.n_1, self.n_2, 452 | n_permutations, 453 | a00=self.a00, a11=self.a11, a01=self.a01) 454 | 455 | 456 | 457 | #Compute stream function from vorticity (Fourier space) 458 | def stream_function(w, real_space=False): 459 | device = w.device 460 | s = w.shape[1] 461 | w_h = torch.rfft(w, 2, normalized=False, onesided=False) 462 | psi_h = w_h.clone() 463 | 464 | # Wavenumbers in y and x directions 465 | k_y = torch.cat((torch.arange(start=0, end=s // 2, step=1, dtype=torch.float32, device=device), \ 466 | torch.arange(start=-s // 2, end=0, step=1, dtype=torch.float32, device=device)), 467 | 0).repeat(s, 1) 468 | 469 | k_x = k_y.clone().transpose(0, 1) 470 | 471 | # Negative inverse Laplacian in Fourier space 472 | inv_lap = (k_x ** 2 + k_y ** 2) 473 | inv_lap[0, 0] = 1.0 474 | inv_lap = 1.0 / inv_lap 475 | 476 | #Stream function in Fourier space: solve Poisson equation 477 | psi_h[...,0] = inv_lap*psi_h[...,0] 478 | psi_h[...,1] = inv_lap*psi_h[...,1] 479 | 480 | return torch.irfft(psi_h, 2, normalized=False, onesided=False, signal_sizes=(s, s)) 481 | 482 | 483 | #Compute velocity field from stream function (Fourier space) 484 | def velocity_field(stream, real_space=True): 485 | device = stream.device 486 | s = stream.shape[1] 487 | 488 | stream_f = torch.rfft(stream, 2, normalized=False, onesided=False) 489 | # Wavenumbers in y and x directions 490 | k_y = torch.cat((torch.arange(start=0, end=s // 2, step=1, dtype=torch.float32, device=device), \ 491 | torch.arange(start=-s // 2, end=0, step=1, dtype=torch.float32, device=device)), 492 | 0).repeat(s, 1) 493 | k_x = k_y.clone().transpose(0, 1) 494 | 495 | #Velocity field in x-direction = psi_y 496 | q_h = stream_f.clone() 497 | temp = q_h[...,0].clone() 498 | q_h[...,0] = -k_y*q_h[...,1] 499 | q_h[...,1] = k_y*temp 500 | 501 | #Velocity field in y-direction = -psi_x 502 | v_h = stream_f.clone() 503 | temp = v_h[...,0].clone() 504 | v_h[...,0] = k_x*v_h[...,1] 505 | v_h[...,1] = -k_x*temp 506 | 507 | q = torch.irfft(q_h, 2, normalized=False, onesided=False, signal_sizes=(s, s)).squeeze(-1) 508 | v = torch.irfft(v_h, 2, normalized=False, onesided=False, signal_sizes=(s, s)).squeeze(-1) 509 | return torch.stack([q,v],dim=3) 510 | 511 | def curl3d(u): 512 | 513 | u = u.permute(-1,0,1,2) 514 | 515 | s = u.shape[1] 516 | kmax = s // 2 517 | device =u.device 518 | 519 | uh = torch.rfft(u, 3, normalized=False, onesided=False) 520 | # print(uh.shape) 521 | 522 | xh = uh[1, ..., :] 523 | yh = uh[0, ..., :] 524 | zh = uh[2, ..., :] 525 | 526 | k_x = torch.cat((torch.arange(start=0, end=kmax, step=1), torch.arange(start=-kmax, end=0, step=1)), 0).reshape( 527 | s, 1, 1).repeat(1, s, s).to(device) 528 | k_y = torch.cat((torch.arange(start=0, end=kmax, step=1), torch.arange(start=-kmax, end=0, step=1)), 0).reshape( 529 | 1, s, 1).repeat(s, 1, s).to(device) 530 | k_z = torch.cat((torch.arange(start=0, end=kmax, step=1), torch.arange(start=-kmax, end=0, step=1)), 0).reshape( 531 | 1, 1, s).repeat(s, s, 1).to(device) 532 | 533 | xdyh = torch.zeros(xh.shape).to(device) 534 | xdyh[..., 0] = - k_y * xh[..., 1] 535 | xdyh[..., 1] = k_y * xh[..., 0] 536 | xdy = torch.irfft(xdyh, 3, normalized=False, onesided=False) 537 | 538 | xdzh = torch.zeros(xh.shape).to(device) 539 | xdzh[..., 0] = - k_z * xh[..., 1] 540 | xdzh[..., 1] = k_z * xh[..., 0] 541 | xdz = torch.irfft(xdzh, 3, normalized=False, onesided=False) 542 | 543 | ydxh = torch.zeros(xh.shape).to(device) 544 | ydxh[..., 0] = - k_x * yh[..., 1] 545 | ydxh[..., 1] = k_x * yh[..., 0] 546 | ydx = torch.irfft(ydxh, 3, normalized=False, onesided=False) 547 | 548 | ydzh = torch.zeros(xh.shape).to(device) 549 | ydzh[..., 0] = - k_z * yh[..., 1] 550 | ydzh[..., 1] = k_z * yh[..., 0] 551 | ydz = torch.irfft(ydzh, 3, normalized=False, onesided=False) 552 | 553 | zdxh = torch.zeros(xh.shape).to(device) 554 | zdxh[..., 0] = - k_x * zh[..., 1] 555 | zdxh[..., 1] = k_x * zh[..., 0] 556 | zdx = torch.irfft(zdxh, 3, normalized=False, onesided=False) 557 | 558 | zdyh = torch.zeros(xh.shape).to(device) 559 | zdyh[..., 0] = - k_y * zh[..., 1] 560 | zdyh[..., 1] = k_y * zh[..., 0] 561 | zdy = torch.irfft(zdyh, 3, normalized=False, onesided=False) 562 | 563 | w = torch.zeros((s,s,s,3)).to(device) 564 | w[..., 0] = zdy - ydz 565 | w[..., 1] = xdz - zdx 566 | w[..., 2] = ydx - xdy 567 | 568 | return w 569 | 570 | def w_to_u(w): 571 | batchsize = w.size(0) 572 | nx = w.size(1) 573 | ny = w.size(2) 574 | 575 | device = w.device 576 | w = w.reshape(batchsize, nx, ny, -1) 577 | 578 | w_h = torch.fft.fft2(w, dim=[1, 2]) 579 | # Wavenumbers in y-direction 580 | k_max = nx // 2 581 | N = nx 582 | k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 583 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, 584 | N).reshape( 585 | 1, N, N, 1) 586 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 587 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 588 | 1).reshape( 589 | 1, N, N, 1) 590 | # Negative Laplacian in Fourier space 591 | lap = (k_x ** 2 + k_y ** 2) 592 | lap[0, 0, 0, 0] = 1.0 593 | f_h = w_h / lap 594 | 595 | ux_h = 1j * k_y * f_h 596 | uy_h = -1j * k_x * f_h 597 | 598 | ux = torch.fft.irfft2(ux_h[:, :, :k_max + 1], dim=[1, 2]) 599 | uy = torch.fft.irfft2(uy_h[:, :, :k_max + 1], dim=[1, 2]) 600 | u = torch.cat([ux, uy], dim=-1) 601 | return u 602 | 603 | def w_to_f(w): 604 | batchsize = w.size(0) 605 | nx = w.size(1) 606 | ny = w.size(2) 607 | 608 | device = w.device 609 | w = w.reshape(batchsize, nx, ny, 1) 610 | 611 | w_h = torch.fft.fft2(w, dim=[1, 2]) 612 | # Wavenumbers in y-direction 613 | k_max = nx // 2 614 | N = nx 615 | k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 616 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, 617 | N).reshape( 618 | 1, N, N, 1) 619 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 620 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 621 | 1).reshape( 622 | 1, N, N, 1) 623 | # Negative Laplacian in Fourier space 624 | lap = (k_x ** 2 + k_y ** 2) 625 | lap[0, 0, 0, 0] = 1.0 626 | f_h = w_h / lap 627 | 628 | f = torch.fft.irfft2(f_h[:, :, :k_max + 1], dim=[1, 2]) 629 | return f.reshape(batchsize, nx, ny, 1) 630 | 631 | def u_to_w(u): 632 | batchsize = u.size(0) 633 | nx = u.size(1) 634 | ny = u.size(2) 635 | 636 | device = u.device 637 | u = u.reshape(batchsize, nx, ny, 2) 638 | ux = u[..., 0] 639 | uy = u[..., 1] 640 | 641 | ux_h = torch.fft.fft2(ux, dim=[1, 2]) 642 | uy_h = torch.fft.fft2(uy, dim=[1, 2]) 643 | # Wavenumbers in y-direction 644 | k_max = nx // 2 645 | N = nx 646 | k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 647 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, 648 | N).reshape( 649 | 1, N, N) 650 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 651 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 652 | 1).reshape( 653 | 1, N, N) 654 | # Negative Laplacian in Fourier space 655 | uxdy_h = 1j * k_y * ux_h 656 | uydx_h = 1j * k_x * uy_h 657 | 658 | uxdy = torch.fft.irfft2(uxdy_h[:, :, :k_max + 1], dim=[1, 2]) 659 | uydx = torch.fft.irfft2(uydx_h[:, :, :k_max + 1], dim=[1, 2]) 660 | w = uydx - uxdy 661 | return w 662 | 663 | def u_to_f(u): 664 | return w_to_f(u_to_w(u)) 665 | 666 | def f_to_u(f): 667 | batchsize = f.size(0) 668 | nx = f.size(1) 669 | ny = f.size(2) 670 | 671 | device = f.device 672 | f = f.reshape(batchsize, nx, ny, -1) 673 | 674 | f_h = torch.fft.fft2(f, dim=[1, 2]) 675 | # Wavenumbers in y-direction 676 | k_max = nx // 2 677 | N = nx 678 | k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 679 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(N, 1).repeat(1, 680 | N).reshape( 681 | 1, N, N, 1) 682 | k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), 683 | torch.arange(start=-k_max, end=0, step=1, device=device)), 0).reshape(1, N).repeat(N, 684 | 1).reshape( 685 | 1, N, N, 1) 686 | # Negative Laplacian in Fourier space 687 | ux_h = 1j * k_y * f_h 688 | uy_h = -1j * k_x * f_h 689 | 690 | ux = torch.fft.irfft2(ux_h[:, :, :k_max + 1], dim=[1, 2]) 691 | uy = torch.fft.irfft2(uy_h[:, :, :k_max + 1], dim=[1, 2]) 692 | u = torch.stack([ux, uy], dim=-1) 693 | return u 694 | 695 | def f_to_w(f): 696 | return u_to_w(f_to_u(f)) 697 | 698 | # print the number of parameters 699 | def count_params(model): 700 | c = 0 701 | for p in list(model.parameters()): 702 | c += reduce(operator.mul, list(p.size())) 703 | return c 704 | --------------------------------------------------------------------------------