├── .gitignore ├── README.md ├── environment.yml ├── examples ├── generate_data.py ├── l96_EnKF_demo.py ├── l96_NN_demo.py ├── l96_correction_demo.py ├── l96_multiscale_param_est.py ├── l96_param_est_demo.py └── utils.py └── torchEnKF ├── __init__.py ├── da_methods.py ├── misc.py ├── nn_templates.py └── noise.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | **/*.ipynb_checkpoints/ 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | # Pycharm 144 | .idea/* 145 | 146 | # Mac 147 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto-differentiable Ensemble Kalman Filters (AD-EnKF) 2 | 3 | Joint learning of latent dynamics and states from noisy observations, by auto-differentiating through an Ensemble Kalman Filter (EnKF) using PyTorch. 4 | 5 | Getting started: 6 | - `l96_EnKF_demo.py`: Computation of parameter log-likelihood and gradient estimates with EnKF (Lorenz-96 model). 7 | - `l96_param_est_demo.py`: Parameter estimation in Lorenz-96 model with AD-EnKF (cf. Section 5.2.1 of paper). 8 | - `l96_NN_demo.py`: Learning Lorenz-96 dynamics and states with neural network and AD-EnKF (cf. Section 5.2.2 of paper). 9 | - `l96_correction_demo.py`: Correcting imperfect Lorenz-96 model with neural network and AD-EnKF (cf. Section 5.2.3 of paper). 10 | - `l96_multiscale_param_est.py`: Parameter estimation in multiscale Lorenz-96 model with AD-EnKF (working paper). 11 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: AD-EnKF-env 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2021.7.5=haa95532_1 8 | - certifi=2021.5.30=py37haa95532_0 9 | - cudatoolkit=10.2.89=h74a9793_1 10 | - cycler=0.10.0=py37_0 11 | - freetype=2.10.4=hd328e21_0 12 | - icc_rt=2019.0.0=h0cc432a_1 13 | - icu=58.2=ha925a31_3 14 | - intel-openmp=2021.3.0=haa95532_3372 15 | - jpeg=9b=hb83a4c4_2 16 | - kiwisolver=1.3.1=py37hd77b12b_0 17 | - libpng=1.6.37=h2a8f88b_0 18 | - libtiff=4.2.0=hd0e1b90_0 19 | - libuv=1.40.0=he774522_0 20 | - lz4-c=1.9.3=h2bbff1b_0 21 | - matplotlib=3.3.4=py37haa95532_0 22 | - matplotlib-base=3.3.4=py37h49ac443_0 23 | - mkl=2021.3.0=haa95532_524 24 | - mkl-service=2.4.0=py37h2bbff1b_0 25 | - mkl_fft=1.3.0=py37h277e83a_2 26 | - mkl_random=1.2.2=py37hf11a4ad_0 27 | - ninja=1.10.2=h6d14046_1 28 | - numpy=1.20.3=py37ha4e8547_0 29 | - numpy-base=1.20.3=py37hc2deb75_0 30 | - olefile=0.46=py37_0 31 | - openssl=1.1.1k=h2bbff1b_0 32 | - pillow=8.3.1=py37h4fa10fc_0 33 | - pip=21.1.3=py37haa95532_0 34 | - pyparsing=2.4.7=pyhd3eb1b0_0 35 | - pyqt=5.9.2=py37h6538335_2 36 | - python=3.7.10=h6244533_0 37 | - python-dateutil=2.8.2=pyhd3eb1b0_0 38 | - pytorch=1.9.0=py3.7_cuda10.2_cudnn7_0 39 | - qt=5.9.7=vc14h73c81de_0 40 | - scipy=1.6.2=py37h66253e8_1 41 | - setuptools=52.0.0=py37haa95532_0 42 | - sip=4.19.8=py37h6538335_0 43 | - six=1.16.0=pyhd3eb1b0_0 44 | - sqlite=3.36.0=h2bbff1b_0 45 | - tk=8.6.10=he774522_0 46 | - torchaudio=0.9.0=py37 47 | - torchvision=0.10.0=py37_cu102 48 | - tornado=6.1=py37h2bbff1b_0 49 | - tqdm=4.61.2=pyhd3eb1b0_1 50 | - typing_extensions=3.10.0.0=pyh06a4308_0 51 | - vc=14.2=h21ff451_1 52 | - vs2015_runtime=14.27.29016=h5e58377_2 53 | - wheel=0.36.2=pyhd3eb1b0_0 54 | - wincertstore=0.2=py37_0 55 | - xz=5.2.5=h62dcd97_0 56 | - zlib=1.2.11=h62dcd97_4 57 | - zstd=1.4.9=h19a0ad4_0 58 | - pip: 59 | - torchdiffeq==0.2.2 60 | variables: 61 | KMP_DUPLICATE_LIB_OK: 'True' 62 | prefix: C:\Users\cheny\Anaconda3\envs\AD-EnKF-env 63 | -------------------------------------------------------------------------------- /examples/generate_data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torchEnKF import nn_templates 4 | from torchdiffeq import odeint_adjoint 5 | from torchdiffeq import odeint 6 | from tqdm import tqdm 7 | import os 8 | import sys 9 | sys.path.append(os.getcwd()) # Fix Python path 10 | 11 | 12 | def generate(ode_func, obs_func, t_obs, x0, model_Q_param, noise_R_param, device, 13 | ode_method='rk4', ode_options=None, t0=0., time_varying_obs=False, tqdm=None): 14 | """ 15 | Generate state and observation data from a given state space model. 16 | 17 | Key args: 18 | ode_func (torch.nn.Module): Vector field f(t,x). 19 | obs_func (torch.nn.Module): Observation model h(x). 20 | If time varying_obs==True, can take a list of torch.nn.Module's 21 | t_obs (tensor): 1D-Tensor of shape (n_obs,). Time points where observations are available 22 | x0 (tensor): Tensor of shape (*bs, x_dim). Initial positions x0. 23 | '*bs' can be arbitrary batch dimension (or empty). 24 | model_Q_param (Noise.AddGaussian): model error covariance 25 | noise_R_param (Noise.AddGaussian): observation error covariance 26 | 27 | Optional args: 28 | ode_method: Numerical scheme for forward equation. We use 'euler' or 'rk4'. Other solvers are available. See https://github.com/rtqichen/torchdiffeq 29 | ode_options: Set it to dict(step_size=...) for fixed step solvers. Adaptive solvers are also available - see the link above. 30 | 31 | Returns: 32 | x_truth (tensor): Tensor of shape (n_obs, *bs, x_dim). States of the reference model. 33 | y_obs (tensor): Tensor of shape (n_obs, *bs, y_dim). Observations of the reference model. 34 | """ 35 | 36 | x_dim = x0.shape[-1] 37 | n_obs = t_obs.shape[0] 38 | bs = x0.shape[:-1] 39 | 40 | x_truth = torch.empty(n_obs, *bs, x_dim, dtype=x0.dtype, device=device) # (n_obs, *bs, x_dim) 41 | X = x0 42 | 43 | y_obs = None 44 | if obs_func is not None: 45 | obs_func_0 = obs_func[0] if time_varying_obs else obs_func 46 | y_dim = obs_func_0(x0).shape[-1] 47 | y_obs = torch.empty(n_obs, *bs, y_dim, dtype=x0.dtype, device=device) 48 | 49 | 50 | t_cur = t0 51 | 52 | pbar = tqdm(range(n_obs), desc="Generating data", leave=True) if tqdm is not None else range(n_obs) 53 | for j in pbar: 54 | _, X = odeint(ode_func, X, torch.tensor([t_cur, t_obs[j]], device=device), method=ode_method, options=ode_options) 55 | t_cur = t_obs[j] 56 | 57 | if model_Q_param is not None: 58 | X = model_Q_param(X) 59 | 60 | x_truth[j] = X 61 | 62 | if obs_func is not None: 63 | obs_func_j = obs_func[j] if time_varying_obs else obs_func 64 | HX = obs_func_j(X) 65 | y_obs[j] = noise_R_param(HX) 66 | 67 | return x_truth, y_obs 68 | 69 | -------------------------------------------------------------------------------- /examples/l96_EnKF_demo.py: -------------------------------------------------------------------------------- 1 | # from core_training import train_loop_diff, train_loop_em_new 2 | from tqdm import tqdm 3 | import os 4 | import sys 5 | sys.path.append(os.getcwd()) # Fix Python path 6 | 7 | from torchEnKF import da_methods, nn_templates, noise 8 | from examples import generate_data, utils 9 | 10 | import random 11 | import torch 12 | import numpy as np 13 | from torchdiffeq import odeint 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | # device="cpu" 17 | print(f"device: {device}") 18 | 19 | 20 | seed = 40 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | 25 | 26 | ######### Define reference model ######### 27 | x_dim = 40 28 | true_F = 8. 29 | true_coeff = torch.tensor([8., 0., 0., -1, 0., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0., 0., 1., 0.], device=device) 30 | true_ode_func = nn_templates.Lorenz96_dict_param(true_coeff, device, x_dim).to(device) 31 | 32 | 33 | ######### Warmup: Draw an initial point x0 from the L96 limit cycle. Can be ignored for problems with a smaller scale. ######### 34 | with torch.no_grad(): 35 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=50 * torch.eye(x_dim)).sample().to(device) 36 | t_warmup = torch.tensor([0., 120.]).to(device) 37 | x0 = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1] # Shape (x_dim,) 38 | 39 | 40 | ######### Generate data from the reference model ######### 41 | t0 = 0. 42 | t_obs_step = 0.05 43 | n_obs = 300 44 | t_obs = t_obs_step * torch.arange(1, n_obs+1).to(device) 45 | model_Q_true = None # No noise in dynamics 46 | indices = [i for i in range(x_dim)] 47 | y_dim = len(indices) 48 | H_true = torch.eye(x_dim)[indices] 49 | true_obs_func = nn_templates.Linear(x_dim, y_dim, H=H_true).to(device) # Full observation 50 | noise_R_true = noise.AddGaussian(y_dim, torch.tensor(1.), param_type='scalar').to(device) # Gaussian observation noise with std 1 51 | with torch.no_grad(): 52 | x_truth, y_obs = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0, model_Q_true, noise_R_true, device=device, 53 | ode_method='rk4', ode_options=dict(step_size=0.01), tqdm=tqdm) 54 | 55 | 56 | ########## Run EnKF with the reference model, compute log-likelihood and gradient ######### 57 | N_ensem = 50 58 | init_m = torch.zeros(x_dim, device=device) 59 | init_C_param = noise.AddGaussian(x_dim, 50 * torch.eye(x_dim), 'full').to(device) 60 | X, res, log_likelihood = da_methods.EnKF(true_ode_func,true_obs_func, t_obs, y_obs, N_ensem, init_m, init_C_param, model_Q_true, noise_R_true,device, 61 | save_filter_step={'mean'}, localization_radius=5, tqdm=tqdm) 62 | print(f"log-likelihood estimate: {log_likelihood}") 63 | burn_in = n_obs // 5 64 | print(f"Filter accuracy (RMSE): {torch.sqrt(utils.mse_loss(res['mean'][burn_in:], x_truth[burn_in:]))}") 65 | print("Computing gradient...") 66 | log_likelihood.backward() 67 | print(f"Gradient: {true_ode_func.coeff.grad}") 68 | 69 | ### The log-likelihood estimates and gradient estimates can be used for various purposes! -------------------------------------------------------------------------------- /examples/l96_NN_demo.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) # Fix Python path 5 | 6 | from torchEnKF import da_methods, nn_templates, noise 7 | from examples import generate_data, utils 8 | 9 | import random 10 | import torch 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torchdiffeq import odeint 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | # device="cpu" 17 | print(f"device: {device}") 18 | 19 | 20 | seed = 42 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | 25 | ######### Define reference model ######### 26 | x_dim = 40 27 | true_F = 8. 28 | true_coeff = torch.tensor([8., 0., 0., -1, 0., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0., 0., 1., 0.], device=device) 29 | true_ode_func = nn_templates.Lorenz96_dict_param(true_coeff, device, x_dim).to(device) 30 | 31 | 32 | ######### Warmup: Draw x0 from the L96 limit cycle. Can be ignored for problems with a smaller scale. ######### 33 | train_size = 8 34 | test_size = 4 35 | with torch.no_grad(): 36 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=25 * torch.eye(x_dim)).sample().to(device) 37 | t_warmup = 60 * torch.arange(0., train_size + test_size + 1).to(device) 38 | x0_train_and_test = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1:] 39 | x0_train = x0_train_and_test[:train_size] # Shape: (train_size, x_dim) 40 | x0_test = x0_train_and_test[train_size:] # Shape: (test_size, x_dim) 41 | 42 | ######### For computing forecast RMSE only: Draw P=500 points from the L96 limit cycle to computer forecast RMSE. ######## 43 | with torch.no_grad(): 44 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=25 * torch.eye(x_dim)).sample((50,)).to(device) 45 | t_warmup = torch.cat((torch.tensor([0.]), torch.arange(80.,280.,20.))).to(device) 46 | x0_fc = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1:].reshape(-1, x_dim) # Shape: (P, x_dim) 47 | 48 | 49 | ######### Generate training data from the reference model ######### 50 | t0 = 0. 51 | t_obs_step = 0.05 52 | n_obs = 1200 53 | t_obs = t_obs_step * torch.arange(1, n_obs+1).to(device) 54 | model_Q_true = None # No noise in dynamics 55 | indices = [i for i in range(x_dim)] 56 | y_dim = len(indices) 57 | H_true = torch.eye(x_dim)[indices] 58 | true_obs_func = nn_templates.Linear(x_dim, y_dim, H=H_true).to(device) # Observe every coordinates 59 | noise_R_true = noise.AddGaussian(y_dim, torch.tensor(1.), param_type='scalar').to(device) # Gaussian perturbation with std 1 60 | with torch.no_grad(): 61 | x_truth_train, y_obs_train = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0_train, model_Q_true, noise_R_true, device=device, 62 | ode_method='rk4', ode_options=dict(step_size=0.01), tqdm=tqdm) # Shape: (n_obs, train_size, y_dim) 63 | x_truth_test, y_obs_test = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0_test, model_Q_true, noise_R_true, device=device, 64 | ode_method='rk4', ode_options=dict(step_size=0.01),tqdm=tqdm) # Shape: (n_obs, test_size, y_dim) 65 | 66 | ######### NN training ######### 67 | N_ensem = 50 68 | init_m = torch.zeros(x_dim, device=device) 69 | init_C_param = noise.AddGaussian(x_dim, 25 * torch.eye(x_dim), 'full').to(device) 70 | init_Q = 2 * torch.ones(x_dim) 71 | learned_ode_func = nn_templates.L96_ODE_Net_2(x_dim).to(device) 72 | learned_model_Q = noise.AddGaussian(x_dim, init_Q, 'diag').to(device) 73 | optimizer = torch.optim.Adam([{'params':learned_ode_func.parameters(), 'lr':1e-2}, 74 | {'params':learned_model_Q.parameters(), 'lr':1e-1}]) 75 | lambda1 = lambda2 = lambda epoch: (epoch-9)**(-1) if epoch >=10 else 1 76 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 77 | L = 20 # subsequence length in AD-EnKF-T 78 | monitor = [] 79 | for epoch in tqdm(range(50), desc="Training", leave=False): 80 | train_log_likelihood = torch.zeros(train_size, device=device) 81 | train_state_est_loss = torch.tensor(0., device=device) 82 | t_start = t0 83 | X = init_C_param(init_m.expand(train_size, N_ensem, x_dim)) 84 | 85 | # Training phase 86 | for start in range(0, n_obs, L): 87 | optimizer.zero_grad() 88 | end = min(start + L, n_obs) 89 | X, res, log_likelihood = da_methods.EnKF(learned_ode_func,true_obs_func, t_obs[start:end], y_obs_train[start:end], N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true,device, 90 | save_filter_step={'mean'}, t0=t_start, init_X=X, ode_options=dict(step_size=0.01), adjoint_options=dict(step_size=0.05), localization_radius=5, tqdm=None) 91 | t_start = t_obs[end - 1] 92 | (-log_likelihood).mean().backward() 93 | train_log_likelihood += log_likelihood.detach().clone() 94 | train_state_est_loss += utils.mse_loss(res['mean'], x_truth_train[start:end]) * (end-start) 95 | optimizer.step() 96 | scheduler.step() 97 | 98 | # Testing 99 | with torch.no_grad(): 100 | filter_rmse = torch.sqrt(train_state_est_loss / n_obs) 101 | _, _, test_log_likelihood = da_methods.EnKF(learned_ode_func, true_obs_func, t_obs, y_obs_test, N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true, device, 102 | save_filter_step={}, linear_obs=True, ode_options=dict(step_size=0.01), adjoint_options=dict(step_size=0.05), localization_radius=5, tqdm=None) 103 | true_fc, _ = generate_data.generate(true_ode_func, None, torch.tensor([t_obs_step], device=device), x0_fc, None, None, device, ode_options=dict(step_size=0.01)) 104 | fc, _ = generate_data.generate(learned_ode_func, None, torch.tensor([t_obs_step], device=device), x0_fc, None, None, device, ode_options=dict(step_size=0.01)) 105 | forecast_rmse = torch.sqrt(utils.mse_loss(true_fc, fc)) 106 | 107 | curr_output = [forecast_rmse.item(), filter_rmse.item(), test_log_likelihood.mean().item()] 108 | monitor.append(curr_output) 109 | 110 | 111 | # Printing 112 | if epoch % 1 == 0: 113 | tqdm.write(f"Epoch {epoch}, Training log-likelihood: {train_log_likelihood.mean().item()}") 114 | tqdm.write(f"Epoch {epoch}, Test log-likelihood: {test_log_likelihood.mean().item()}") 115 | tqdm.write(f"Filter RMSE: {filter_rmse.item()}") 116 | tqdm.write(f"Forecast RMSE: {forecast_rmse.item()}") 117 | 118 | # Reproducing Figure 7, EnKF results 119 | monitor = np.asarray(monitor) 120 | fig, axes = plt.subplots(1, 3, figsize=(18, 6)) 121 | axes[0].plot(monitor[:,0]) 122 | axes[1].plot(monitor[:,1]) 123 | axes[2].plot(monitor[:,2]) 124 | axes[0].set_ylabel("Forecast RMSE") 125 | axes[1].set_ylabel("Filter RMSE") 126 | axes[2].set_ylabel("Test log likelihood") 127 | plt.show() -------------------------------------------------------------------------------- /examples/l96_correction_demo.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) # Fix Python path 5 | 6 | from torchEnKF import da_methods, nn_templates, noise 7 | from examples import generate_data, utils 8 | 9 | import random 10 | import torch 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torchdiffeq import odeint 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | # device="cpu" 17 | print(f"device: {device}") 18 | 19 | 20 | seed = 42 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | 25 | ######### Define reference model ######### 26 | x_dim = 40 27 | true_F = 8. 28 | true_coeff = torch.tensor([8., 0., 0., -1, 0., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0., 0., 1., 0.], device=device) 29 | true_ode_func = nn_templates.Lorenz96_dict_param(true_coeff, device, x_dim).to(device) 30 | 31 | 32 | ######### Warmup: Draw x0 from the L96 limit cycle. Can be ignored for problems with a smaller scale. ######### 33 | train_size = 8 34 | test_size = 4 35 | with torch.no_grad(): 36 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=25 * torch.eye(x_dim)).sample().to(device) 37 | t_warmup = 60 * torch.arange(0., train_size + test_size + 1).to(device) 38 | x0_train_and_test = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1:] 39 | x0_train = x0_train_and_test[:train_size] # Shape: (train_size, x_dim) 40 | x0_test = x0_train_and_test[train_size:] # Shape: (test_size, x_dim) 41 | 42 | ######### For computing forecast RMSE only: Draw P=500 points from the L96 limit cycle to computer forecast RMSE. ######## 43 | with torch.no_grad(): 44 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=25 * torch.eye(x_dim)).sample((50,)).to(device) 45 | t_warmup = torch.cat((torch.tensor([0.]), torch.arange(80.,280.,20.))).to(device) 46 | x0_fc = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1:].reshape(-1, x_dim) # Shape: (P, x_dim) 47 | 48 | 49 | ######### Generate training data from the reference model ######### 50 | t0 = 0. 51 | t_obs_step = 0.05 52 | n_obs = 1200 53 | t_obs = t_obs_step * torch.arange(1, n_obs+1).to(device) 54 | model_Q_true = None # No noise in dynamics 55 | indices = [i for i in range(x_dim)] 56 | y_dim = len(indices) 57 | H_true = torch.eye(x_dim)[indices] 58 | true_obs_func = nn_templates.Linear(x_dim, y_dim, H=H_true).to(device) # Observe every coordinates 59 | noise_R_true = noise.AddGaussian(y_dim, torch.tensor(1.), param_type='scalar').to(device) # Gaussian perturbation with std 1 60 | with torch.no_grad(): 61 | x_truth_train, y_obs_train = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0_train, model_Q_true, noise_R_true, device=device, 62 | ode_method='rk4', ode_options=dict(step_size=0.01), tqdm=tqdm) # Shape: (n_obs, train_size, y_dim) 63 | x_truth_test, y_obs_test = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0_test, model_Q_true, noise_R_true, device=device, 64 | ode_method='rk4', ode_options=dict(step_size=0.01),tqdm=tqdm) # Shape: (n_obs, test_size, y_dim) 65 | 66 | ######## Define NN correction model ############# 67 | pert_0, pert_1, pert_2 = 1, 0.1, 0.01 68 | pert_diag = torch.cat( (pert_0 * torch.ones(1), pert_1 * torch.ones(5), pert_2 * torch.ones(12)) ).to(device) 69 | pert_cov = torch.diag(pert_diag) 70 | pert = torch.distributions.MultivariateNormal(torch.zeros(18, device=device), covariance_matrix=pert_cov).sample() 71 | init_coeff = true_coeff + pert 72 | learned_ode_func = nn_templates.Lorenz96_correction(init_coeff, x_dim).to(device) 73 | 74 | ######### NN training ######### 75 | N_ensem = 50 76 | init_m = torch.zeros(x_dim, device=device) 77 | init_C_param = noise.AddGaussian(x_dim, 25 * torch.eye(x_dim), 'full').to(device) 78 | init_Q = 2 * torch.ones(x_dim) 79 | learned_model_Q = noise.AddGaussian(x_dim, init_Q, 'diag').to(device) 80 | optimizer = torch.optim.Adam([{'params':learned_ode_func.parameters(), 'lr':1e-3}, 81 | {'params':learned_model_Q.parameters(), 'lr':1e-1}]) 82 | lambda1 = lambda2 = lambda epoch: (epoch-9)**(-0.75) if epoch >=10 else 1 83 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 84 | L = 20 # subsequence length in AD-EnKF-T 85 | monitor = [] 86 | for epoch in tqdm(range(50), desc="Training", leave=False): 87 | train_log_likelihood = torch.zeros(train_size, device=device) 88 | train_state_est_loss = torch.tensor(0., device=device) 89 | t_start = t0 90 | X = init_C_param(init_m.expand(train_size, N_ensem, x_dim)) 91 | 92 | # Training phase 93 | for start in range(0, n_obs, L): 94 | optimizer.zero_grad() 95 | end = min(start + L, n_obs) 96 | X, res, log_likelihood = da_methods.EnKF(learned_ode_func,true_obs_func, t_obs[start:end], y_obs_train[start:end], N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true,device, 97 | save_filter_step={'mean'}, t0=t_start, init_X=X, ode_options=dict(step_size=0.01), adjoint_options=dict(step_size=0.05), localization_radius=5, tqdm=None) 98 | t_start = t_obs[end - 1] 99 | (-log_likelihood).mean().backward() 100 | train_log_likelihood += log_likelihood.detach().clone() 101 | train_state_est_loss += utils.mse_loss(res['mean'], x_truth_train[start:end]) * (end-start) 102 | optimizer.step() 103 | scheduler.step() 104 | 105 | # Testing 106 | with torch.no_grad(): 107 | filter_rmse = torch.sqrt(train_state_est_loss / n_obs) 108 | _, _, test_log_likelihood = da_methods.EnKF(learned_ode_func, true_obs_func, t_obs, y_obs_test, N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true, device, 109 | save_filter_step={}, ode_options=dict(step_size=0.01), adjoint_options=dict(step_size=0.05), localization_radius=5, tqdm=None) 110 | true_fc, _ = generate_data.generate(true_ode_func, None, torch.tensor([t_obs_step], device=device), x0_fc, None, None, device, ode_options=dict(step_size=0.01)) 111 | fc, _ = generate_data.generate(learned_ode_func, None, torch.tensor([t_obs_step], device=device), x0_fc, None, None, device, ode_options=dict(step_size=0.01)) 112 | forecast_rmse = torch.sqrt(utils.mse_loss(true_fc, fc)) 113 | 114 | curr_output = [forecast_rmse.item(), filter_rmse.item(), test_log_likelihood.mean().item()] 115 | monitor.append(curr_output) 116 | 117 | 118 | # Printing 119 | if epoch % 1 == 0: 120 | tqdm.write(f"Epoch {epoch}, Training log-likelihood: {train_log_likelihood.mean().item()}") 121 | tqdm.write(f"Epoch {epoch}, Test log-likelihood: {test_log_likelihood.mean().item()}") 122 | tqdm.write(f"Filter RMSE: {filter_rmse.item()}") 123 | tqdm.write(f"Forecast RMSE: {forecast_rmse.item()}") 124 | 125 | 126 | # Reproducing Figure 7, EnKF results 127 | monitor = np.asarray(monitor) 128 | fig, axes = plt.subplots(1, 3, figsize=(18, 6)) 129 | axes[0].plot(monitor[:,0]) 130 | axes[1].plot(monitor[:,1]) 131 | axes[2].plot(monitor[:,2]) 132 | axes[0].set_ylabel("Forecast RMSE") 133 | axes[1].set_ylabel("Filter RMSE") 134 | axes[2].set_ylabel("Test log likelihood") 135 | plt.show() -------------------------------------------------------------------------------- /examples/l96_multiscale_param_est.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) # Fix Python path 5 | 6 | from torchEnKF import da_methods, nn_templates, noise 7 | from examples import generate_data, utils 8 | 9 | import random 10 | import torch 11 | import numpy as np 12 | 13 | from torchdiffeq import odeint 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | # device="cpu" 17 | print(f"device: {device}") 18 | # torch.backends.cudnn.benchmark = True 19 | 20 | seed = 44 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | 25 | ######### Define reference model ######### 26 | xx_dim = 36 27 | xy_dim = 10 28 | x_dim = xx_dim * (xy_dim + 1) 29 | true_param = torch.tensor([10., 1., 10., 10.]) 30 | true_ode_func = nn_templates.Lorenz96_FS(true_param, device, xx_dim, xy_dim).to(device) 31 | 32 | ######### Warmup: Draw x0 from the L96 limit cycle. Can be ignored for problems with a smaller scale. ######### 33 | print("Warming up...") 34 | train_size = 10 35 | with torch.no_grad(): 36 | xx_cov = 25 * torch.eye(xx_dim) 37 | xy_cov = 0.25 * torch.eye(xx_dim * xy_dim) 38 | cov = torch.block_diag(xx_cov, xy_cov).to(device) 39 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim, device=device), covariance_matrix=cov).sample().to(device) # <- 40 | t_warmup = 20 * torch.arange(0., train_size + 1).to(device) 41 | x0 = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.005))[1:] 42 | 43 | ######### Generate training data from the reference model ######### 44 | t0 = 0. 45 | t_obs_step = 0.1 46 | n_obs = 100 47 | t_obs = t_obs_step * torch.arange(1, n_obs+1).to(device) 48 | model_Q_true = None # No noise in dynamics 49 | 50 | def true_obs_func(X): # define the observation model 51 | # (*bs, x_dim) -> (*bs, y_dim) 52 | bs = X.shape[:-1] 53 | to_cat = [] 54 | to_cat.append(X[..., :xx_dim]) 55 | Y_bar = X[..., xx_dim:].reshape(*bs, xx_dim, xy_dim).mean(dim=-1) 56 | to_cat.append(Y_bar) 57 | to_cat.append(X[..., :xx_dim]**2) 58 | to_cat.append(X[..., :xx_dim] * Y_bar) 59 | to_cat.append(((X[..., xx_dim:].reshape(*bs, xx_dim, xy_dim))**2).mean(dim=-1).view(*bs, xx_dim)) 60 | return torch.cat(to_cat, dim=-1) 61 | y_dim = true_obs_func(x0).shape[-1] 62 | xo_cov = 0.1 * torch.eye(xx_dim) 63 | yo_cov = 0.1 * torch.eye(xx_dim) 64 | xxo_cov = 0.1 * torch.eye(xx_dim) 65 | xyo_cov = 0.1 * torch.eye(xx_dim) 66 | yyo_cov = 0.1 * torch.eye(xx_dim) 67 | obs_cov = torch.block_diag(xo_cov, yo_cov, xxo_cov, xyo_cov, yyo_cov).to(device) 68 | noise_R_true = noise.AddGaussian(y_dim, obs_cov, param_type='full').to(device) 69 | 70 | 71 | with torch.no_grad(): 72 | x_truth, y_obs = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0, model_Q_true, noise_R_true, device=device, 73 | ode_method='rk4', ode_options=dict(step_size=0.0025), tqdm=tqdm) # Shape: (n_obs, train_size, y_dim) 74 | 75 | ######### Parameter estimation ######### 76 | N_ensem = 100 77 | init_m = torch.zeros(x_dim, device=device) 78 | init_C_param = noise.AddGaussian(x_dim, cov, 'full').to(device) 79 | init_coeff = torch.tensor([8., 0., 2., 2.],device=device) 80 | init_Q = 0.2 * torch.ones(x_dim) 81 | learned_ode_func = nn_templates.Lorenz96_FS(init_coeff, device, xx_dim, xy_dim).to(device) 82 | learned_model_Q = noise.AddGaussian(x_dim, init_Q, 'diag').to(device) 83 | optimizer = torch.optim.Adam([{'params':learned_ode_func.parameters(), 'lr':1e-1}, 84 | {'params':learned_model_Q.parameters(), 'lr':1e-1}]) 85 | lambda1 = lambda2 = lambda epoch: (epoch-49)**(-0.5) if epoch >=50 else 1 86 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 87 | L = 1 # subsequence length in AD-EnKF-T 88 | warm_up = 0 89 | for epoch in tqdm(range(100), desc="Training", leave=False): 90 | train_log_likelihood = torch.zeros(train_size, device=device) 91 | t_start = t0 92 | X = init_C_param(init_m.expand(train_size, N_ensem, x_dim)) 93 | 94 | # Warm-up phase. Time interval at the beginning that the gradients will not be recorded. But the filtered states will. (This is not presented in paper) 95 | with torch.no_grad(): 96 | X, res, log_likelihood = da_methods.EnKF(learned_ode_func, true_obs_func, t_obs[:warm_up], y_obs[:warm_up], N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true, device, 97 | save_filter_step={}, t0=t_start, init_X=X, ode_options=dict(step_size=0.0025), adjoint_options=dict(step_size=0.01), linear_obs=False) 98 | train_log_likelihood += log_likelihood 99 | t_start = t_obs[warm_up - 1] if warm_up >= 1 else t0 100 | 101 | for start in range(warm_up, n_obs, L): 102 | optimizer.zero_grad() 103 | end = min(start + L, n_obs) 104 | X, res, log_likelihood = da_methods.EnKF(learned_ode_func,true_obs_func, t_obs[start:end], y_obs[start:end], N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true,device, 105 | save_filter_step={}, t0=t_start, init_X=X, ode_options=dict(step_size=0.0025), adjoint_options=dict(step_size=0.01), linear_obs=False) 106 | t_start = t_obs[end - 1] 107 | (-log_likelihood).mean().backward() 108 | train_log_likelihood += log_likelihood.detach().clone() 109 | optimizer.step() 110 | scheduler.step() 111 | 112 | if epoch % 1 == 0: 113 | tqdm.write(f"Epoch {epoch}, Training log-likelihood: {train_log_likelihood.mean().item()}") 114 | tqdm.write(f"Learned coefficients: {learned_ode_func.param.data.cpu().numpy()}") 115 | tqdm.write(f"Learned q: {torch.sqrt(torch.trace(learned_model_Q.full())/x_dim).item()}") -------------------------------------------------------------------------------- /examples/l96_param_est_demo.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) # Fix Python path 5 | 6 | from torchEnKF import da_methods, nn_templates, noise 7 | from examples import generate_data, utils 8 | 9 | import random 10 | import torch 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torchdiffeq import odeint 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | # device="cpu" 17 | print(f"device: {device}") 18 | 19 | 20 | seed = 42 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | 25 | ######### Define reference model ######### 26 | x_dim = 40 27 | true_F = 8. 28 | true_coeff = torch.tensor([8., 0., 0., -1, 0., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0., 0., 1., 0.], device=device) 29 | true_ode_func = nn_templates.Lorenz96_dict_param(true_coeff, device, x_dim).to(device) 30 | 31 | 32 | ######### Warmup: Draw x0 from the L96 limit cycle. Can be ignored for problems with a smaller scale. ######### 33 | train_size = 4 34 | with torch.no_grad(): 35 | x0_warmup = torch.distributions.MultivariateNormal(torch.zeros(x_dim), covariance_matrix=25 * torch.eye(x_dim)).sample().to(device) # <- 36 | t_warmup = 120 * torch.arange(0., train_size + 1).to(device) 37 | x0 = odeint(true_ode_func, x0_warmup, t_warmup, method='rk4', options=dict(step_size=0.05))[1:] # Shape: (train_size, x_dim) 38 | 39 | 40 | ######### Generate training data from the reference model ######### 41 | t0 = 0. 42 | t_obs_step = 0.05 43 | n_obs = 150 44 | t_obs = t_obs_step * torch.arange(1, n_obs+1).to(device) 45 | model_Q_true = None # No noise in dynamics 46 | indices = [i for i in range(x_dim)] 47 | y_dim = len(indices) 48 | H_true = torch.eye(x_dim)[indices] 49 | true_obs_func = nn_templates.Linear(x_dim, y_dim, H=H_true).to(device) # Observe every coordinates 50 | noise_R_true = noise.AddGaussian(y_dim, torch.tensor(1.), param_type='scalar').to(device) # Gaussian perturbation with std 1 51 | with torch.no_grad(): 52 | x_truth, y_obs = generate_data.generate(true_ode_func, true_obs_func, t_obs, x0, model_Q_true, noise_R_true, device=device, 53 | ode_method='rk4', ode_options=dict(step_size=0.01), tqdm=tqdm) # Shape: (n_obs, train_size, y_dim) 54 | 55 | 56 | ######### Parameter estimation ######### 57 | N_ensem = 50 58 | init_m = torch.zeros(x_dim, device=device) 59 | init_C_param = noise.AddGaussian(x_dim, 25 * torch.eye(x_dim), 'full').to(device) 60 | init_coeff = torch.zeros(18) 61 | init_Q = 2 * torch.ones(x_dim) 62 | learned_ode_func = nn_templates.Lorenz96_dict_param(init_coeff, device, x_dim).to(device) 63 | learned_model_Q = noise.AddGaussian(x_dim, init_Q, 'diag').to(device) 64 | optimizer = torch.optim.Adam([{'params':learned_ode_func.parameters(), 'lr':1e-1}, 65 | {'params':learned_model_Q.parameters(), 'lr':1e-1}]) 66 | lambda1 = lambda2 = lambda epoch: (epoch-9)**(-0.5) if epoch >=10 else 1 67 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 68 | L = 20 # subsequence length in AD-EnKF-T 69 | monitor = [] 70 | for epoch in tqdm(range(150), desc="Training", leave=False): 71 | train_log_likelihood = torch.zeros(train_size, device=device) 72 | t_start = t0 73 | X = init_C_param(init_m.expand(train_size, N_ensem, x_dim)) 74 | 75 | # Training phase 76 | for start in range(0, n_obs, L): 77 | optimizer.zero_grad() 78 | end = min(start + L, n_obs) 79 | X, res, log_likelihood = da_methods.EnKF(learned_ode_func,true_obs_func, t_obs[start:end], y_obs[start:end], N_ensem, init_m, init_C_param, learned_model_Q, noise_R_true,device, 80 | save_filter_step={}, t0=t_start, init_X=X, ode_options=dict(step_size=0.01), adjoint_options=dict(step_size=0.05), localization_radius=5, tqdm=None) 81 | t_start = t_obs[end - 1] 82 | (-log_likelihood).mean().backward() 83 | train_log_likelihood += log_likelihood.detach().clone() 84 | optimizer.step() 85 | scheduler.step() 86 | 87 | # Printing 88 | if epoch % 5 == 0: 89 | tqdm.write(f"Epoch {epoch}, Training log-likelihood: {train_log_likelihood.mean().item()}") 90 | tqdm.write(f"Learned coefficients: {learned_ode_func.coeff.data.cpu().numpy()}") 91 | with torch.no_grad(): 92 | q_scale = torch.sqrt(torch.trace(learned_model_Q.full())/x_dim) 93 | curr_output = learned_ode_func.coeff.tolist() + [q_scale.item()] + [train_log_likelihood.mean().item()] 94 | monitor.append(curr_output) 95 | 96 | # Reproducing Figure 6, EnKF results 97 | monitor = np.asarray(monitor) 98 | fig, axes = plt.subplots(2, 2, figsize=(12, 10)) 99 | for i in range(18): 100 | if i in {0,3,11,16}: 101 | axes[0,0].plot(monitor[:,i]) 102 | else: 103 | axes[0,1].plot(monitor[:,i]) 104 | axes[1,0].plot(monitor[:,-2]) 105 | axes[1,1].plot(monitor[:,-1]) 106 | axes[0,0].set_ylabel("Coefficients, non-zero entries") 107 | axes[0,1].set_ylabel("Coefficients, zero entries") 108 | axes[1,0].set_ylabel("Error level") 109 | axes[1,1].set_ylabel("Training log-likelihood") 110 | plt.show() -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | from matplotlib.backends.backend_pdf import PdfPages 7 | import math 8 | 9 | import matplotlib 10 | 11 | 12 | 13 | 14 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import 15 | import time 16 | 17 | class Timer(object): 18 | def __init__(self, name=None): 19 | if name is None: 20 | self.name = 'foo' 21 | else: 22 | self.name = name 23 | 24 | def __enter__(self): 25 | self.tstart = time.time() 26 | 27 | def __exit__(self, type, value, traceback): 28 | if self.name is None: 29 | print('[%s]' % self.name,) 30 | print(f"{self.name}, Elapsed: {time.time() - self.tstart}s") 31 | 32 | def softplus(t): 33 | return torch.log(1. + torch.exp(t)) 34 | 35 | def softplus_inv(t): 36 | return torch.log(-1. + torch.exp(t)) 37 | 38 | def softplus_grad(t): 39 | return torch.exp(t) / (1. + torch.exp(t)) 40 | 41 | def flat2matrix(t, truth=None): 42 | # t: (*batch_dims, d*d) --> (*batch_dim, d, d) 43 | batch_dim = t.shape[:-1] 44 | mat_dim = int(math.sqrt(t.shape[-1])) 45 | return torch.linalg.norm(t.view(*batch_dim, mat_dim, mat_dim) - truth, 'fro', dim=(-1,-2)) 46 | 47 | 48 | def visualize_matrix(mat, symmetric_error_bar=False): 49 | mat_plot = mat.detach().cpu().numpy() 50 | fig_ratio = mat_plot.shape[0] / mat_plot.shape[1] 51 | if symmetric_error_bar: 52 | vmax = max(mat_plot.max(), -mat_plot.min()) 53 | fig = plt.figure(figsize=(8/(fig_ratio) , 8)) 54 | ax = fig.add_subplot(111) 55 | if symmetric_error_bar: 56 | cf = ax.imshow(mat_plot, cmap=cm.bwr, vmax=vmax, vmin=-vmax) 57 | else: 58 | cf = ax.imshow(mat_plot, cmap=cm.bwr) 59 | divider = make_axes_locatable(ax) 60 | cax = divider.append_axes("right", size="5%", pad=0.05) 61 | plt.colorbar(cf, cax=cax) 62 | plt.show() 63 | 64 | 65 | def unique_labels(ax): 66 | handles, labels = ax.get_legend_handles_labels() 67 | labels, ids = np.unique(labels, return_index=True) 68 | handles = [handles[i] for i in ids] 69 | return handles, labels 70 | 71 | def mse_loss(input, target): 72 | return torch.mean((input - target) ** 2) 73 | 74 | def mse_loss_last_dim(input, target): 75 | last_dim = input.shape[-1] 76 | return torch.mean((input.reshape(-1, last_dim) - target.reshape(-1, last_dim)) ** 2, dim=0) 77 | 78 | def particle_mse_loss(input, target, weight): 79 | weight = weight.unsqueeze(-1) 80 | mean = torch.sum(weight * input, dim=-2) # (n_obs, *bs, x_dim) 81 | return mse_loss(mean, target) 82 | 83 | def particle_mse_loss_last_dim(input, target, weight): 84 | weight = weight.unsqueeze(-1) 85 | mean = torch.sum(weight * input, dim=-2) # (n_obs, *bs, x_dim) 86 | return mse_loss_last_dim(mean, target) 87 | 88 | def weighted_mse_loss(input, target, weight): 89 | return torch.mean(weight * (input - target) ** 2) 90 | 91 | def ess(weight): 92 | # (*bdims, weight) -> (*bdims) 93 | return 1 / (weight**2).sum(dim=-1) 94 | 95 | def mean_and_std(t, axis=None): 96 | t_np = t.detach().cpu().numpy() 97 | if axis is None: 98 | return np.mean(t_np), np.std(t_np) 99 | else: 100 | return np.mean(t_np, axis=axis), np.std(t_np, axis=axis) 101 | 102 | def construct_exp(x_dim): 103 | exp = torch.zeros(x_dim, x_dim) 104 | for i in range(x_dim): 105 | for j in range(x_dim): 106 | exp[i, j] = -1. * abs(i-j) 107 | return exp 108 | 109 | def shrink_batch_dim(t): 110 | # (a, *bs, b, c) --> (a, -1, b, c) 111 | return t.view(t.shape[0], -1, t.shape[-2], t.shape[-1]) 112 | 113 | def mean_over_all_but_last_k_dims(t, k): 114 | num_dims = len(t.shape) 115 | for _ in range(num_dims - k): 116 | t = t.mean(dim=0) 117 | return t 118 | 119 | def set_lim_ticks(ax, low, high, n_ticks, ratio=0.05): 120 | res = (high-low)*ratio 121 | ax.set_ylim(low-res, high+res) 122 | ax.yaxis.set_ticks(np.linspace(low, high, n_ticks)) 123 | return 124 | 125 | def plot_monitor_res_new(monitor_res, methods, titles, truths, logscale, logscalex=None, truths_legends=None, truths2=None, truths2_legends=None, plot_truth_first=True, start_from_one=None, gridspec_kw=None, legend_order=None, groups=None, error_bar=False, error_bar_style="std", plots_to_show=None, n_cols=3, 126 | subplots_adjust=None, subplot_width=8, subplot_height=6, subplot_groups=None, ax_d=None, linewidth=2, colors_list=None, custom_axes=None, x_axiss=None, save_location=None, file_name=None, load_location=None): 127 | ''' 128 | Args: 129 | monitor_res: (list of n_methods *) (n_runs, n_epochs, n_monitors) 130 | e.g., for a 'method' Diff-EnKF, I have 4 repeated runs, statistics over 100 training epochs, of 6 different statistics (training loss, test loss, etc.) 131 | For each 'method', n_runs, n_epochs, n_monitors are the same 132 | For different 'method's, make sure each monitor has the same meaning 133 | method: (list of) method_names 134 | titles: (list of) monitor_names 135 | 136 | ''' 137 | 138 | if monitor_res is None and load_location is not None: 139 | d = torch.load(load_location) 140 | monitor_res = d['monitor_res'] 141 | x_axiss = d['x_axiss'] 142 | 143 | if not isinstance(monitor_res, list): 144 | monitor_res = [monitor_res] 145 | if not isinstance(methods, list): 146 | methods = [methods] 147 | 148 | n_methods = len(monitor_res) 149 | n_subplots = 0 150 | for me in range(n_methods): 151 | if torch.is_tensor(monitor_res[me]): 152 | monitor_res[me] = monitor_res[me].detach().cpu().numpy() 153 | n_subplots = max(n_subplots, monitor_res[me].shape[2]) 154 | 155 | if groups is None: 156 | groups = [(i,i+1) for i in range(n_subplots)] 157 | else: 158 | new_group = [] 159 | for g in range(len(groups)): 160 | if g == 0: 161 | new_group += [(i,i+1) for i in range(groups[g][0])] 162 | else: 163 | new_group += [(i,i+1) for i in range(groups[g-1][1], groups[g][0])] 164 | new_group += [groups[g]] 165 | new_group += [(i,i+1) for i in range(groups[-1][1], n_subplots)] 166 | groups = new_group 167 | n_subplots = len(groups) 168 | 169 | if plots_to_show is not None: 170 | n_subplots = len(plots_to_show) 171 | groups = [groups[i] for i in plots_to_show] 172 | titles = [titles[i] for i in plots_to_show] 173 | def unravel(i, n_cols): 174 | return (i//n_cols, i%n_cols) 175 | 176 | if subplot_groups is not None: 177 | row_range, col_range = subplot_groups 178 | tmp = n_subplots - 1 + (row_range[1]-row_range[0])*(col_range[1]-col_range[0]) 179 | n_rows = (tmp-1)//n_cols + 1 180 | else: 181 | n_rows = (n_subplots-1)//n_cols + 1 182 | 183 | 184 | fig, axes = plt.subplots(n_rows, n_cols, figsize=(subplot_width * n_cols, subplot_height * n_rows), constrained_layout=False, gridspec_kw=gridspec_kw) 185 | 186 | if n_rows == 1: 187 | axes = np.array([axes]) 188 | if n_cols == 1: 189 | axes = axes[:,np.newaxis] 190 | 191 | 192 | if subplot_groups is not None: 193 | gs = axes[row_range[0],col_range[0]].get_gridspec() 194 | for row in range(row_range[0], row_range[1]): 195 | for col in range(col_range[0], col_range[1]): 196 | axes[row, col].remove() 197 | axbig = fig.add_subplot(gs[row_range[0]:row_range[1], col_range[0]:col_range[1]]) 198 | ax_d = {sp:axes[ax_d[sp]] for sp in range(1, n_subplots)} 199 | ax_d[0] = axbig 200 | else: 201 | ax_d = {sp:axes[unravel(sp,n_cols)] for sp in range(n_subplots)} 202 | 203 | 204 | # colors_list = [cm.Wistia, cm.Blues, cm.Greens] 205 | # colors_list = [cm.Reds, cm.Blues, cm.Greens] 206 | colors_list = [cm.Blues, cm.Greens, cm.Reds] if colors_list is None else colors_list 207 | 208 | for sp in range(n_subplots): 209 | if start_from_one is not None and sp in start_from_one: 210 | start = 1 211 | else: 212 | start = 0 213 | idx = unravel(sp, n_cols) 214 | for me in range(n_methods): 215 | cur_res = monitor_res[me] # (n_runs, n_epochs, n_monitors) 216 | cur_res_mean = cur_res.mean(axis=0) 217 | cur_res_std = cur_res.std(axis=0) 218 | cur_res_75_quantile = np.quantile(cur_res, 0.75, axis=0) 219 | cur_res_50_quantile = np.quantile(cur_res, 0.5, axis=0) 220 | cur_res_25_quantile = np.quantile(cur_res, 0.25, axis=0) 221 | n_runs, n_epochs, n_monitors = cur_res.shape 222 | if x_axiss is None: 223 | x_axis = np.arange(n_epochs) 224 | else: 225 | x_axis = x_axiss[sp] 226 | 227 | for m in range(groups[sp][0], groups[sp][1]): 228 | if m >= n_monitors: 229 | continue 230 | colors = colors_list[me](np.linspace(0.5,0.7,n_runs+1)) 231 | if plot_truth_first: 232 | if m < len(truths) and truths[m] is not None: 233 | if (truths_legends is None or truths_legends[m] is None): 234 | label=None 235 | elif truths_legends[m] == 't': 236 | label = "Truth" 237 | else: 238 | label = truths_legends[m] 239 | ax_d[sp].axhline(y=truths[m], color='r', linestyle='--', linewidth=2.5, label=label) 240 | if truths2 is not None and m < len(truths2) and truths2[m] is not None: 241 | if (truths2_legends is None or truths2_legends[m] is None): 242 | label=None 243 | elif truths2_legends[m] == 't': 244 | label = "Truth" 245 | else: 246 | label = truths2_legends[m] 247 | ax_d[sp].axhline(y=truths2[m], color='r', linestyle='--', linewidth=2.5, label=label) 248 | 249 | 250 | 251 | if not error_bar: 252 | for r in range(n_runs): 253 | ax_d[sp].plot(x_axis[start:], cur_res[r,start:x_axis.shape[0],m], color=colors[-r-1],linewidth=linewidth, label=methods[me]) 254 | else: 255 | if error_bar_style == "std": 256 | ax_d[sp].plot(x_axis[start:], cur_res_mean[start:x_axis.shape[0],m], color=colors[-1],linewidth=linewidth, label=methods[me]) 257 | ax_d[sp].fill_between(x_axis[start:], cur_res_mean[start:x_axis.shape[0],m]+2*cur_res_std[start:x_axis.shape[0],m], cur_res_mean[start:x_axis.shape[0],m]-2*cur_res_std[start:x_axis.shape[0],m], facecolor=colors[-1], edgecolor=None, alpha=0.2) 258 | elif error_bar_style == "quantile": 259 | ax_d[sp].plot(x_axis[start:], cur_res_50_quantile[start:x_axis.shape[0],m], color=colors[-1],linewidth=linewidth, label=methods[me]) 260 | ax_d[sp].fill_between(x_axis[start:], cur_res_75_quantile[start:x_axis.shape[0],m], cur_res_25_quantile[start:x_axis.shape[0],m], facecolor=colors[-1], edgecolor=None, alpha=0.2) 261 | 262 | if not plot_truth_first: 263 | if m < len(truths) and truths[m] is not None: 264 | if (truths_legends is None or truths_legends[m] is None): 265 | label=None 266 | elif truths_legends[m] == 't': 267 | label = "Truth" 268 | else: 269 | label = truths_legends[m] 270 | ax_d[sp].axhline(y=truths[m], color='r', linestyle='--', linewidth=2.5, label=label) 271 | if truths2 is not None and m < len(truths2) and truths2[m] is not None: 272 | if (truths2_legends is None or truths2_legends[m] is None): 273 | label=None 274 | elif truths2_legends[m] == 't': 275 | label = "Truth" 276 | else: 277 | label = truths2_legends[m] 278 | ax_d[sp].axhline(y=truths2[m], color='k', linestyle='-', linewidth=2.5, label=label) 279 | if titles[sp] is not None: 280 | ax_d[sp].set_title(f"{titles[sp]}")#, {n_runs} independent run(s) 281 | if sp in logscale: 282 | ax_d[sp].set_yscale('log') 283 | if logscalex is not None and sp in logscalex: 284 | ax_d[sp].set_xscale('log') 285 | ax_d[sp].invert_xaxis() 286 | # if truths[i] is not None: 287 | # axes[idx].axhline(y=truths[i], color='r', linestyle='--', linewidth=3, label="True value") 288 | 289 | # Handle legends 290 | handles, labels = unique_labels(ax_d[sp]) 291 | if legend_order is None: 292 | ax_d[sp].legend(handles, labels, loc='upper right') 293 | else: 294 | ax_d[sp].legend([handles[idx] for idx in legend_order],[labels[idx] for idx in legend_order], loc='upper right') 295 | 296 | 297 | if custom_axes is not None: 298 | custom_axes(axes, ax_d) 299 | 300 | if subplots_adjust is not None: 301 | plt.subplots_adjust(wspace=subplots_adjust) 302 | 303 | if save_location is not None: 304 | # torch.save({'monitor_res': monitor_res, 'x_axiss':x_axiss}, 305 | # save_location) 306 | plt.savefig(save_location+".pdf", bbox_inches='tight') 307 | 308 | plt.show() 309 | 310 | 311 | 312 | def plot_dynamic(t_eval, out, t_obs=None, y_obs=None, online_color=None, fig_num_limit=None, text="obs", contour_plot=False): 313 | t_eval = t_eval.cpu().numpy() 314 | if t_obs is None: 315 | t_obs = t_eval 316 | x_dim = out.shape[1] 317 | if fig_num_limit is not None and x_dim > fig_num_limit: 318 | x_dim = fig_num_limit 319 | n_eval = t_eval.shape[0] 320 | out_plot = out.detach().cpu().numpy() 321 | if y_obs is not None: 322 | obs_dim = y_obs.shape[1] 323 | if fig_num_limit is not None and obs_dim > fig_num_limit: 324 | obs_dim = fig_num_limit 325 | obs_plot = y_obs.detach().cpu().numpy() 326 | fig, axes = plt.subplots(x_dim, 1, sharex=True, figsize=(6, 3*x_dim), constrained_layout=True) 327 | if x_dim == 1: 328 | axes = np.array([axes]) 329 | fig.suptitle('True dynamic (for each coordinate)', fontsize=15) 330 | for i, ax in enumerate(axes): 331 | # Note that: ax = axes[i] 332 | ax.plot(t_eval, out_plot[:, i], 'r', linewidth=2, label="Truth") 333 | y_lim=ax.get_ylim() 334 | ax.set_ylim(y_lim) 335 | if y_obs is not None and i < obs_dim: 336 | ax.plot(t_obs, obs_plot[:, i], 'y', label=text) 337 | ax.set_xlabel('$t$', fontsize=10) 338 | # ax.set_ylabel('$u_1$', rotation=0, fontsize=15) 339 | ax.legend(loc='upper right') 340 | 341 | if x_dim == 2 or x_dim == 3: 342 | fig2 = plt.figure(figsize=(12,6)) 343 | fig2.suptitle('True dynamic', fontsize=15) 344 | if x_dim == 2: 345 | ax21 = fig2.add_subplot(121) 346 | ax21.plot(out_plot[:, 0], out_plot[:, 1], 'r', linewidth=2, label="Truth") 347 | elif x_dim == 3: 348 | ax21 = fig2.add_subplot(121, projection='3d') 349 | ax21.plot(out_plot[:, 0], out_plot[:, 1], out_plot[:, 2], c='r', linewidth=2, label="Truth") 350 | ax21.scatter(out_plot[0, 0], out_plot[0, 1], out_plot[0, 2], marker='*', s=25, linestyle="None", color="k", label="start") 351 | ax21.scatter(out_plot[-1, 0], out_plot[-1, 1], out_plot[-1, 2], marker='^', s=25, linestyle="None", color="k", label="end") 352 | ax21.legend(loc='upper right') 353 | if y_obs is not None: 354 | if obs_dim == 1: 355 | ax22 = fig2.add_subplot(122) 356 | ax22.plot(t_obs, obs_plot, 'y', label=text) 357 | elif obs_dim == 2: 358 | ax22 = fig2.add_subplot(122) 359 | # ax22.plot(obs_plot[:, 0], obs_plot[:, 1], 'y', label=text) 360 | ax22.plot(obs_plot[0, 0], obs_plot[0, 1], marker='*', markersize=12, linestyle="None", color="k", label=f"start") 361 | ax22.plot(obs_plot[-1, 0], obs_plot[-1, 1], marker='^', markersize=12, linestyle="None", color="k", label=f"end") 362 | ax22.quiver(obs_plot[:-1, 0], obs_plot[:-1, 1], obs_plot[1:, 0]-obs_plot[:-1, 0], obs_plot[1:, 1]-obs_plot[:-1, 1], np.linspace(0,1,len(t_obs)-1), scale_units='xy', angles='xy', scale=1) 363 | if x_dim == 2: 364 | plt.setp(ax22, xlim=ax21.get_xlim(),ylim=ax21.get_ylim()) 365 | elif obs_dim == 3: 366 | ax22 = fig2.add_subplot(122, projection='3d') 367 | if online_color is None: 368 | ax22.plot(obs_plot[:, 0], obs_plot[:, 1], obs_plot[:, 2], c='y', label=text) 369 | else: 370 | online_color = online_color.detach().cpu().numpy().squeeze() 371 | online_color = (online_color - np.min(online_color))/np.ptp(online_color) 372 | n_obs = obs_plot.shape[0] 373 | for i in range(0, n_obs-1): 374 | ax22.plot(obs_plot[i:i+2, 0], obs_plot[i:i+2, 1], obs_plot[i:i+2, 2], color=plt.cm.Blues(online_color[i]*0.75+0.25)) 375 | # for i in range(len(t_obs)-1): 376 | # ax22.plot(obs_plot[i:i+2, 0], obs_plot[i:i+2, 1], obs_plot[i:i+2, 2], color=plt.cm.jet(255*i/len(t_obs))) 377 | ax22.scatter(obs_plot[0, 0], obs_plot[0, 1], obs_plot[0, 2], marker='*', s=25, linestyle="None", color="k", label=f"start") 378 | ax22.scatter(obs_plot[-1, 0], obs_plot[-1, 1], obs_plot[-1, 2], marker='^', s=25, linestyle="None", color="k", label=f"end") 379 | # ax22.quiver(obs_plot[:-1, 0], obs_plot[:-1, 1], obs_plot[:-1, 2], obs_plot[1:, 0]-obs_plot[:-1, 0], obs_plot[1:, 1]-obs_plot[:-1, 1], obs_plot[1:, 2]-obs_plot[:-1, 2], np.linspace(0,1,n_eval-1)) 380 | if x_dim == 3: 381 | plt.setp(ax22, xlim=ax21.get_xlim(),ylim=ax21.get_ylim(), zlim=ax21.get_zlim()) 382 | ax22.legend(loc='upper right') 383 | 384 | if contour_plot: 385 | vert = list(range(out.shape[1])) 386 | fig3 = plt.figure(figsize=(18,6)) 387 | fig3.suptitle('Time evolution', fontsize=15) 388 | ax31 = fig3.add_subplot(111) 389 | vmin = out_plot.min() 390 | vmax = out_plot.max() 391 | cf1 = ax31.contourf(t_eval, vert, out_plot.T, levels=np.linspace(vmin, vmax, 8)) 392 | 393 | plt.colorbar(cf1) 394 | if y_obs is not None and y_obs.shape[1] == out.shape[1]: 395 | vert_obs = list(range(y_obs.shape[1])) 396 | fig4 = plt.figure(figsize=(18,6)) 397 | ax41 = fig4.add_subplot(111) 398 | cf2 = ax41.contourf(t_eval, vert_obs, obs_plot.T, levels=np.linspace(vmin, vmax, 8), extend='both') 399 | plt.colorbar(cf2) 400 | fig5 = plt.figure(figsize=(18,6)) 401 | ax51 = fig5.add_subplot(111) 402 | cf3 = ax51.contourf(t_eval, vert_obs, out_plot.T - obs_plot.T, cmap=cm.bwr, levels=np.linspace(-vmax, vmax, 20), extend='both') 403 | plt.colorbar(cf3) 404 | plt.show() 405 | 406 | # def plot_filter_old(t_eval, out, E, name="Ensemble", plot_all=True, compare_F_S=False, fig_num_limit=None): 407 | # t_eval = t_eval.cpu().numpy() 408 | # x_dim = out.shape[1] 409 | # if fig_num_limit is not None and x_dim > fig_num_limit: 410 | # x_dim = fig_num_limit 411 | # out_plot = out.detach().cpu().numpy() 412 | # n_cols = 2 if compare_F_S else 1 413 | # fig, axes = plt.subplots(x_dim, n_cols, sharex='col', sharey='row', figsize=(8*n_cols, 4*x_dim), constrained_layout=True) 414 | 415 | # if x_dim == 1: 416 | # axes = np.array([axes]) 417 | # if not compare_F_S: 418 | # axes = axes[:, np.newaxis] 419 | 420 | # if name == "Ensemble": 421 | # X_track_plot = E.X_track.detach().cpu().numpy() 422 | # if compare_F_S: 423 | # X_smooth_plot = E.X_smooth.detach().cpu().numpy() 424 | # else: 425 | # X_smooth_plot = E.X_track.detach().cpu().numpy() 426 | # N_ensem = X_track_plot.shape[1] 427 | # track_mean = X_track_plot.mean(axis=1) 428 | # track_std = np.std(X_track_plot, axis=1) 429 | # smooth_mean = X_smooth_plot.mean(axis=1) 430 | # smooth_std = np.std(X_smooth_plot, axis=1) 431 | # fig.suptitle("Ensemble Kalman", fontsize=15) 432 | # elif name == "Kalman": 433 | # track_mean = E.mu_track.detach().cpu().numpy() 434 | # V_track = E.V_track.detach().cpu().numpy() 435 | # track_std = np.sqrt(np.diagonal(V_track, axis1=1, axis2=2)) 436 | # smooth_mean = E.mu_smooth.detach().cpu().numpy() 437 | # V_smooth = E.V_smooth.detach().cpu().numpy() 438 | # smooth_std = np.sqrt(np.diagonal(V_smooth, axis1=1, axis2=2)) 439 | # fig.suptitle("Kalman", fontsize=15) 440 | 441 | # for i, ax in enumerate(axes): 442 | # if name == "Ensemble" and plot_all: 443 | # for n in range(N_ensem): 444 | # ax[0].plot(np.insert(t_eval, 0, 0.), X_smooth_plot[:, n, i], 'b', linewidth=0.5, label='EnKS') 445 | # if compare_F_S: 446 | # ax[1].plot(np.insert(t_eval, 0, 0.), X_track_plot[:, n, i], 'b', linewidth=0.5, label='EnKF') 447 | # else: 448 | # ax[0].plot(np.insert(t_eval, 0, 0.), smooth_mean[:, i], 'b--', linewidth=2, label='Smoother mean') 449 | # ax[0].plot(np.insert(t_eval, 0, 0.), smooth_mean[:, i] + 2*smooth_std[:, i], 'b', linewidth=0.5, label='+2 std') 450 | # ax[0].plot(np.insert(t_eval, 0, 0.), smooth_mean[:, i] - 2*smooth_std[:, i], 'b', linewidth=0.5, label='-2 std') 451 | # if compare_F_S: 452 | # ax[1].plot(np.insert(t_eval, 0, 0.), track_mean[:, i], 'b--', linewidth=2, label='Filter mean') 453 | # ax[1].plot(np.insert(t_eval, 0, 0.), track_mean[:, i] + 2*track_std[:, i], 'b', linewidth=0.5, label='+2 std') 454 | # ax[1].plot(np.insert(t_eval, 0, 0.), track_mean[:, i] - 2*track_std[:, i], 'b', linewidth=0.5, label='-2 std') 455 | # ax[0].plot(t_eval, out_plot[:, i], 'r', linewidth=3, label="Truth") 456 | # ax[0].set_xlabel('$t$', fontsize=10) 457 | # handles, labels = unique_labels(ax[0]) 458 | # ax[0].legend(handles, labels, loc='upper right') 459 | # # ax[0].set_ylim(-30, 60) 460 | 461 | # if compare_F_S: 462 | # ax[1].plot(t_eval, out_plot[:, i], 'r', linewidth=3, label="Truth") 463 | # ax[1].set_xlabel('$t$', fontsize=10) 464 | # handles, labels = unique_labels(ax[1]) 465 | # ax[1].legend(handles, labels, loc='upper right') 466 | # plt.show() 467 | 468 | def plot_loss(n_epochs, loss_1_track, loss_2_track): 469 | fig = plt.figure(figsize=(16, 6)) 470 | ax1 = fig.add_subplot(121) 471 | ax2 = fig.add_subplot(122) 472 | ax1.plot(list(range(n_epochs)), loss_1_track, label='observation model loss') 473 | ax2.plot(list(range(n_epochs)), loss_2_track, label='dynamic model loss') 474 | ax1.set_xlabel('$EM iter$', fontsize=10) 475 | ax2.set_xlabel('$EM iter$', fontsize=20) 476 | ax1.legend() 477 | ax2.legend() 478 | # ax1.set_yscale('log') 479 | # ax2.set_yscale('log') 480 | plt.show() 481 | 482 | def plot_monitor_res(monitor_res, method, monitor_res_cmp, method_cmp, titles, truths, logscale): 483 | ''' 484 | Args: 485 | monitor_res: list of n_methods * (n_runs, n_epochs, n_monitors) or (n_runs, n_epochs, n_monitors) 486 | method: list of method_names 487 | titles: list of monitor_names 488 | ''' 489 | 490 | # monitor_res: (n_runs, n_epochs, n_monitors) 491 | monitor_res = monitor_res.detach().cpu().numpy() 492 | n_runs = monitor_res.shape[0] 493 | n_epochs = monitor_res.shape[1] 494 | n_monitors = monitor_res.shape[2] 495 | 496 | n_subplots = n_monitors 497 | if monitor_res_cmp is not None: 498 | monitor_res_cmp = monitor_res_cmp.detach().cpu().numpy() 499 | n_subplots = max(n_subplots, monitor_res_cmp.shape[2]) 500 | 501 | n_cols = 3 502 | n_rows = (n_subplots-1)//n_cols + 1 503 | def unravel(i, n_cols): 504 | return (i//n_cols, i%n_cols) 505 | 506 | fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows), constrained_layout=False) 507 | if n_rows == 1: 508 | axes = np.array([axes]) 509 | colors = cm.Blues(np.linspace(0.5,1,n_runs)) 510 | for i in range(n_monitors): 511 | idx = unravel(i, n_cols) 512 | for j in range(n_runs): 513 | axes[idx].plot(np.arange(n_epochs), monitor_res[j, :, i], color=colors[j], label=method) 514 | axes[idx].set_title(f"{titles[i]}, {n_runs} run(s)", fontsize=20) 515 | if i in logscale: 516 | axes[idx].set_yscale('log') 517 | if truths[i] is not None: 518 | axes[idx].axhline(y=truths[i], color='r', linestyle='--', linewidth=3, label="True value") 519 | 520 | if monitor_res_cmp is not None: 521 | colors_cmp = cm.Wistia(np.linspace(0.5, 1, n_runs)) 522 | n_monitors_cmp = monitor_res_cmp.shape[2] 523 | for i in range(n_monitors_cmp): 524 | idx = unravel(i, n_cols) 525 | for j in range(n_runs): 526 | axes[idx].plot(np.arange(n_epochs), monitor_res_cmp[j, :, i], color=colors_cmp[j], label=method_cmp) 527 | axes[idx].set_title(f"{titles[i]}, {n_runs} run(s)", fontsize=20) 528 | 529 | # Handle legends 530 | for i in range(n_monitors): 531 | idx = unravel(i, n_cols) 532 | handles, labels = unique_labels(axes[idx]) 533 | axes[idx].legend(handles, labels, loc='upper right') 534 | 535 | # axes[0,0].set_ylim(-0.4, 0.7) 536 | # axes[0,2].set_ylim(30, 60) 537 | # axes[1,0].set_ylim(30, 60) 538 | # axes[1,1].set_ylim(7500, 20000) 539 | # axes[1,2].set_ylim(0.75, 3) 540 | # axes[2,0].set_ylim(0.75, 3) 541 | # axes[2,1].set_ylim(0, 70) 542 | plt.show() 543 | 544 | 545 | def plot_1d_nll_with_grad(x_axis, nll, nll_kalman, N_ensem_list, x_label, x_truth, grad=None, grad_kalman=None, recon=None, dx=0.7, grad_every=10, delta=10, x_scale=None): 546 | # nll: (Ne, n_pts, n_trial) 547 | x_axis = x_axis.detach().cpu().numpy() 548 | nll = nll.detach().cpu().numpy() 549 | # recon = recon.detach().cpu().numpy() 550 | ne_trials = nll.shape[0] 551 | n_pts = nll.shape[1] 552 | n_trials = nll.shape[2] 553 | indices_grad = list(range(0,n_pts,grad_every)) 554 | if grad is not None: 555 | grad = grad.detach().cpu().numpy() 556 | 557 | 558 | fig = plt.figure(figsize=(24, 9)) 559 | ax1 = fig.add_subplot(121) 560 | ax2 = fig.add_subplot(122) 561 | # ax3 = fig.add_subplot(223) 562 | colors = cm.viridis_r(np.linspace(0,1,ne_trials)) 563 | 564 | for ne in range(ne_trials): 565 | nll_mean = nll[ne].mean(axis=1) 566 | nll_std = nll[ne].std(axis=1) 567 | # recon_mean = recon.mean(axis=0) 568 | # recon_std = np.std(recon, axis=0) 569 | if grad is not None: 570 | grad_mean = grad[ne].mean(axis=1) 571 | grad_std = grad[ne].std(axis=1) 572 | 573 | ax1.plot(x_axis, nll_mean, linewidth=3, label=f"EnKF, N = {N_ensem_list[ne]}") 574 | ax1.fill_between(x_axis, nll_mean - 2*nll_std, nll_mean + 2*nll_std, alpha=0.3) 575 | ax1.set_xlabel(x_label, fontsize=15) 576 | ax1.set_title(f"nll versus {x_label}, averaged {n_trials} EnKF trials", fontsize=20) 577 | # ax3.plot(x_axis, recon_mean, 'k', label='recon $\pm$ 2std') 578 | # ax3.fill_between(x_axis, recon_mean - 2*recon_std, recon_mean + 2*recon_std, alpha=0.3) 579 | # ax3.set_xlabel(x_label, fontsize=15) 580 | # ax3.set_title(f"recon_error versus F, averaged {n_trials} EnKF trials", fontsize=15) 581 | if delta is not None and ne == ne_trials - 1: 582 | ylim1 = [(nll_mean - 2 * nll_std).min() - delta[0], (nll_mean + 2 * nll_std).max()+delta[0]] 583 | 584 | 585 | 586 | if grad is not None: 587 | ax2.plot(x_axis, grad_mean, linewidth=3, label=f"EnKF, N = {N_ensem_list[ne]}") 588 | ax2.fill_between(x_axis, grad_mean - 2*grad_std, grad_mean + 2*grad_std, alpha=0.3) 589 | ax2.set_xlabel(x_label, fontsize=15) 590 | ax2.set_title(f"nll_grad versus {x_label}, w/ autograd, averaged {n_trials} EnKF trials", fontsize=20) 591 | if delta is not None and ne == ne_trials - 1: 592 | ylim2 = [(grad_mean - 2 * grad_std).min() - delta[1], (grad_mean + 2 * grad_std).max()+delta[1]] 593 | 594 | # for i in indices_grad: 595 | # ax3.quiver(x_axis[i]*np.ones(n_trials), nll_mean[i]*np.ones(n_trials), dx*np.ones(n_trials), dx*grad[ne, i], angles='xy', scale_units='xy', scale=1, width=0.005, headwidth=3, headlength=5, color="blue") 596 | # ax3.fill_between(x_axis, nll_mean - 2*nll_std, nll_mean + 2*nll_std, alpha=0.3) 597 | # ax3.set_xlabel(x_label, fontsize=15) 598 | # ax3.set_title(f"nll_grad versus {x_label}, {n_trials} EnKF trials (w/ backprop)", fontsize=15) 599 | 600 | if nll_kalman is not None: 601 | nllk = nll_kalman.detach().cpu().numpy() 602 | ax1.plot(x_axis, nllk, 'r', linewidth=3, label=f"KF") 603 | if grad_kalman is not None: 604 | gradk = grad_kalman.detach().cpu().numpy() 605 | ax2.plot(x_axis, gradk, 'r', linewidth=3, label=f"KF") 606 | 607 | ax1.axvline(x=x_truth, color='k', linestyle='--', linewidth=3, label=f"True {x_label}") 608 | ax2.axvline(x=x_truth, color='k', linestyle='--', linewidth=3, label=f"True {x_label}") 609 | # ax3.axvline(x=x_truth, color='r', linestyle='--', linewidth=3, label=f"True {x_label}") 610 | 611 | if delta is not None: 612 | ax1.set_ylim(ylim1) 613 | ax1.legend(loc='upper right') 614 | if grad is not None: 615 | if delta is not None: 616 | ax2.set_ylim(ylim2) 617 | ax2.legend(loc='upper right') 618 | # if delta is not None: 619 | # ax3.set_ylim(ylim1) 620 | 621 | ax1.xaxis.set_tick_params(labelsize=20) 622 | ax1.yaxis.set_tick_params(labelsize=20) 623 | ax2.xaxis.set_tick_params(labelsize=20) 624 | ax2.yaxis.set_tick_params(labelsize=20) 625 | if x_scale == "log": 626 | ax1.set_xscale('log') 627 | ax2.set_xscale('log') 628 | # ax3.xaxis.set_tick_params(labelsize=20) 629 | # ax3.yaxis.set_tick_params(labelsize=20) 630 | plt.show() 631 | 632 | 633 | def plot_2d_nll_with_grad(x_axis, y_axis, nll, nll_kalman, N_ensem, xy_label, xy_truth, grad=None, grad_kalman=None, recon=None, arrow_scale=1, dx=0.7, grad_every=10, xy_scale=[None,None]): 634 | # nll: (n_ptsx, n_ptsy, n_trial) 635 | # grad: (n_ptsx, n_ptsy, 2, n_trial) 636 | x_axis = x_axis.detach().cpu().numpy() 637 | y_axis = y_axis.detach().cpu().numpy() 638 | nll = nll.detach().cpu().numpy() 639 | # recon = recon.detach().cpu().numpy() 640 | n_ptsx = nll.shape[0] 641 | n_ptsy = nll.shape[1] 642 | n_trials = nll.shape[2] 643 | indices_gradx = np.arange(0,n_ptsx, grad_every) 644 | indices_grady = np.arange(0,n_ptsy, grad_every) 645 | if grad is not None: 646 | grad = grad.detach().cpu().numpy() 647 | 648 | fig, axes = plt.subplots(2, 3, figsize=(24, 16), constrained_layout=False) 649 | 650 | # colors = cm.viridis_r(np.linspace(0,1,ne_trials)) 651 | 652 | nll_mean = nll.mean(axis=-1) 653 | nll_std = nll.std(axis=-1) 654 | # recon_mean = recon.mean(axis=0) 655 | # recon_std = np.std(recon, axis=0) 656 | if grad is not None: 657 | gradx_mean = grad.mean(axis=-1)[:, :, 0] 658 | grady_mean = grad.mean(axis=-1)[:, :, 1] 659 | grad_std = np.sqrt((grad**2).sum(axis=-2)).std(axis=-1) 660 | 661 | 662 | vmax = nll_mean.max() 663 | vmin = nll_mean.min() 664 | if nll_kalman is not None: 665 | nllk = nll_kalman.detach().cpu().numpy() 666 | vmax = nllk.max() 667 | vmin = nllk.min() 668 | cf02 = axes[0,2].contourf(x_axis, y_axis, nllk.T, levels=np.linspace(vmin, vmax, 25)) 669 | axes[0,2].set_xlabel(xy_label[0], fontsize=15) 670 | axes[0,2].set_ylabel(xy_label[1], fontsize=15) 671 | axes[0,2].set_title(f"nll versus ({xy_label[0]}, {xy_label[1]}), Kalman Filter", fontsize=20) 672 | divider = make_axes_locatable(axes[0,2]) 673 | cax = divider.append_axes("right", size="5%", pad=0.05) 674 | plt.colorbar(cf02, cax=cax) 675 | if grad_kalman is not None: 676 | gradkx = grad_kalman[:, :, 0] 677 | gradky = grad_kalman[:, :, 1] 678 | axes[1,2].quiver(x_axis[0:n_ptsx:grad_every], y_axis[0:n_ptsy:grad_every], -gradkx[0:n_ptsx:grad_every,0:n_ptsy:grad_every].T, -gradky[0:n_ptsx:grad_every,0:n_ptsy:grad_every].T, scale_units='xy', angles='xy', scale=arrow_scale) 679 | axes[1,2].set_xlabel(xy_label[0], fontsize=15) 680 | axes[1,2].set_ylabel(xy_label[1], fontsize=15) 681 | axes[1,2].set_title(f"nll_grad versus ({xy_label[0]}, {xy_label[1]}), Kalman Filter", fontsize=20) 682 | 683 | cf00 = axes[0,0].contourf(x_axis, y_axis, nll_mean.T, levels=np.linspace(vmin, vmax, 25), extend='both') 684 | axes[0,0].set_xlabel(xy_label[0], fontsize=15) 685 | axes[0,0].set_ylabel(xy_label[1], fontsize=15) 686 | axes[0,0].set_title(f"nll versus ({xy_label[0]}, {xy_label[1]}), N={N_ensem}, averaged {n_trials} EnKF trials", fontsize=20) 687 | divider = make_axes_locatable(axes[0,0]) 688 | cax = divider.append_axes("right", size="5%", pad=0.05) 689 | plt.colorbar(cf00, cax=cax) 690 | cf01 = axes[0,1].contourf(x_axis, y_axis, nll_std.T) 691 | axes[0,1].set_xlabel(xy_label[0], fontsize=15) 692 | axes[0,1].set_ylabel(xy_label[1], fontsize=15) 693 | axes[0,1].set_title(f"std", fontsize=20) 694 | divider = make_axes_locatable(axes[0,1]) 695 | cax = divider.append_axes("right", size="5%", pad=0.05) 696 | plt.colorbar(cf01, cax=cax) 697 | # ax3.plot(x_axis, recon_mean, 'k', label='recon $\pm$ 2std') 698 | # ax3.fill_between(x_axis, recon_mean - 2*recon_std, recon_mean + 2*recon_std, alpha=0.3) 699 | # ax3.set_xlabel(x_label, fontsize=15) 700 | # ax3.set_title(f"recon_error versus F, averaged {n_trials} EnKF trials", fontsize=15) 701 | 702 | if grad is not None: 703 | axes[1,0].quiver(x_axis[0:n_ptsx:grad_every], y_axis[0:n_ptsy:grad_every], -gradx_mean[0:n_ptsx:grad_every,0:n_ptsy:grad_every].T, -grady_mean[0:n_ptsx:grad_every,0:n_ptsy:grad_every].T, scale_units='xy', angles='xy', scale=arrow_scale) 704 | axes[1,0].set_xlabel(xy_label[0], fontsize=15) 705 | axes[1,0].set_ylabel(xy_label[1], fontsize=15) 706 | axes[1,0].set_title(f"nll_grad versus ({xy_label[0]}, {xy_label[1]}), N={N_ensem}, averaged {n_trials} EnKF trials", fontsize=20) 707 | cf11 = axes[1,1].contourf(x_axis, y_axis, grad_std.T) 708 | axes[1,1].set_xlabel(xy_label[0], fontsize=15) 709 | axes[1,1].set_ylabel(xy_label[1], fontsize=15) 710 | axes[1,1].set_title(f"std of norm", fontsize=20) 711 | divider = make_axes_locatable(axes[1,1]) 712 | cax = divider.append_axes("right", size="5%", pad=0.05) 713 | plt.colorbar(cf11, cax=cax) 714 | 715 | for i in range(axes.shape[0]): 716 | for j in range(axes.shape[1]): 717 | axes[i,j].axvline(x=xy_truth[0], color='k', linestyle='--', linewidth=3, label=f"True {xy_label[0]}") 718 | axes[i,j].axhline(y=xy_truth[1], color='k', linestyle='-.', linewidth=3, label=f"True {xy_label[1]}") 719 | axes[i,j].xaxis.set_tick_params(labelsize=20) 720 | axes[i,j].yaxis.set_tick_params(labelsize=20) 721 | if xy_scale[0] == "log": 722 | axes[i,j].set_xscale('log') 723 | if xy_scale[1] == "log": 724 | axes[i,j].set_yscale('log') 725 | 726 | # f = plt.figure(figsize=(9,9)) 727 | # ax = plt.axes(projection='3d') 728 | # cf = ax.contour3D(x_axis, y_axis, nll_mean.T, levels=np.linspace(vmin, vmax, 25), extend='both') 729 | # divider = make_axes_locatable(ax) 730 | # cax = divider.append_axes("right", size="5%", pad=0.05) 731 | # plt.colorbar(cf, cax=cax) 732 | 733 | plt.show() 734 | return 735 | 736 | def plot_filter(t_eval, out, filter_track, name="Ensemble", plot_all=False, compare_F_S=False, fig_num_limit=None): 737 | 738 | t_eval = t_eval.cpu().numpy() 739 | x_dim = out.shape[1] 740 | if fig_num_limit is not None and x_dim > fig_num_limit: 741 | x_dim = fig_num_limit 742 | out_plot = out.detach().cpu().numpy() 743 | n_cols = 2 if compare_F_S else 1 744 | fig, axes = plt.subplots(x_dim, n_cols, sharex='col', sharey='row', figsize=(8*n_cols, 4*x_dim), constrained_layout=True) 745 | 746 | if x_dim == 1: 747 | axes = np.array([axes]) 748 | if not compare_F_S: 749 | axes = axes[:, np.newaxis] 750 | 751 | if name == "Ensemble": 752 | X_track_plot = filter_track.detach().cpu().numpy() 753 | X_smooth_plot = filter_track.detach().cpu().numpy() 754 | N_ensem = X_track_plot.shape[1] 755 | track_mean = X_track_plot.mean(axis=1) 756 | track_std = np.std(X_track_plot, axis=1) 757 | smooth_mean = X_smooth_plot.mean(axis=1) 758 | smooth_std = np.std(X_smooth_plot, axis=1) 759 | fig.suptitle("Ensemble Kalman", fontsize=15) 760 | elif name == "Kalman": 761 | track_mean = filter_track[0].detach().cpu().numpy() 762 | V_track = filter_track[1].detach().cpu().numpy() 763 | track_std = np.sqrt(np.diagonal(V_track, axis1=1, axis2=2)) 764 | smooth_mean = filter_track[0].detach().cpu().numpy() 765 | V_smooth = filter_track[1].detach().cpu().numpy() 766 | smooth_std = np.sqrt(np.diagonal(V_smooth, axis1=1, axis2=2)) 767 | fig.suptitle("Kalman", fontsize=15) 768 | elif name == "Particle": 769 | w_track = filter_track[0].detach().cpu().numpy() # n_obs, N_ensem 770 | X_track = filter_track[1].detach().cpu().numpy() # n_obs, N_ensem, x_dim 771 | w_track = w_track[...,np.newaxis] # (n_obs, N_ensem, 1) 772 | smooth_mean = np.sum(w_track * X_track, axis=1, keepdims=True) # (n_obs, 1, x_dim) 773 | smooth_std = np.sum(w_track * (X_track - smooth_mean)**2, axis=1) # (n_obs, x_dim) 774 | smooth_mean = np.squeeze(smooth_mean, 1) 775 | fig.suptitle("Particle", fontsize=15) 776 | 777 | for i, ax in enumerate(axes): 778 | ax[0].plot(t_eval, out_plot[:, i], 'r', linewidth=3, label="Truth") 779 | xlim = ax[0].get_xlim() 780 | ylim = ax[0].get_ylim() 781 | if name == "Ensemble" and plot_all: 782 | for n in range(N_ensem): 783 | ax[0].plot(t_eval, X_smooth_plot[:, n, i], 'b', linewidth=0.5, label='EnKS') 784 | if compare_F_S: 785 | ax[1].plot(t_eval, X_track_plot[:, n, i], 'b', linewidth=0.5, label='EnKF') 786 | else: 787 | ax[0].plot(t_eval, smooth_mean[:, i], 'b--', linewidth=1, label='Smoother mean') 788 | ax[0].plot(t_eval, smooth_mean[:, i] + 2*smooth_std[:, i], 'b', linewidth=0.5, label='+2 std') 789 | ax[0].plot(t_eval, smooth_mean[:, i] - 2*smooth_std[:, i], 'b', linewidth=0.5, label='-2 std') 790 | if compare_F_S: 791 | ax[1].plot(t_eval, track_mean[:, i], 'b--', linewidth=1, label='Filter mean') 792 | ax[1].plot(t_eval, track_mean[:, i] + 2*track_std[:, i], 'b', linewidth=0.5, label='+2 std') 793 | ax[1].plot(t_eval, track_mean[:, i] - 2*track_std[:, i], 'b', linewidth=0.5, label='-2 std') 794 | plt.setp(ax[0], xlim=xlim,ylim=ylim) 795 | ax[0].set_xlabel('$t$', fontsize=10) 796 | handles, labels = unique_labels(ax[0]) 797 | ax[0].legend(handles, labels, loc='upper right') 798 | # ax[0].set_ylim(-30, 60) 799 | 800 | if compare_F_S: 801 | ax[1].plot(t_eval, out_plot[:, i], 'r', linewidth=3, label="Truth") 802 | ax[1].set_xlabel('$t$', fontsize=10) 803 | handles, labels = unique_labels(ax[1]) 804 | ax[1].legend(handles, labels, loc='upper right') 805 | plt.show() 806 | 807 | 808 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | -------------------------------------------------------------------------------- /torchEnKF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymchen0/torchEnKF/016b4f8412310c195671c81790d372bd6cd9dc95/torchEnKF/__init__.py -------------------------------------------------------------------------------- /torchEnKF/da_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torchEnKF import misc 4 | from torchdiffeq import odeint_adjoint 5 | from torchdiffeq import odeint 6 | from torch.nn.functional import normalize 7 | 8 | from tqdm import tqdm 9 | import time 10 | 11 | def construct_Gaspari_Cohn(loc_radius, x_dim, device): 12 | def G(z): 13 | if z >= 0 and z < 1: 14 | return 1. - 5./3*z**2 + 5./8*z**3 + 1./2*z**4 - 1./4*z**5 15 | elif z >= 1 and z < 2: 16 | return 4. - 5.*z + 5./3*z**2 + 5./8*z**3 - 1./2*z**4 + 1./12*z**5 - 2./(3*z) 17 | else: 18 | return 0 19 | taper = torch.zeros(x_dim, x_dim, device=device) 20 | for i in range(x_dim): 21 | for j in range(x_dim): 22 | dist = min(abs(i-j), x_dim - abs(i-j)) 23 | taper[i, j] = G(dist/loc_radius) 24 | return taper 25 | 26 | def power_iter(A, n_iter=1): 27 | device = A.device 28 | u_shape = A[...,0:1].shape # (*bs, N_ensem, 1) 29 | v_shape = A[...,0:1,:].shape # (*bs, 1, y_dim) 30 | u = normalize(A.new_empty(u_shape).normal_(0, 1), dim=-2) # (*bs, N_ensem, 1) 31 | v = normalize(A.new_empty(v_shape).normal_(0, 1), dim=-1) # (*bs, 1, y_dim) 32 | for i in range(n_iter): 33 | v = normalize(A.transpose(-1, -2) @ u, dim=-2) # (*bs, y_dim, 1) 34 | u = normalize(A @ v, dim=-2) # (*bs, N_ensem, 1) 35 | sigma = u.transpose(-1, -2) @ A @ v 36 | # A = A / sigma 37 | v = v.transpose(-1,-2) # (*bs, 1, y_dim) 38 | return sigma 39 | 40 | def inv_logdet(v, Y_ct, R, R_inv, logdet_R): 41 | # Returns matrix-vector product (Y Y^T + R)^{-1} times v for matrix Y=Y_ct and any choice of vector/matrix v. Also returns the log-determinant of (Y Y^T + R). 42 | # Supports batch operation 43 | # Y_ct: (*bs, N_ensem, y_dim), R: (y_dim, y_dim), v: (*bs, bs2, y_dim) 44 | # out (invv): (*bs, bs2, y_dim) 45 | device = Y_ct.device 46 | N_ensem = Y_ct.shape[-2] 47 | y_dim = Y_ct.shape[-1] 48 | if N_ensem >= y_dim: 49 | YYT_R = Y_ct.transpose(-1, -2) @ Y_ct + R 50 | YYT_R_chol = torch.linalg.cholesky(YYT_R) 51 | logdet = 2 * YYT_R_chol.diagonal(dim1=-2, dim2=-1).log().sum(-1) 52 | invv = torch.cholesky_solve(v.transpose(-1,-2), YYT_R_chol).transpose(-1,-2) 53 | else: 54 | YTRinv = Y_ct @ R_inv # (*bs, N_ensem, y_dim) 55 | YTRinvv = YTRinv @ v.transpose(-1, -2) # (*bs, N_ensem, bs2) 56 | I_YTRinvY = torch.eye(N_ensem, device=device) + YTRinv @ Y_ct.transpose(-1, -2) # (*bs, N_ensem, N_ensem) 57 | sc = power_iter(I_YTRinvY,n_iter=1) 58 | I_YTRinvY_sc = I_YTRinvY/sc 59 | I_YTRinvY_chol_sc = torch.linalg.cholesky(I_YTRinvY_sc) # (*bs, N_ensem, N_ensem) 60 | I_YTRinvY_inv_YTRinvv = 1/sc * torch.cholesky_solve(YTRinvv, I_YTRinvY_chol_sc) # (*bs, N_ensem, bs2) 61 | invv = v @ R_inv - I_YTRinvY_inv_YTRinvv.transpose(-1,-2) @ YTRinv # (*bs, bs2, y_dim) 62 | logdet = y_dim * torch.log(sc).squeeze(-1).squeeze(-1) + 2 * I_YTRinvY_chol_sc.diagonal(dim1=-2, dim2=-1).log().sum(-1) + logdet_R 63 | return invv, logdet # (*bs, bs2, y_dim), (*bs) 64 | 65 | def EnKF(ode_func, obs_func, t_obs, y_obs, N_ensem, init_m, init_C_param, model_Q_param, noise_R_param, device, 66 | init_X=None, ode_method='rk4', ode_options=None, adjoint=True, adjoint_method='rk4', adjoint_options=None, save_filter_step={'mean'}, 67 | smooth_lag=0, t0=0., var_inflation=None, localization_radius=None, compute_likelihood=True, linear_obs=True, time_varying_obs=False, 68 | save_first=False, tqdm=None, **ode_kwargs): 69 | """ 70 | EnKF with stochastic perturbation. 71 | 72 | Key args: 73 | ode_func (torch.nn.Module): Vector field f(t,x) 74 | Tip: Wrap all parameters of interest that you want to evaluate gradient by torch.nn.Parameter() 75 | NOTE: This implicitly assume the underlying latent model is an ODE. For generic type of latent evolutions x_{t+1}=F(x_t), slight modifications of the forcast step are required. 76 | obs_func (torch.nn.Module): Observation model h(x), assumed to be linear h(x) = Hx. 77 | If time varying_obs==True, can take a list of torch.nn.Module's 78 | t_obs (tensor): 1D-Tensor of shape (n_obs,). Time points where observations are available. 79 | This does NOT need to be time-uniform. By default, t0 is NOT included. Must be monotonic increasing. 80 | y_obs (tensor): Tensor of shape (n_obs, *bs, y_dim). Observed values at t_obs. 81 | '*bs' can be arbitrary batch dimension (or empty). 82 | Observations are assumed to have the same dimension 'y_dim'. However, observation model can be time-varying. 83 | N_ensem: Number of particles. 84 | init_m (tensor): Tensor of shape (x_dim, ). Mean of the initial distribution. 85 | init_C (noise.AddGaussian): covariance of initial distribution 86 | model_Q_param (noise.AddGaussian): model error covariance 87 | noise_R_param (noise.AddGaussian): observation error covariance 88 | 89 | Optional args: 90 | init_X (tensor): Tensor of shape (*bs, N_ensem, x_dim). Initial ensemble if pre-specified. 91 | ode_method: Numerical scheme for forward equation. We use 'euler' or 'rk4'. Other solvers are available. See https://github.com/rtqichen/torchdiffeq 92 | ode_options: Set it to dict(step_size=...) for fixed step solvers for the forward equation. Adaptive solvers are also available - see the link above. 93 | adjoint (bool): Whether to compute gradient via adjoint equation or direct backpropagation through the solver. 94 | adjoint_method: Numerical scheme for adjoint equation if adjoint==True. 95 | adjoint options: Set it to dict(step_size=...) for fixed step solvers for the adjoint equation. Adaptive solvers are also available - see the link above. 96 | ode_kwargs: additional kwargs for neuralODE. 97 | save_filter_step: 98 | If contains 'mean', then particle means will be saved. 99 | If contains 'particles', then all particles will be saved. 100 | (Note: the up-to-date/final particles will always be returned seperately) 101 | t0: The timestamp at which the ensemble is initialized. 102 | By default, we DO NOT assume observation is available at t0. Slight modifications of the code are needed to handle this situation. 103 | var_inflation: See discussion in paper. Typical value is between 1 and 1.1. None by default. 104 | localization_radius: See discussion in paper. Typical value is 5. None by default. 105 | compute_likelihood: Whether to compute data log-likelihood in the filtering process. 106 | Must be set to True for AD-EnKF. 107 | linear_obs: If set to True, then obs_func must be 'nn_templates.Linear' class. The observation model is y = Hx + noise where H is a matrix. 108 | If set to False, then obs_func can be any differentiable function/module in PyTorch. The observation model is y = obs_func(x) + noise 109 | time_varying_obs: If set to False, the observation model is time-invariant. A single nn.Module/function is sufficient for the obs_func argument. 110 | If set to True, the observation model can be different across time. A list of nn.Module/functions is needed for obs_func argument and has the same length as t_obs. 111 | save_first: Set it to True to save the initial ensemble. 112 | tqdm: Set tqdm=tqdm to use the tqdm format for presenting. 113 | 114 | Returns: 115 | X (tensor): Tensor of shape (*bs, N_ensem, x_dim). Final ensemble. 116 | res (dict): 117 | If save_filter_step contains 'mean', then res['mean'] will be tensor of shape (n_obs, *bs, x_dim) 118 | If save_filter_step contains 'particles', then res['particles'] will be tensor of shape (n_obs, *bs, N_ensem, x_dim) 119 | log_likelihood (tensor): Log likelihood estimate # (*bs) 120 | """ 121 | 122 | ode_integrator = odeint_adjoint if adjoint else odeint 123 | 124 | x_dim = init_m.shape[0] 125 | y_dim = y_obs.shape[-1] 126 | n_obs = y_obs.shape[0] 127 | bs = y_obs.shape[1:-1] 128 | 129 | if ode_options is None: 130 | if n_obs > 0: 131 | step_size = (t_obs[1:] - t_obs[:-1]).min() # This computes the minimum length of time intervals in t_obs. However it's more preferred to manually provide a quantity for the step_size to avoid issues like non-divisibility. 132 | ode_options = dict(step_size=step_size) 133 | 134 | log_likelihood = torch.zeros(bs, device=device) if compute_likelihood else None # (*bs), tensor(0.) if no batch dimension 135 | 136 | if linear_obs and localization_radius is not None: 137 | taper = construct_Gaspari_Cohn(localization_radius, x_dim, device) 138 | 139 | if init_X is not None: 140 | X = init_X.detach() 141 | else: 142 | X = init_C_param(init_m.expand(*bs, N_ensem, x_dim)) 143 | 144 | 145 | res = {} 146 | if 'particles' in save_filter_step: 147 | res['particles'] = torch.empty(n_obs + 1, *bs, N_ensem, x_dim, dtype=init_m.dtype, device=device) 148 | res['particles'][0] = X 149 | if 'mean' in save_filter_step: 150 | X_m = X.mean(dim=-2) 151 | res['mean'] = torch.empty(n_obs + 1, *bs, x_dim, dtype=init_m.dtype, device=device) 152 | res['mean'][0] = X_m.detach() 153 | 154 | 155 | step_size = ode_options['step_size'] 156 | 157 | t_cur = t0 158 | 159 | pbar = tqdm(range(n_obs), desc="Running EnKF", leave=False) if tqdm is not None else range(n_obs) 160 | for j in pbar: 161 | ################ Forecast step ################## 162 | n_intermediate_j = round(((t_obs[j] - t_cur) / step_size).item()) 163 | if adjoint: 164 | X = ode_integrator(ode_func, X, torch.linspace(t_cur, t_obs[j], n_intermediate_j + 1, device=device), method=ode_method, adjoint_method=adjoint_method, adjoint_options=adjoint_options, **ode_kwargs)[-1] 165 | else: 166 | X = ode_integrator(ode_func, X, torch.linspace(t_cur, t_obs[j], n_intermediate_j + 1, device=device), method=ode_method, **ode_kwargs)[-1] 167 | t_cur = t_obs[j] 168 | 169 | if model_Q_param is not None: 170 | X = model_Q_param(X) 171 | 172 | X_m = X.mean(dim=-2).unsqueeze(-2) # (*bs, 1, x_dim) 173 | X_ct = X - X_m 174 | 175 | if var_inflation is not None: 176 | X = var_inflation * (X - X_m) + X_m 177 | 178 | ################ Analysis step ################## 179 | obs_func_j = obs_func[j] if time_varying_obs else obs_func 180 | y_obs_j = y_obs[j].unsqueeze(-2) # (*bs, 1, y_dim) 181 | # Noise perturbation of observed data (key in stochastic EnKF) 182 | obs_perturb = noise_R_param(y_obs_j.expand(*bs, N_ensem, y_dim)) 183 | noise_R = noise_R_param.full() 184 | noise_R_inv = noise_R_param.inv() 185 | logdet_noise_R = noise_R_param.logdet() 186 | 187 | 188 | 189 | if linear_obs and localization_radius is not None: 190 | H = obs_func_j.H # (y_dim, x_dim) 191 | C_uu = 1 / (N_ensem - 1) * X_ct.transpose(-1,-2) @ X_ct # (*bs, x_dim, x_dim). Note: It can be made memory-efficient by not computing this explicity. See discussion in paper. 192 | C_uu = taper * C_uu 193 | HX = X @ H.transpose(-1, -2) # (*bs, N_ensem, y_dim) 194 | HX_m = X_m @ H.transpose(-1, -2) # (*bs, 1, y_dim) 195 | HC = H @ C_uu # (*bs, y_dim, x_dim) 196 | HCH_T = HC @ H.transpose(-1, -2) # (*bs, y_dim, y_dim) 197 | HCH_TR_chol = torch.linalg.cholesky(HCH_T + noise_R) # (*bs, y_dim, y_dim), lower-tril 198 | if compute_likelihood: 199 | d = torch.distributions.MultivariateNormal(HX_m.squeeze(-2), scale_tril=HCH_TR_chol) # (*bs, y_dim) and (*bs, y_dim, y_dim) 200 | log_likelihood += d.log_prob(y_obs_j.squeeze(-2)) # (*bs) 201 | pre = (obs_perturb - HX) @ torch.cholesky_inverse(HCH_TR_chol) # (*bs, N_ensem, y_dim) 202 | X = X + pre @ HC # (*bs, N_ensem, x_dim) 203 | else: 204 | HX = obs_func_j(X) # (*bs, N_ensem, y_dim) 205 | HX_m = HX.mean(dim=-2).unsqueeze(-2) # (*bs, 1, y_dim) 206 | HX_ct = HX - HX_m 207 | C_ww_sqrt = 1/math.sqrt(N_ensem-1) * HX_ct # (*bs, N_ensem, y_dim) 208 | v1 = obs_perturb - HX # (*bs, N_ensem, y_dim) 209 | v2 = y_obs_j - HX_m # (*bs, 1, y_dim) 210 | v = torch.cat((v1, v2), dim=-2) # (*bs, N_ensem+1, y_dim) 211 | C_ww_R_invv, C_ww_R_logdet = inv_logdet(v, C_ww_sqrt, noise_R, noise_R_inv, logdet_noise_R) # (*bs, N_ensem+1, y_dim), (*bs) 212 | pre = C_ww_R_invv[..., :N_ensem, :] # (*bs, N_ensem, y_dim) # (*bs, N_ensem, y_dim) 213 | if compute_likelihood: 214 | part1 = -1 / 2 * (y_dim * math.log(2 * math.pi) + C_ww_R_logdet) # (1,) 215 | part2 = -1 / 2 * C_ww_R_invv[..., N_ensem:, :] @ (y_obs_j - HX_m).transpose(-1, -2) # (*bs, 1, 1,) 216 | log_likelihood += (part1 + part2.squeeze(-1).squeeze(-1)) # (*bs) 217 | X = X + 1 / math.sqrt(N_ensem - 1) * (pre @ C_ww_sqrt.transpose(-1, -2)) @ X_ct # (*bs, N_ensem, x_dim) 218 | 219 | if 'particles' in save_filter_step: 220 | res['particles'][j+1] = X 221 | if 'mean' in save_filter_step: 222 | X_m = X.mean(dim=-2) 223 | res['mean'][j+1] = X_m.detach() 224 | 225 | if not save_first: 226 | for key in res.keys(): 227 | res[key] = res[key][1:] 228 | return X, res, log_likelihood 229 | 230 | -------------------------------------------------------------------------------- /torchEnKF/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def ess(weight): 5 | # (*bdims, weight) -> (*bdims) 6 | return 1 / (weight**2).sum(dim=-1) 7 | 8 | def softplus(t): 9 | return torch.log(1. + torch.exp(t)) 10 | 11 | def softplus_inv(t): 12 | return torch.log(-1. + torch.exp(t)) 13 | 14 | def softplus_grad(t): 15 | return torch.exp(t) / (1. + torch.exp(t)) -------------------------------------------------------------------------------- /torchEnKF/nn_templates.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear_ODE(nn.Module): 5 | def __init__(self, x_dim, a, param=None): 6 | super().__init__() 7 | self.a = nn.Parameter(a) 8 | self.x_dim = x_dim 9 | self.param=param 10 | if param is None: 11 | self.A = self.a[:] 12 | else: 13 | self.A = param(self.a, x_dim) 14 | 15 | def forward(self, t, u): 16 | # du/dt = f(u, t), input: N * x_dim, output: N * x_dim 17 | if self.param is None: 18 | A = self.a[:] 19 | else: 20 | A = self.param(self.a, self.x_dim) 21 | out = u @ (A - torch.eye(self.x_dim)).t() 22 | return out 23 | 24 | class Linear_ODE_single_var(nn.Module): 25 | def __init__(self, x_dim, a): 26 | super().__init__() 27 | self.a = nn.Parameter(a) 28 | self.x_dim = x_dim 29 | self.exp = self.construct_exp(x_dim) 30 | # self.A = torch.pow(self.a, self.exp) 31 | 32 | def construct_exp(self, x_dim): 33 | exp = torch.zeros(x_dim, x_dim) 34 | for i in range(x_dim): 35 | for j in range(x_dim): 36 | exp[i, j] = abs(i-j)+1 37 | return exp 38 | 39 | def A(self): 40 | return torch.pow(self.a, self.exp) 41 | 42 | def forward(self, t, u): 43 | # du/dt = f(u, t), input: N * x_dim, output: N * x_dim 44 | # A = torch.pow(self.a, self.exp) 45 | out = u @ (self.A() - torch.eye(self.x_dim)).t() 46 | return out 47 | 48 | class Linear_ODE_diag(nn.Module): 49 | def __init__(self, x_dim, a): 50 | super().__init__() 51 | self.a = nn.Parameter(a) 52 | self.x_dim = x_dim 53 | self.num_a = a.shape[0] 54 | 55 | def A(self): 56 | A = torch.zeros(self.x_dim, self.x_dim) 57 | for i in range(self.num_a): 58 | diagonal = (i+1)//2 * (-1)**i # 0, -1, 1, -2, 2,... 59 | len_one = self.x_dim - (i+1)//2 # (d, d-1, d-1, d-2, d-2...) 60 | A = A + torch.diag(self.a[i] * torch.ones(len_one), diagonal=diagonal) 61 | return A 62 | 63 | def forward(self, t, u): 64 | # du/dt = f(u, t), input: N * x_dim, output: N * x_dim 65 | # A = torch.pow(self.a, self.exp) 66 | out = u @ (self.A() - torch.eye(self.x_dim)).t() 67 | return out 68 | 69 | class Lorenz63(nn.Module): 70 | def __init__(self, coeff, x_dim=3): 71 | super().__init__() 72 | self.coeff = nn.Parameter(coeff) 73 | self.x_dim = x_dim 74 | 75 | def forward(self, t, u): 76 | # (*bs * x_dim) -> (*bs * x_dim) 77 | sigma, beta, rho = self.coeff 78 | out = torch.stack((sigma * (u[...,1] - u[...,0]), rho * u[...,0] - u[...,1] - u[...,0] * u[...,2], u[...,0] * u[...,1] - beta * u[...,2]), dim=-1) 79 | return out 80 | 81 | class Linear(nn.Module): 82 | def __init__(self, x_dim, y_dim, H): 83 | super().__init__() 84 | self.H = nn.Parameter(H) 85 | self.x_dim = x_dim 86 | self.y_dim = y_dim 87 | 88 | def forward(self, u): 89 | # du/dt = f(u, t), input: N * x_dim, output: N * x_dim 90 | out = u @ self.H.t() 91 | return out 92 | 93 | class ODE_Net(nn.Module): 94 | def __init__(self, x_dim, hidden_layer_widths, scaled_layer=False): 95 | # e.g. hidden_layer_widths = [40, 128, 64, 40] 96 | super().__init__() 97 | self.hidden_layer_widths = hidden_layer_widths 98 | self.num_hidden_layers = len(hidden_layer_widths) - 2 99 | self.layers = nn.ModuleList() 100 | for i in range(self.num_hidden_layers+1): 101 | layer = nn.Linear(hidden_layer_widths[i], hidden_layer_widths[i+1]) 102 | if scaled_layer: 103 | layer.weight.data.mul_(1/math.sqrt(layer_dims[i])) 104 | self.layers.append(layer) 105 | 106 | def forward(self, t, u): 107 | for layer in self.layers[:-1]: 108 | u = torch.relu(layer(u)) 109 | out = self.layers[-1](u) 110 | return out 111 | 112 | class FC_Net(nn.Module): 113 | def __init__(self, x_dim, hidden_layer_widths, scaled_layer=False): 114 | # e.g. hidden_layer_widths = [40, 128, 64, 40] 115 | super().__init__() 116 | self.num_hidden_layers = len(hidden_layer_widths) - 2 117 | self.layers = nn.ModuleList() 118 | for i in range(self.num_hidden_layers+1): 119 | layer = nn.Linear(hidden_layer_widths[i], hidden_layer_widths[i+1]) 120 | if scaled_layer: 121 | layer.weight.data.mul_(1/math.sqrt(layer_dims[i])) 122 | self.layers.append(layer) 123 | 124 | def forward(self, u): 125 | for layer in self.layers[:-1]: 126 | u = torch.relu(layer(u)) 127 | out = self.layers[-1](u) 128 | return out 129 | 130 | class ODE_Net_from_basenet(nn.Module): 131 | # e.g. base_net = [3,40], this_hidden_layer_widths=[40,40,3] 132 | def __init__(self, base, hidden_layer_widths): 133 | super().__init__() 134 | self.base = base 135 | self.num_hidden_layers = len(hidden_layer_widths) - 2 136 | self.layers = nn.ModuleList() 137 | for i in range(self.num_hidden_layers+1): 138 | layer = nn.Linear(hidden_layer_widths[i], hidden_layer_widths[i+1]) 139 | self.layers.append(layer) 140 | 141 | def forward(self, t, u): 142 | u = torch.relu(self.base(u)) 143 | for layer in self.layers[:-1]: 144 | u = torch.relu(layer(u)) 145 | out = self.layers[-1](u) 146 | return out 147 | 148 | 149 | 150 | class L96_ODE_Net(nn.Module): 151 | def __init__(self, x_dim): 152 | super().__init__() 153 | self.x_dim = x_dim 154 | self.layer1 = nn.Conv1d(1, 6, 5, padding=2, padding_mode='circular') 155 | self.layer2 = nn.Conv1d(12, 1, 1) 156 | 157 | # self.layer1 = nn,Conv1d(1,6,5) 158 | 159 | def forward(self, t, u): 160 | bs = u.shape[:-1] 161 | out = torch.relu(self.layer1(u.view(-1, self.x_dim).unsqueeze(-2))) 162 | out = torch.cat((out**2, out), dim=-2) 163 | out = self.layer2(out).squeeze(-2).view(*bs, self.x_dim) 164 | return out 165 | 166 | class L96_ODE_Net_2(nn.Module): 167 | def __init__(self, x_dim): 168 | super().__init__() 169 | self.x_dim = x_dim 170 | self.layer1 = nn.Conv1d(1, 72, 5, padding=2, padding_mode='circular') 171 | # self.layer1b = nn.Conv1d(1, 24, 5, padding=2, padding_mode='circular') 172 | # self.layer1c = nn.Conv1d(1, 24, 5, padding=2, padding_mode='circular') 173 | self.layer2 = nn.Conv1d(48, 37, 5, padding=2, padding_mode='circular') 174 | self.layer3 = nn.Conv1d(37, 1, 1) 175 | 176 | # self.layer1 = nn,Conv1d(1,6,5) 177 | 178 | def forward(self, t, u): 179 | bs = u.shape[:-1] # (*bs, x_dim) 180 | out = torch.relu(self.layer1(u.view(-1, self.x_dim).unsqueeze(-2))) # (bs, 1, x_dim) -> (bs, 72, x_dim) 181 | out = torch.cat((out[...,:24,:], out[...,24:48,:] * out[...,48:,:]), dim=-2) # (bs, 72, x_dim) -> (bs, 48, x_dim) 182 | out = torch.relu(self.layer2(out)) # (bs, 48, x_dim) -> (bs, 37, x_dim) 183 | out = self.layer3(out).squeeze(-2).view(*bs, self.x_dim) # (bs, 37, x_dim) -> (bs, 1, x_dim) -> (*bs, x_dim) 184 | return out 185 | 186 | class Lorenz96(nn.Module): 187 | def __init__(self, F, x_dim, device): 188 | super().__init__() 189 | self.F = nn.Parameter(torch.tensor(F)) 190 | self.x_dim = x_dim 191 | self.indices_p1 = torch.tensor([(i+1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 192 | self.indices_m2 = torch.tensor([(i-2)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 193 | self.indices_m1 = torch.tensor([(i-1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 194 | 195 | def forward(self, t, u): 196 | # du/dt = f(u, t), input: N * x_dim, output: N * x_dim 197 | out = (u.index_select(-1, self.indices_p1) - u.index_select(-1, self.indices_m2)) * u.index_select(-1, self.indices_m1) - u + self.F 198 | return out 199 | 200 | class Lorenz96_correction(nn.Module): 201 | def __init__(self, coeff, x_dim=40): 202 | super().__init__() 203 | device = coeff.device 204 | self.coeff = coeff 205 | self.x_dim = x_dim 206 | self.indices_p1 = torch.tensor([(i+1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 207 | self.indices_p2 = torch.tensor([(i+2)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 208 | self.indices_m2 = torch.tensor([(i-2)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 209 | self.indices_m1 = torch.tensor([(i-1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 210 | 211 | self.layer1 = nn.Conv1d(1, 72, 5, padding=2, padding_mode='circular') 212 | # self.layer1b = nn.Conv1d(1, 24, 5, padding=2, padding_mode='circular') 213 | # self.layer1c = nn.Conv1d(1, 24, 5, padding=2, padding_mode='circular') 214 | self.layer2 = nn.Conv1d(48, 37, 5, padding=2, padding_mode='circular') 215 | self.layer3 = nn.Conv1d(37, 1, 1) 216 | 217 | def forward(self, t, u): 218 | # (*bs, x_dim) -> (*bs, x_dim) 219 | u_m2 = u.index_select(-1, self.indices_m2) 220 | u_m1 = u.index_select(-1, self.indices_m1) 221 | u_p1 = u.index_select(-1, self.indices_p1) 222 | u_p2 = u.index_select(-1, self.indices_p2) 223 | to_cat = [] 224 | to_cat.append(torch.ones_like(u)) 225 | to_cat.extend([u_m2, u_m1, u, u_p1, u_p2]) 226 | # to_cat.append(u) 227 | to_cat.extend([u_m2**2, u_m1**2, u**2, u_p1**2, u_p2**2]) 228 | to_cat.extend([u_m2*u_m1, u_m1*u, u*u_p1, u_p1*u_p2]) 229 | # to_cat.append(u_m2*u_m1) 230 | to_cat.extend([u_m2*u, u_m1*u_p1, u*u_p2]) 231 | # to_cat.append(u_m1*u_p1) 232 | out1 = torch.stack(to_cat, dim=-1) @ self.coeff # (*bs, x_dim, N_a) @ (N_a) -> (*bs, x_dim) 233 | 234 | bs = u.shape[:-1] # (*bs, x_dim) 235 | out2 = torch.relu(self.layer1(u.view(-1, self.x_dim).unsqueeze(-2))) # (bs, 1, x_dim) -> (bs, 72, x_dim) 236 | out2 = torch.cat((out2[...,:24,:], out2[...,24:48,:] * out2[...,48:,:]), dim=-2) # (bs, 72, x_dim) -> (bs, 48, x_dim) 237 | out2 = torch.relu(self.layer2(out2)) # (bs, 48, x_dim) -> (bs, 37, x_dim) 238 | out2 = self.layer3(out2).squeeze(-2).view(*bs, self.x_dim) # (bs, 37, x_dim) -> (bs, 1, x_dim) -> (*bs, x_dim) 239 | return out1+out2 240 | 241 | class Lorenz96_dict_param(nn.Module): 242 | def __init__(self, coeff, device, x_dim=40): 243 | super().__init__() 244 | self.coeff = nn.Parameter(coeff) 245 | self.x_dim = x_dim 246 | self.indices_p1 = torch.tensor([(i+1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 247 | self.indices_p2 = torch.tensor([(i+2)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 248 | self.indices_m2 = torch.tensor([(i-2)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 249 | self.indices_m1 = torch.tensor([(i-1)%self.x_dim for i in range(x_dim)], dtype=torch.long, device=device) 250 | 251 | def forward(self, t, u): 252 | # (*bs, x_dim) -> (*bs, x_dim) 253 | u_m2 = u.index_select(-1, self.indices_m2) 254 | u_m1 = u.index_select(-1, self.indices_m1) 255 | u_p1 = u.index_select(-1, self.indices_p1) 256 | u_p2 = u.index_select(-1, self.indices_p2) 257 | to_cat = [] 258 | to_cat.append(torch.ones_like(u)) 259 | to_cat.extend([u_m2, u_m1, u, u_p1, u_p2]) 260 | # to_cat.append(u) 261 | to_cat.extend([u_m2**2, u_m1**2, u**2, u_p1**2, u_p2**2]) 262 | to_cat.extend([u_m2*u_m1, u_m1*u, u*u_p1, u_p1*u_p2]) 263 | # to_cat.append(u_m2*u_m1) 264 | to_cat.extend([u_m2*u, u_m1*u_p1, u*u_p2]) 265 | # to_cat.append(u_m1*u_p1) 266 | out = torch.stack(to_cat, dim=-1) @ self.coeff # (*bs, x_dim, N_a) @ (N_a) -> (*bs, x_dim) 267 | return out 268 | 269 | 270 | class Lorenz96_FS(nn.Module): 271 | def __init__(self, param, device, xx_dim=36, xy_dim=10): 272 | super().__init__() 273 | self.param = nn.Parameter(param) 274 | 275 | self.xx_dim = xx_dim 276 | self.xy_dim = xy_dim 277 | self.x_dim = xx_dim * (xy_dim + 1) 278 | 279 | self.indices_x = torch.tensor([i for i in range(xx_dim)], dtype=torch.long) 280 | self.indices_x_p1 = torch.tensor([(i + 1) % self.xx_dim for i in range(xx_dim)], dtype=torch.long, device=device) 281 | self.indices_x_m2 = torch.tensor([(i - 2) % self.xx_dim for i in range(xx_dim)], dtype=torch.long, device=device) 282 | self.indices_x_m1 = torch.tensor([(i - 1) % self.xx_dim for i in range(xx_dim)], dtype=torch.long, device=device) 283 | self.indices_y_p1 = torch.tensor([(i + 1) % self.xy_dim for i in range(xy_dim)], dtype=torch.long, device=device) 284 | self.indices_y_p2 = torch.tensor([(i + 2) % self.xy_dim for i in range(xy_dim)], dtype=torch.long, device=device) 285 | self.indices_y_m1 = torch.tensor([(i - 1) % self.xy_dim for i in range(xy_dim)], dtype=torch.long, device=device) 286 | 287 | def forward(self, t, u): 288 | # (*bs * x_dim) -> (*bs * x_dim) 289 | bs = u.shape[:-1] 290 | F, h, c, b = self.param 291 | to_cat = [] 292 | u_y = u[..., self.xx_dim:].reshape(*bs, self.xx_dim, self.xy_dim) # (*bs, xx_dim, xy_dim) 293 | # print(u.index_select(-1, self.indices_x_p1).shape, u_y.mean(dim=-1).shape) 294 | to_cat.append((u.index_select(-1, self.indices_x_p1) - u.index_select(-1, self.indices_x_m2)) * u.index_select(-1, self.indices_x_m1) - u[...,:self.xx_dim] + F - h * c * u_y.mean(dim=-1)) 295 | to_cat.append(c * (-b * u_y.index_select(-1, self.indices_y_p1) * (u_y.index_select(-1, self.indices_y_p2) - u_y.index_select(-1,self.indices_y_m1)) - u_y + h / self.xy_dim * u[...,:self.xx_dim].unsqueeze(-1)).view(*bs, self.xx_dim * self.xy_dim)) 296 | out = torch.cat(to_cat, dim=-1) 297 | return out 298 | 299 | 300 | 301 | class One_Layer_NN(nn.Module): 302 | def __init__(self, input_dim, output_dim, H=None, residual=False, bias=False): 303 | super().__init__() 304 | self.residual = residual 305 | self.layer1 = nn.Linear(input_dim, output_dim, bias=bias) 306 | if H is not None: 307 | self.layer1.weight.data = H 308 | self.H = H 309 | 310 | def forward(self, x): 311 | res = self.layer1(x) 312 | if self.residual: 313 | out = res + x 314 | else: 315 | out = res 316 | return out 317 | 318 | class Two_Layer_NN(nn.Module): 319 | def __init__(self, input_dim, output_dim, hidden_dim, residual=True, activation="relu", batchnorm=False): 320 | super().__init__() 321 | if activation == "relu": 322 | self.activation = torch.relu 323 | elif activation == "tanh": 324 | self.activation = torch.tanh 325 | self.residual = residual 326 | self.batchnorm = batchnorm 327 | self.input_dim = input_dim 328 | self.output_dim = output_dim 329 | self.layer1 = nn.Linear(input_dim, hidden_dim) 330 | self.layer2 = nn.Linear(hidden_dim, output_dim) 331 | 332 | def forward(self, x): 333 | # x = x[:,1].unsqueeze(1) 334 | res = x 335 | if self.batchnorm: 336 | res = nn.BatchNorm1d(self.input_dim)(res) 337 | res = self.activation(self.layer1(res)) 338 | res = self.layer2(res) 339 | if self.residual: 340 | out = res + x 341 | else: 342 | out = res 343 | # out = res + 0.5*x[:,1].unsqueeze(1) 344 | return out 345 | 346 | class Three_Layer_NN(nn.Module): 347 | def __init__(self, input_dim, output_dim, hidden_dim, residual=True, activation="relu"): 348 | super().__init__() 349 | if activation == "relu": 350 | self.activation = torch.relu 351 | elif activation == "tanh": 352 | self.activation = torch.tanh 353 | self.residual = residual 354 | self.output_dim = output_dim 355 | self.layer1 = nn.Linear(input_dim, hidden_dim[0]) 356 | self.layer2 = nn.Linear(hidden_dim[0], hidden_dim[1]) 357 | self.layer3 = nn.Linear(hidden_dim[1], output_dim) 358 | 359 | def forward(self, x): 360 | res = self.activation(self.layer1(x)) 361 | res = self.activation(self.layer2(res)) 362 | res = self.layer3(res) 363 | if self.residual: 364 | out = res + x 365 | else: 366 | out = res 367 | return out 368 | 369 | class Four_Layer_NN(nn.Module): 370 | def __init__(self, input_dim, output_dim, hidden_dim, residual=True, activation="relu"): 371 | super().__init__() 372 | if activation == "relu": 373 | self.activation = torch.relu 374 | elif activation == "tanh": 375 | self.activation = torch.tanh 376 | self.residual = residual 377 | self.output_dim = output_dim 378 | self.layer1 = nn.Linear(input_dim, hidden_dim[0]) 379 | self.layer2 = nn.Linear(hidden_dim[0], hidden_dim[1]) 380 | self.layer3 = nn.Linear(hidden_dim[1], hidden_dim[2]) 381 | self.layer4 = nn.Linear(hidden_dim[2], output_dim) 382 | 383 | def forward(self, x): 384 | res = self.activation(self.layer1(x)) 385 | res = self.activation(self.layer2(res)) 386 | res = self.activation(self.layer3(res)) 387 | res = self.layer4(res) 388 | if self.residual: 389 | out = res + x 390 | else: 391 | out = res 392 | return out 393 | 394 | 395 | -------------------------------------------------------------------------------- /torchEnKF/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchEnKF import misc 4 | import math 5 | 6 | class AddGaussian(nn.Module): 7 | """ 8 | torch.nn.Module that adds a Gaussian perturbation to a given input. 9 | (softplus function is used to ensure positiveness.) 10 | 11 | The Gaussian perturbation is parameterized by q, which may take the following forms, depending on "param_type": 12 | 1. "scalar": q has shape (1,). The perturbation is N(0, softplus(q)**2 * Id) 13 | 2. "diag": q has shape (x_dim,). The perturbation is N(0, diag(softplus(q)**2)) 14 | 3. "tril": q has shape (x_dim, x_dim). The perturbation is N(0, LL^T), where l is lower-triangular. 15 | Diagonal entries of L are the diagonal entries of q transformed by a softplus. 16 | Lower triangular part of L is the same as that of q. 17 | 4. "full": q has shape (x_dim, x_dim) and is positive definite. The perturbation is N(0, q). 18 | Use this if q does not need to be learned. 19 | """ 20 | 21 | def __init__(self, x_dim, q_true, param_type, q_shape=None): 22 | # q_shape: Additional parameter for the linear Gaussian experiment appeared in paper 23 | super().__init__() 24 | self.x_dim = x_dim 25 | self.q = nn.Parameter(self.pre_process(q_true, param_type)) 26 | self.param_type = param_type 27 | self.q_shape = q_shape 28 | if q_shape is not None: 29 | self.base = misc.construct_exp(x_dim) 30 | self.q_shape = nn.Parameter(q_shape) 31 | 32 | def pre_process(self, q_true, param_type): 33 | # We want to pass something to nn.Parameter that can take *every* real value, not just positives. 34 | if param_type == "scalar": 35 | return misc.softplus_inv(q_true) 36 | elif param_type == "diag": 37 | return misc.softplus_inv(q_true) 38 | elif param_type == "tril": 39 | return torch.tril(q_true, diagonal=-1) + torch.diag(misc.softplus_inv(q_true.diag())) 40 | elif param_type == "full": 41 | return q_true 42 | 43 | def post_process(self, q, param_type): 44 | if param_type == "scalar": 45 | return misc.softplus(q) 46 | elif param_type == "diag": 47 | return misc.softplus(q) 48 | elif param_type == "tril": 49 | return torch.tril(q, diagonal=-1) + torch.diag(misc.softplus(q.diag())) 50 | elif param_type == "full": 51 | return q 52 | 53 | def forward(self, X): 54 | if self.param_type == "scalar": 55 | if self.q_shape is None: 56 | X = X + self.post_process(self.q, self.param_type) * torch.randn_like(X) 57 | else: 58 | chol = self.post_process(self.q, self.param_type) * torch.linalg.cholesky( 59 | torch.exp(self.q_shape * self.base)) 60 | X = X + torch.randn_like(X) @ chol.t() 61 | elif self.param_type == "diag": 62 | X = X + self.post_process(self.q, self.param_type) * torch.randn_like(X) # (x_dim) * (*bs, N_ensem, x_dim) 63 | elif self.param_type == "tril": 64 | # batch_shape = X.shape[:-1] 65 | chol = self.post_process(self.q, self.param_type) 66 | X = X + torch.randn_like(X) @ chol.t() # (*bs, N_ensem, x_dim) @ (x_dim, x_dim) 67 | # X = X + torch.distributions.MultivariateNormal(torch.zeros(self.x_dim, device=self.q.device), scale_tril=chol).sample(batch_shape) 68 | elif self.param_type == "full": 69 | batch_shape = X.shape[:-1] 70 | chol = torch.linalg.cholesky(self.q) 71 | X = X + torch.distributions.MultivariateNormal(torch.zeros(self.x_dim, device=self.q.device), 72 | scale_tril=chol).sample(batch_shape) # (*bs, N_ensem, x_dim) 73 | return X 74 | 75 | def chol(self): 76 | if self.param_type == "scalar": 77 | if self.q_shape is None: 78 | return self.post_process(self.q, self.param_type) * torch.eye(self.x_dim, device=self.q.device) 79 | else: 80 | return self.post_process(self.q, self.param_type) * torch.linalg.cholesky( 81 | torch.exp(self.q_shape * self.base)) 82 | elif self.param_type == "diag": 83 | return self.post_process(self.q, self.param_type) * torch.eye(self.x_dim, device=self.q.device) 84 | elif self.param_type == "tril": 85 | return self.post_process(self.q, self.param_type) 86 | elif self.param_type == "full": 87 | return torch.linalg.cholesky(self.q) 88 | 89 | def inv(self): 90 | if self.param_type == "scalar": 91 | return 1 / (self.post_process(self.q, self.param_type) ** 2) * torch.eye(self.x_dim,device=self.q.device) 92 | elif self.param_type == "diag": 93 | return 1 / (self.post_process(self.q, self.param_type) ** 2) * torch.eye(self.x_dim,device=self.q.device) 94 | elif self.param_type == "tril": 95 | return torch.cholesky_inverse(self.post_process(self.q, self.param_type)) 96 | elif self.param_type == "full": 97 | return torch.cholesky_inverse(torch.linalg.cholesky(self.q)) 98 | 99 | def logdet(self): 100 | if self.param_type == "scalar": 101 | return 2 * self.x_dim * torch.log(self.post_process(self.q, self.param_type)) 102 | elif self.param_type == "diag": 103 | return 2 * self.post_process(self.q, self.param_type).log().sum() 104 | elif self.param_type == "tril": 105 | return 2 * self.post_process(self.q, self.param_type).diagonal(dim1=-2,dim2=-1).log().sum(-1) 106 | elif self.param_type == "full": 107 | return 2 * torch.linalg.cholesky(self.q).diagonal(dim1=-2,dim2=-1).log().sum(-1) 108 | 109 | def full(self): 110 | chol = self.chol() 111 | return chol @ chol.t() 112 | 113 | def q_true(self): 114 | return self.post_process(self.q, self.param_type) 115 | 116 | def post_grad(self): 117 | # Some pytorch tricks to compute d(loss)/d(q_true) where q_true = post_process(self.q) 118 | leaf = self.post_process(self.q, self.param_type).detach().requires_grad_() 119 | q_sub = self.pre_process(leaf,self.param_type) # ideally should recover self.q, but we can compute d(q_sub)/d(leaf) 120 | q_sub.backward(gradient=self.q.grad) 121 | return leaf.grad 122 | --------------------------------------------------------------------------------