├── README.md ├── latentode ├── latent_ode │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── generate_timeseries.py │ ├── lib │ │ ├── base_models.py │ │ ├── create_latent_ode_model.py │ │ ├── diffeq_solver.py │ │ ├── encoder_decoder.py │ │ ├── latent_ode.py │ │ ├── likelihood_eval.py │ │ ├── ode_func.py │ │ ├── ode_rnn.py │ │ ├── parse_datasets.py │ │ ├── plotting.py │ │ ├── rnn_baselines.py │ │ └── utils.py │ ├── mujoco_physics.py │ ├── person_activity.py │ ├── physionet.py │ ├── run_models.py │ ├── train-activity.sh │ ├── train-ecg.sh │ ├── train-periodic100.sh │ └── train-periodic1000.sh └── ours_impl │ ├── __pycache__ │ └── lv_field.cpython-38.pyc │ └── lv_field.py ├── robotic ├── .combine_stats.py.swp ├── .ipynb_checkpoints │ └── Untitled-checkpoint.ipynb ├── .run_stiff_time.py.swp ├── 3D_Cshape_top.mat ├── S_gt_traj.npy ├── bench_output │ └── out.txt ├── cube_pick.npy └── run.py ├── rotating_MNIST ├── LV_stat_dic_0.tar ├── autoencoder_state_dic_0.tar ├── evaluate.py ├── readme.md ├── rotating3_2.pdf ├── run_rotate.py └── samples │ ├── t10k-images-idx3-ubyte │ ├── t10k-labels-idx1-ubyte │ ├── train-images-idx3-ubyte │ └── train-labels-idx1-ubyte ├── run-all.sh ├── stiff_ode ├── run_rober.py └── run_rober_comp_dopri.py └── systems ├── LV_run_theirs.py ├── LV_train_run_ours.py ├── LV_train_theirs.py ├── Untitled.ipynb ├── f_x_base_save_good_eod2.tar ├── inn2_save_good_eod2.tar ├── lor_run_all.py ├── lor_train_ours.py └── lor_train_theirs.py /README.md: -------------------------------------------------------------------------------- 1 | # ODElearning_INN 2 | [ICML 2022] Learning Efficient and Robust Ordinary Differential \\ Equations via Invertible Neural Networks 3 | 4 | Code requires Torch, Torchdiffeq (for neural ODEs), FrEIA (for invertible NNs), and for the latent ODE problems, borrows heavily from code from the original Latent ODE work. 5 | To understand the method, it is recommended to start by looking at the code in "systems". 6 | Code has yet to be refactored for clarity and readability -- this is coming soon. 7 | -------------------------------------------------------------------------------- /latentode/latent_ode/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /latentode/latent_ode/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yulia Rubanova 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /latentode/latent_ode/README.md: -------------------------------------------------------------------------------- 1 | # Latent ODEs for Irregularly-Sampled Time Series 2 | 3 | Code for the paper: 4 | > Yulia Rubanova, Ricky Chen, David Duvenaud. "Latent ODEs for Irregularly-Sampled Time Series" (2019) 5 | [[arxiv]](https://arxiv.org/abs/1907.03907) 6 | 7 |

8 | 9 |

10 | 11 | ## Prerequisites 12 | 13 | Install `torchdiffeq` from https://github.com/rtqichen/torchdiffeq. 14 | 15 | ## Experiments on different datasets 16 | 17 | By default, the dataset are downloadeded and processed when script is run for the first time. 18 | 19 | Raw datasets: 20 | [[MuJoCo]](http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt) 21 | [[Physionet]](https://physionet.org/physiobank/database/challenge/2012/) 22 | [[Human Activity]](https://archive.ics.uci.edu/ml/datasets/Localization+Data+for+Person+Activity/) 23 | 24 | To generate MuJoCo trajectories from scratch, [DeepMind Control Suite](https://github.com/deepmind/dm_control/) is required 25 | 26 | 27 | * Toy dataset of 1d periodic functions 28 | ``` 29 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 30 | ``` 31 | 32 | * MuJoCo 33 | 34 | ``` 35 | python3 run_models.py --niters 300 -n 10000 -l 15 --dataset hopper --latent-ode --rec-dims 30 --gru-units 100 --units 300 --gen-layers 3 --rec-layers 3 36 | ``` 37 | 38 | * Physionet (discretization by 1 min) 39 | ``` 40 | python3 run_models.py --niters 100 -n 8000 -l 20 --dataset physionet --latent-ode --rec-dims 40 --rec-layers 3 --gen-layers 3 --units 50 --gru-units 50 --quantization 0.016 --classif 41 | 42 | ``` 43 | 44 | * Human Activity 45 | ``` 46 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif 47 | 48 | ``` 49 | 50 | 51 | ### Running different models 52 | 53 | * ODE-RNN 54 | ``` 55 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --ode-rnn 56 | ``` 57 | 58 | * Latent ODE with ODE-RNN encoder 59 | ``` 60 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode 61 | ``` 62 | 63 | * Latent ODE with ODE-RNN encoder and poisson likelihood 64 | ``` 65 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --poisson 66 | ``` 67 | 68 | * Latent ODE with RNN encoder (Chen et al, 2018) 69 | ``` 70 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --z0-encoder rnn 71 | ``` 72 | 73 | * RNN-VAE 74 | ``` 75 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --rnn-vae 76 | ``` 77 | 78 | * Classic RNN 79 | ``` 80 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --classic-rnn 81 | ``` 82 | 83 | * GRU-D 84 | 85 | GRU-D consists of two parts: input imputation (--input-decay) and exponential decay of the hidden state (--rnn-cell expdecay) 86 | 87 | ``` 88 | python3 run_models.py --niters 500 -n 100 -b 30 -l 10 --dataset periodic --classic-rnn --input-decay --rnn-cell expdecay 89 | ``` 90 | 91 | 92 | ### Making the visualization 93 | ``` 94 | python3 run_models.py --niters 100 -n 5000 -b 100 -l 3 --dataset periodic --latent-ode --noise-weight 0.5 --lr 0.01 --viz --rec-layers 2 --gen-layers 2 -u 100 -c 30 95 | ``` 96 | -------------------------------------------------------------------------------- /latentode/latent_ode/generate_timeseries.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | # Create a synthetic dataset 7 | from __future__ import absolute_import, division 8 | from __future__ import print_function 9 | import os 10 | import matplotlib 11 | # if os.path.exists("/Users/yulia"): 12 | # matplotlib.use('TkAgg') 13 | # else: 14 | # matplotlib.use('Agg') 15 | 16 | import numpy as np 17 | import numpy.random as npr 18 | from scipy.special import expit as sigmoid 19 | import pickle 20 | import matplotlib.pyplot as plt 21 | import matplotlib.image 22 | import torch 23 | import lib.utils as utils 24 | 25 | # ====================================================================================== 26 | 27 | def get_next_val(init, t, tmin, tmax, final = None): 28 | if final is None: 29 | return init 30 | val = init + (final - init) / (tmax - tmin) * t 31 | return val 32 | 33 | 34 | def generate_periodic(time_steps, init_freq, init_amplitude, starting_point, 35 | final_freq = None, final_amplitude = None, phi_offset = 0.): 36 | 37 | tmin = time_steps.min() 38 | tmax = time_steps.max() 39 | 40 | data = [] 41 | t_prev = time_steps[0] 42 | phi = phi_offset 43 | for t in time_steps: 44 | dt = t - t_prev 45 | amp = get_next_val(init_amplitude, t, tmin, tmax, final_amplitude) 46 | freq = get_next_val(init_freq, t, tmin, tmax, final_freq) 47 | phi = phi + 2 * np.pi * freq * dt # integrate to get phase 48 | 49 | y = amp * np.sin(phi) + starting_point 50 | t_prev = t 51 | data.append([t,y]) 52 | return np.array(data) 53 | 54 | def assign_value_or_sample(value, sampling_interval = [0.,1.]): 55 | if value is None: 56 | int_length = sampling_interval[1] - sampling_interval[0] 57 | return np.random.random() * int_length + sampling_interval[0] 58 | else: 59 | return value 60 | 61 | class TimeSeries: 62 | def __init__(self, device = torch.device("cpu")): 63 | self.device = device 64 | self.z0 = None 65 | 66 | def init_visualization(self): 67 | self.fig = plt.figure(figsize=(10, 4), facecolor='white') 68 | self.ax = self.fig.add_subplot(111, frameon=False) 69 | plt.show(block=False) 70 | 71 | def visualize(self, truth): 72 | self.ax.plot(truth[:,0], truth[:,1]) 73 | 74 | def add_noise(self, traj_list, time_steps, noise_weight): 75 | n_samples = traj_list.size(0) 76 | 77 | # Add noise to all the points except the first point 78 | n_tp = len(time_steps) - 1 79 | noise = np.random.sample((n_samples, n_tp)) 80 | noise = torch.Tensor(noise).to(self.device) 81 | 82 | traj_list_w_noise = traj_list.clone() 83 | # Dimension [:,:,0] is a time dimension -- do not add noise to that 84 | traj_list_w_noise[:,1:,0] += noise_weight * noise 85 | return traj_list_w_noise 86 | 87 | 88 | 89 | class Periodic_1d(TimeSeries): 90 | def __init__(self, device = torch.device("cpu"), 91 | init_freq = 0.3, init_amplitude = 1., 92 | final_amplitude = 10., final_freq = 1., 93 | z0 = 0.): 94 | """ 95 | If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled. 96 | For now, all the time series share the time points and the starting point. 97 | """ 98 | super(Periodic_1d, self).__init__(device) 99 | 100 | self.init_freq = init_freq 101 | self.init_amplitude = init_amplitude 102 | self.final_amplitude = final_amplitude 103 | self.final_freq = final_freq 104 | self.z0 = z0 105 | 106 | def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1., 107 | cut_out_section = None): 108 | """ 109 | Sample periodic functions. 110 | """ 111 | traj_list = [] 112 | for i in range(n_samples): 113 | init_freq = assign_value_or_sample(self.init_freq, [0.4,0.8]) 114 | if self.final_freq is None: 115 | final_freq = init_freq 116 | else: 117 | final_freq = assign_value_or_sample(self.final_freq, [0.4,0.8]) 118 | init_amplitude = assign_value_or_sample(self.init_amplitude, [0.,1.]) 119 | final_amplitude = assign_value_or_sample(self.final_amplitude, [0.,1.]) 120 | 121 | noisy_z0 = self.z0 + np.random.normal(loc=0., scale=0.1) 122 | 123 | traj = generate_periodic(time_steps, init_freq = init_freq, 124 | init_amplitude = init_amplitude, starting_point = noisy_z0, 125 | final_amplitude = final_amplitude, final_freq = final_freq) 126 | 127 | # Cut the time dimension 128 | traj = np.expand_dims(traj[:,1:], 0) 129 | traj_list.append(traj) 130 | 131 | # shape: [n_samples, n_timesteps, 2] 132 | # traj_list[:,:,0] -- time stamps 133 | # traj_list[:,:,1] -- values at the time stamps 134 | traj_list = np.array(traj_list) 135 | traj_list = torch.Tensor().new_tensor(traj_list, device = self.device) 136 | traj_list = traj_list.squeeze(1) 137 | 138 | traj_list = self.add_noise(traj_list, time_steps, noise_weight) 139 | return traj_list 140 | 141 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/base_models.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | 11 | import lib.utils as utils 12 | from lib.encoder_decoder import * 13 | from lib.likelihood_eval import * 14 | 15 | from torch.distributions.multivariate_normal import MultivariateNormal 16 | from torch.distributions.normal import Normal 17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 18 | 19 | from torch.distributions.normal import Normal 20 | from torch.distributions import Independent 21 | from torch.nn.parameter import Parameter 22 | 23 | 24 | def create_classifier(z0_dim, n_labels): 25 | return nn.Sequential( 26 | nn.Linear(z0_dim, 300), 27 | nn.ReLU(), 28 | nn.Linear(300, 300), 29 | nn.ReLU(), 30 | nn.Linear(300, n_labels),) 31 | 32 | 33 | class Baseline(nn.Module): 34 | def __init__(self, input_dim, latent_dim, device, 35 | obsrv_std = 0.01, use_binary_classif = False, 36 | classif_per_tp = False, 37 | use_poisson_proc = False, 38 | linear_classifier = False, 39 | n_labels = 1, 40 | train_classif_w_reconstr = False): 41 | super(Baseline, self).__init__() 42 | 43 | self.input_dim = input_dim 44 | self.latent_dim = latent_dim 45 | self.n_labels = n_labels 46 | 47 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device) 48 | self.device = device 49 | 50 | self.use_binary_classif = use_binary_classif 51 | self.classif_per_tp = classif_per_tp 52 | self.use_poisson_proc = use_poisson_proc 53 | self.linear_classifier = linear_classifier 54 | self.train_classif_w_reconstr = train_classif_w_reconstr 55 | 56 | z0_dim = latent_dim 57 | if use_poisson_proc: 58 | z0_dim += latent_dim 59 | 60 | if use_binary_classif: 61 | if linear_classifier: 62 | self.classifier = nn.Sequential( 63 | nn.Linear(z0_dim, n_labels)) 64 | else: 65 | self.classifier = create_classifier(z0_dim, n_labels) 66 | utils.init_network_weights(self.classifier) 67 | 68 | 69 | def get_gaussian_likelihood(self, truth, pred_y, mask = None): 70 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 71 | # truth shape [n_traj, n_tp, n_dim] 72 | if mask is not None: 73 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 74 | 75 | # Compute likelihood of the data under the predictions 76 | log_density_data = masked_gaussian_log_density(pred_y, truth, 77 | obsrv_std = self.obsrv_std, mask = mask) 78 | log_density_data = log_density_data.permute(1,0) 79 | 80 | # Compute the total density 81 | # Take mean over n_traj_samples 82 | log_density = torch.mean(log_density_data, 0) 83 | 84 | # shape: [n_traj] 85 | return log_density 86 | 87 | 88 | def get_mse(self, truth, pred_y, mask = None): 89 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 90 | # truth shape [n_traj, n_tp, n_dim] 91 | if mask is not None: 92 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 93 | 94 | # Compute likelihood of the data under the predictions 95 | log_density_data = compute_mse(pred_y, truth, mask = mask) 96 | # shape: [1] 97 | return torch.mean(log_density_data) 98 | 99 | 100 | def compute_all_losses(self, batch_dict, 101 | n_tp_to_sample = None, n_traj_samples = 1, kl_coef = 1.): 102 | 103 | # Condition on subsampled points 104 | # Make predictions for all the points 105 | pred_x, info = self.get_reconstruction(batch_dict["tp_to_predict"], 106 | batch_dict["observed_data"], batch_dict["observed_tp"], 107 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples, 108 | mode = batch_dict["mode"]) 109 | 110 | # Compute likelihood of all the points 111 | likelihood = self.get_gaussian_likelihood(batch_dict["data_to_predict"], pred_x, 112 | mask = batch_dict["mask_predicted_data"]) 113 | 114 | mse = self.get_mse(batch_dict["data_to_predict"], pred_x, 115 | mask = batch_dict["mask_predicted_data"]) 116 | 117 | ################################ 118 | # Compute CE loss for binary classification on Physionet 119 | # Use only last attribute -- mortatility in the hospital 120 | device = get_device(batch_dict["data_to_predict"]) 121 | ce_loss = torch.Tensor([0.]).to(device) 122 | 123 | if (batch_dict["labels"] is not None) and self.use_binary_classif: 124 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1): 125 | ce_loss = compute_binary_CE_loss( 126 | info["label_predictions"], 127 | batch_dict["labels"]) 128 | else: 129 | ce_loss = compute_multiclass_CE_loss( 130 | info["label_predictions"], 131 | batch_dict["labels"], 132 | mask = batch_dict["mask_predicted_data"]) 133 | 134 | if torch.isnan(ce_loss): 135 | print("label pred") 136 | print(info["label_predictions"]) 137 | print("labels") 138 | print( batch_dict["labels"]) 139 | raise Exception("CE loss is Nan!") 140 | 141 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"])) 142 | if self.use_poisson_proc: 143 | pois_log_likelihood = compute_poisson_proc_likelihood( 144 | batch_dict["data_to_predict"], pred_x, 145 | info, mask = batch_dict["mask_predicted_data"]) 146 | # Take mean over n_traj 147 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1) 148 | 149 | loss = - torch.mean(likelihood) 150 | 151 | if self.use_poisson_proc: 152 | loss = loss - 0.1 * pois_log_likelihood 153 | 154 | if self.use_binary_classif: 155 | if self.train_classif_w_reconstr: 156 | loss = loss + ce_loss * 100 157 | else: 158 | loss = ce_loss 159 | 160 | # Take mean over the number of samples in a batch 161 | results = {} 162 | results["loss"] = torch.mean(loss) 163 | results["likelihood"] = torch.mean(likelihood).detach() 164 | results["mse"] = torch.mean(mse).detach() 165 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach() 166 | results["ce_loss"] = torch.mean(ce_loss).detach() 167 | results["kl"] = 0. 168 | results["kl_first_p"] = 0. 169 | results["std_first_p"] = 0. 170 | 171 | if batch_dict["labels"] is not None and self.use_binary_classif: 172 | results["label_predictions"] = info["label_predictions"].detach() 173 | return results 174 | 175 | 176 | 177 | class VAE_Baseline(nn.Module): 178 | def __init__(self, input_dim, latent_dim, 179 | z0_prior, device, 180 | obsrv_std = 0.01, 181 | use_binary_classif = False, 182 | classif_per_tp = False, 183 | use_poisson_proc = False, 184 | linear_classifier = False, 185 | n_labels = 1, 186 | train_classif_w_reconstr = False): 187 | 188 | super(VAE_Baseline, self).__init__() 189 | 190 | self.input_dim = input_dim 191 | self.latent_dim = latent_dim 192 | self.device = device 193 | self.n_labels = n_labels 194 | 195 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device) 196 | 197 | self.z0_prior = z0_prior 198 | self.use_binary_classif = use_binary_classif 199 | self.classif_per_tp = classif_per_tp 200 | self.use_poisson_proc = use_poisson_proc 201 | self.linear_classifier = linear_classifier 202 | self.train_classif_w_reconstr = train_classif_w_reconstr 203 | 204 | z0_dim = latent_dim 205 | if use_poisson_proc: 206 | z0_dim += latent_dim 207 | 208 | if use_binary_classif: 209 | if linear_classifier: 210 | self.classifier = nn.Sequential( 211 | nn.Linear(z0_dim, n_labels)) 212 | else: 213 | self.classifier = create_classifier(z0_dim, n_labels) 214 | utils.init_network_weights(self.classifier) 215 | 216 | 217 | def get_gaussian_likelihood(self, truth, pred_y, mask = None): 218 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 219 | # truth shape [n_traj, n_tp, n_dim] 220 | n_traj, n_tp, n_dim = truth.size() 221 | 222 | # Compute likelihood of the data under the predictions 223 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 224 | 225 | if mask is not None: 226 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 227 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated, 228 | obsrv_std = self.obsrv_std, mask = mask) 229 | log_density_data = log_density_data.permute(1,0) 230 | log_density = torch.mean(log_density_data, 1) 231 | 232 | # shape: [n_traj_samples] 233 | return log_density 234 | 235 | 236 | def get_mse(self, truth, pred_y, mask = None): 237 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 238 | # truth shape [n_traj, n_tp, n_dim] 239 | n_traj, n_tp, n_dim = truth.size() 240 | 241 | # Compute likelihood of the data under the predictions 242 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 243 | 244 | if mask is not None: 245 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 246 | 247 | # Compute likelihood of the data under the predictions 248 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask) 249 | # shape: [1] 250 | return torch.mean(log_density_data) 251 | 252 | 253 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1., weight=None): 254 | # Condition on subsampled points 255 | # Make predictions for all the points 256 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"], 257 | batch_dict["observed_data"], batch_dict["observed_tp"], 258 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples, 259 | mode = batch_dict["mode"]) 260 | 261 | #print("get_reconstruction done -- computing likelihood") 262 | fp_mu, fp_std, fp_enc = info["first_point"] 263 | fp_std = fp_std.abs() 264 | fp_distr = Normal(fp_mu, fp_std) 265 | 266 | assert(torch.sum(fp_std < 0) == 0.) 267 | 268 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior) 269 | 270 | if torch.isnan(kldiv_z0).any(): 271 | print(fp_mu) 272 | print(fp_std) 273 | raise Exception("kldiv_z0 is Nan!") 274 | 275 | # Mean over number of latent dimensions 276 | # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated) 277 | # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly) 278 | # shape after: [n_traj_samples] 279 | kldiv_z0 = torch.mean(kldiv_z0,(1,2)) 280 | 281 | # Compute likelihood of all the points 282 | rec_likelihood = self.get_gaussian_likelihood( 283 | batch_dict["data_to_predict"], pred_y, 284 | mask = batch_dict["mask_predicted_data"]) 285 | 286 | mse = self.get_mse( 287 | batch_dict["data_to_predict"], pred_y, 288 | mask = batch_dict["mask_predicted_data"]) 289 | 290 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"])) 291 | if self.use_poisson_proc: 292 | pois_log_likelihood = compute_poisson_proc_likelihood( 293 | batch_dict["data_to_predict"], pred_y, 294 | info, mask = batch_dict["mask_predicted_data"]) 295 | # Take mean over n_traj 296 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1) 297 | 298 | ################################ 299 | # Compute CE loss for binary classification on Physionet 300 | device = get_device(batch_dict["data_to_predict"]) 301 | ce_loss = torch.Tensor([0.]).to(device) 302 | # print(info["label_predictions"]) 303 | # print(batch_dict["labels"]) 304 | # print(batch_dict["mask_predicted_data"]) 305 | # print(info["label_predictions"].shape) 306 | # print(batch_dict["labels"].shape) 307 | # print(batch_dict["mask_predicted_data"].shape) 308 | # exit() 309 | if (batch_dict["labels"] is not None) and self.use_binary_classif: 310 | 311 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1): 312 | ce_loss = compute_binary_CE_loss( 313 | info["label_predictions"], 314 | batch_dict["labels"]) 315 | else: 316 | ce_loss = compute_multiclass_CE_loss( 317 | info["label_predictions"], 318 | batch_dict["labels"], 319 | mask = batch_dict["mask_predicted_data"], 320 | weight=weight) 321 | 322 | # IWAE loss 323 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0) 324 | if torch.isnan(loss): 325 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0) 326 | 327 | if self.use_poisson_proc: 328 | loss = loss - 0.1 * pois_log_likelihood 329 | 330 | if self.use_binary_classif: 331 | if self.train_classif_w_reconstr: 332 | loss = loss + ce_loss * 100 333 | else: 334 | loss = ce_loss 335 | 336 | results = {} 337 | results["loss"] = torch.mean(loss) 338 | results["likelihood"] = torch.mean(rec_likelihood).detach() 339 | results["mse"] = torch.mean(mse).detach() 340 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach() 341 | results["ce_loss"] = torch.mean(ce_loss).detach() 342 | results["kl_first_p"] = torch.mean(kldiv_z0).detach() 343 | results["std_first_p"] = torch.mean(fp_std).detach() 344 | 345 | if batch_dict["labels"] is not None and self.use_binary_classif: 346 | results["label_predictions"] = info["label_predictions"].detach() 347 | 348 | return results 349 | 350 | 351 | 352 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/create_latent_ode_model.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import os 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.functional import relu 12 | 13 | import lib.utils as utils 14 | from lib.latent_ode import LatentODE 15 | from lib.encoder_decoder import * 16 | from lib.diffeq_solver import DiffeqSolver 17 | 18 | import sys 19 | sys.path.insert(0, '../ours_impl') 20 | from lv_field import LVField 21 | 22 | from torch.distributions.normal import Normal 23 | from lib.ode_func import ODEFunc, ODEFunc_w_Poisson 24 | 25 | ##################################################################################################### 26 | 27 | def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, 28 | classif_per_tp = False, n_labels = 1, 29 | BLOB_OF_MODEL_SETTINGS=None 30 | ): 31 | 32 | dim = args.latents 33 | if args.poisson: 34 | lambda_net = utils.create_net(dim, input_dim, 35 | n_layers = 1, n_units = args.units, nonlinear = nn.Tanh) 36 | 37 | # ODE function produces the gradient for latent state and for poisson rate 38 | ode_func_net = utils.create_net(dim * 2, args.latents * 2, 39 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh) 40 | 41 | gen_ode_func = ODEFunc_w_Poisson( 42 | input_dim = input_dim, 43 | latent_dim = args.latents * 2, 44 | ode_func_net = ode_func_net, 45 | lambda_net = lambda_net, 46 | device = device).to(device) 47 | else: 48 | dim = args.latents 49 | ode_func_net = utils.create_net(dim, args.latents, 50 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh) 51 | 52 | gen_ode_func = ODEFunc( 53 | input_dim = input_dim, 54 | latent_dim = args.latents, 55 | ode_func_net = ode_func_net, 56 | device = device).to(device) 57 | 58 | z0_diffeq_solver = None 59 | n_rec_dims = args.rec_dims 60 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask 61 | gen_data_dim = input_dim 62 | 63 | z0_dim = args.latents 64 | if args.poisson: 65 | z0_dim += args.latents # predict the initial poisson rate 66 | 67 | if args.z0_encoder == "odernn": 68 | ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, 69 | n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh) 70 | 71 | 72 | # using = 'ours' 73 | # using = 'theirs' 74 | 75 | 76 | # if using == 'ours': 77 | 78 | # z0_diffeq_solver = LVField( 79 | # dim=enc_input_dim * args.latents, 80 | # augmented_dim=input_dim * args.latents + 5, 81 | # num_layers=3, 82 | # ) 83 | # # pred_y = self.lv_field(first_point, time_steps_to_predict) 84 | 85 | 86 | # elif using == 'theirs': 87 | 88 | if True: 89 | """ 90 | We won't be using this as the integration time-steps within encoder tends to be very short. 91 | """ 92 | rec_ode_func = ODEFunc( 93 | input_dim = enc_input_dim, 94 | latent_dim = n_rec_dims, 95 | ode_func_net = ode_func_net, 96 | device = device).to(device) 97 | 98 | z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents, 99 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device) 100 | 101 | 102 | encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver, 103 | z0_dim = z0_dim, n_gru_units = args.gru_units, device = device).to(device) 104 | 105 | 106 | 107 | elif args.z0_encoder == "rnn": 108 | encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim, 109 | lstm_output_size = n_rec_dims, device = device).to(device) 110 | else: 111 | raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder) 112 | 113 | decoder = Decoder(args.latents, gen_data_dim).to(device) 114 | 115 | if BLOB_OF_MODEL_SETTINGS['model'] == 'node': 116 | diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 117 | BLOB_OF_MODEL_SETTINGS['node__int_method'], 118 | # 'euler', 119 | args.latents, 120 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device) 121 | 122 | elif BLOB_OF_MODEL_SETTINGS['model'] == 'ours': 123 | e_vec_factor = 1.0 124 | t_d_factor = 1.0 125 | if BLOB_OF_MODEL_SETTINGS['dataset'] == 'activity': 126 | e_vec_factor = 0.001 127 | t_d_factor = 0.1 128 | elif BLOB_OF_MODEL_SETTINGS['dataset'] == 'periodic': 129 | if BLOB_OF_MODEL_SETTINGS['timepoints'] == 1000: 130 | e_vec_factor = 1.0 131 | t_d_factor = 0.01 132 | diffeq_solver = LVField( 133 | dim=args.latents, 134 | augmented_dim=args.latents + BLOB_OF_MODEL_SETTINGS['ours__extra_augnmented_dim'], 135 | num_layers=BLOB_OF_MODEL_SETTINGS['ours__num_layers'], 136 | hidden_dim=BLOB_OF_MODEL_SETTINGS['ours__inn_hidden_dim'], 137 | 138 | e_vec_factor=e_vec_factor, 139 | t_d_factor=t_d_factor, 140 | ) 141 | 142 | model = LatentODE( 143 | input_dim = gen_data_dim, 144 | latent_dim = args.latents, 145 | encoder_z0 = encoder_z0, 146 | decoder = decoder, 147 | diffeq_solver = diffeq_solver, 148 | z0_prior = z0_prior, 149 | device = device, 150 | obsrv_std = obsrv_std, 151 | use_poisson_proc = args.poisson, 152 | use_binary_classif = args.classif, 153 | linear_classifier = args.linear_classif, 154 | classif_per_tp = classif_per_tp, 155 | n_labels = n_labels, 156 | train_classif_w_reconstr = (args.dataset == "physionet") 157 | ).to(device) 158 | 159 | return model 160 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/diffeq_solver.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import time 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | import lib.utils as utils 13 | from torch.distributions.multivariate_normal import MultivariateNormal 14 | 15 | # git clone https://github.com/rtqichen/torchdiffeq.git 16 | from torchdiffeq import odeint as odeint 17 | 18 | ##################################################################################################### 19 | 20 | class DiffeqSolver(nn.Module): 21 | def __init__(self, input_dim, ode_func, method, latents, 22 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")): 23 | super(DiffeqSolver, self).__init__() 24 | 25 | self.ode_method = method 26 | self.latents = latents 27 | self.device = device 28 | self.ode_func = ode_func 29 | 30 | self.odeint_rtol = odeint_rtol 31 | self.odeint_atol = odeint_atol 32 | 33 | def forward(self, first_point, time_steps_to_predict, backwards = False): 34 | """ 35 | # Decode the trajectory through ODE Solver 36 | """ 37 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 38 | n_dims = first_point.size()[-1] 39 | 40 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict, 41 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 42 | pred_y = pred_y.permute(1,2,0,3) 43 | 44 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001) 45 | assert(pred_y.size()[0] == n_traj_samples) 46 | assert(pred_y.size()[1] == n_traj) 47 | 48 | return pred_y 49 | 50 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict, 51 | n_traj_samples = 1): 52 | """ 53 | # Decode the trajectory through ODE Solver using samples from the prior 54 | 55 | time_steps_to_predict: time steps at which we want to sample the new trajectory 56 | """ 57 | func = self.ode_func.sample_next_point_from_prior 58 | 59 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict, 60 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 61 | # shape: [n_traj_samples, n_traj, n_tp, n_dim] 62 | pred_y = pred_y.permute(1,2,0,3) 63 | return pred_y 64 | 65 | 66 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | import lib.utils as utils 11 | from torch.distributions import Categorical, Normal 12 | import lib.utils as utils 13 | from torch.nn.modules.rnn import LSTM, GRU 14 | from lib.utils import get_device 15 | 16 | 17 | # GRU description: 18 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/ 19 | class GRU_unit(nn.Module): 20 | def __init__(self, latent_dim, input_dim, 21 | update_gate = None, 22 | reset_gate = None, 23 | new_state_net = None, 24 | n_units = 100, 25 | device = torch.device("cpu")): 26 | super(GRU_unit, self).__init__() 27 | 28 | if update_gate is None: 29 | self.update_gate = nn.Sequential( 30 | nn.Linear(latent_dim * 2 + input_dim, n_units), 31 | nn.Tanh(), 32 | nn.Linear(n_units, latent_dim), 33 | nn.Sigmoid()) 34 | utils.init_network_weights(self.update_gate) 35 | else: 36 | self.update_gate = update_gate 37 | 38 | if reset_gate is None: 39 | self.reset_gate = nn.Sequential( 40 | nn.Linear(latent_dim * 2 + input_dim, n_units), 41 | nn.Tanh(), 42 | nn.Linear(n_units, latent_dim), 43 | nn.Sigmoid()) 44 | utils.init_network_weights(self.reset_gate) 45 | else: 46 | self.reset_gate = reset_gate 47 | 48 | if new_state_net is None: 49 | self.new_state_net = nn.Sequential( 50 | nn.Linear(latent_dim * 2 + input_dim, n_units), 51 | nn.Tanh(), 52 | nn.Linear(n_units, latent_dim * 2)) 53 | utils.init_network_weights(self.new_state_net) 54 | else: 55 | self.new_state_net = new_state_net 56 | 57 | 58 | def forward(self, y_mean, y_std, x, masked_update = True): 59 | y_concat = torch.cat([y_mean, y_std, x], -1) 60 | 61 | update_gate = self.update_gate(y_concat) 62 | reset_gate = self.reset_gate(y_concat) 63 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) 64 | 65 | new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat)) 66 | new_state_std = new_state_std.abs() 67 | 68 | new_y = (1-update_gate) * new_state + update_gate * y_mean 69 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std 70 | 71 | assert(not torch.isnan(new_y).any()) 72 | 73 | if masked_update: 74 | # IMPORTANT: assumes that x contains both data and mask 75 | # update only the hidden states for hidden state only if at least one feature is present for the current time point 76 | n_data_dims = x.size(-1)//2 77 | mask = x[:, :, n_data_dims:] 78 | utils.check_mask(x[:, :, :n_data_dims], mask) 79 | 80 | mask = (torch.sum(mask, -1, keepdim = True) > 0).float() 81 | 82 | assert(not torch.isnan(mask).any()) 83 | 84 | new_y = mask * new_y + (1-mask) * y_mean 85 | new_y_std = mask * new_y_std + (1-mask) * y_std 86 | 87 | if torch.isnan(new_y).any(): 88 | print("new_y is nan!") 89 | print(mask) 90 | print(y_mean) 91 | print(prev_new_y) 92 | exit() 93 | 94 | new_y_std = new_y_std.abs() 95 | return new_y, new_y_std 96 | 97 | 98 | 99 | class Encoder_z0_RNN(nn.Module): 100 | def __init__(self, latent_dim, input_dim, lstm_output_size = 20, 101 | use_delta_t = True, device = torch.device("cpu")): 102 | 103 | super(Encoder_z0_RNN, self).__init__() 104 | 105 | self.gru_rnn_output_size = lstm_output_size 106 | self.latent_dim = latent_dim 107 | self.input_dim = input_dim 108 | self.device = device 109 | self.use_delta_t = use_delta_t 110 | 111 | self.hiddens_to_z0 = nn.Sequential( 112 | nn.Linear(self.gru_rnn_output_size, 50), 113 | nn.Tanh(), 114 | nn.Linear(50, latent_dim * 2),) 115 | 116 | utils.init_network_weights(self.hiddens_to_z0) 117 | 118 | input_dim = self.input_dim 119 | 120 | if use_delta_t: 121 | self.input_dim += 1 122 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device) 123 | 124 | def forward(self, data, time_steps, run_backwards = True): 125 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 126 | 127 | # data shape: [n_traj, n_tp, n_dims] 128 | # shape required for rnn: (seq_len, batch, input_size) 129 | # t0: not used here 130 | n_traj = data.size(0) 131 | 132 | assert(not torch.isnan(data).any()) 133 | assert(not torch.isnan(time_steps).any()) 134 | 135 | data = data.permute(1,0,2) 136 | 137 | if run_backwards: 138 | # Look at data in the reverse order: from later points to the first 139 | data = utils.reverse(data) 140 | 141 | if self.use_delta_t: 142 | delta_t = time_steps[1:] - time_steps[:-1] 143 | if run_backwards: 144 | # we are going backwards in time with 145 | delta_t = utils.reverse(delta_t) 146 | # append zero delta t in the end 147 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) 148 | delta_t = delta_t.unsqueeze(1).repeat((1,n_traj)).unsqueeze(-1) 149 | data = torch.cat((delta_t, data),-1) 150 | 151 | outputs, _ = self.gru_rnn(data) 152 | 153 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size) 154 | last_output = outputs[-1] 155 | 156 | self.extra_info ={"rnn_outputs": outputs, "time_points": time_steps} 157 | 158 | mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) 159 | std = std.abs() 160 | 161 | assert(not torch.isnan(mean).any()) 162 | assert(not torch.isnan(std).any()) 163 | 164 | return mean.unsqueeze(0), std.unsqueeze(0) 165 | 166 | 167 | 168 | 169 | 170 | class Encoder_z0_ODE_RNN(nn.Module): 171 | # Derive z0 by running ode backwards. 172 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i 173 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1 174 | # Continue until we get to z0 175 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None, 176 | z0_dim = None, GRU_update = None, 177 | n_gru_units = 100, 178 | device = torch.device("cpu")): 179 | 180 | super(Encoder_z0_ODE_RNN, self).__init__() 181 | 182 | if z0_dim is None: 183 | self.z0_dim = latent_dim 184 | else: 185 | self.z0_dim = z0_dim 186 | 187 | if GRU_update is None: 188 | self.GRU_update = GRU_unit(latent_dim, input_dim, 189 | n_units = n_gru_units, 190 | device=device).to(device) 191 | else: 192 | self.GRU_update = GRU_update 193 | 194 | self.z0_diffeq_solver = z0_diffeq_solver 195 | self.latent_dim = latent_dim 196 | self.input_dim = input_dim 197 | self.device = device 198 | self.extra_info = None 199 | 200 | self.transform_z0 = nn.Sequential( 201 | nn.Linear(latent_dim * 2, 100), 202 | nn.Tanh(), 203 | nn.Linear(100, self.z0_dim * 2),) 204 | utils.init_network_weights(self.transform_z0) 205 | 206 | 207 | def forward(self, data, time_steps, run_backwards = True, save_info = False): 208 | # data, time_steps -- observations and their time stamps 209 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 210 | assert(not torch.isnan(data).any()) 211 | assert(not torch.isnan(time_steps).any()) 212 | 213 | n_traj, n_tp, n_dims = data.size() 214 | if len(time_steps) == 1: 215 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 216 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 217 | 218 | xi = data[:,0,:].unsqueeze(0) 219 | 220 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi) 221 | extra_info = None 222 | else: 223 | 224 | last_yi, last_yi_std, _, extra_info = self.run_odernn( 225 | data, time_steps, run_backwards = run_backwards, 226 | save_info = save_info) 227 | 228 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim) 229 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim) 230 | 231 | mean_z0, std_z0 = utils.split_last_dim( self.transform_z0( torch.cat((means_z0, std_z0), -1))) 232 | std_z0 = std_z0.abs() 233 | if save_info: 234 | self.extra_info = extra_info 235 | 236 | return mean_z0, std_z0 237 | 238 | 239 | def run_odernn(self, data, time_steps, 240 | run_backwards = True, save_info = False): 241 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 242 | 243 | n_traj, n_tp, n_dims = data.size() 244 | extra_info = [] 245 | 246 | t0 = time_steps[-1] 247 | if run_backwards: 248 | t0 = time_steps[0] 249 | 250 | device = get_device(data) 251 | 252 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) 253 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) 254 | 255 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] 256 | 257 | interval_length = time_steps[-1] - time_steps[0] 258 | minimum_step = interval_length / 50 259 | 260 | #print("minimum step: {}".format(minimum_step)) 261 | 262 | assert(not torch.isnan(data).any()) 263 | assert(not torch.isnan(time_steps).any()) 264 | 265 | latent_ys = [] 266 | # Run ODE backwards and combine the y(t) estimates using gating 267 | time_points_iter = range(0, len(time_steps)) 268 | if run_backwards: 269 | time_points_iter = reversed(time_points_iter) 270 | 271 | for i in time_points_iter: 272 | """ 273 | if (prev_t - t_i) < minimum_step: 274 | 275 | time_points = torch.stack((prev_t, t_i)) 276 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) 277 | 278 | assert(not torch.isnan(inc).any()) 279 | 280 | ode_sol = prev_y + inc 281 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) 282 | 283 | assert(not torch.isnan(ode_sol).any()) 284 | else: 285 | """ 286 | # Skip doing the above hacks 287 | if 1: 288 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) 289 | 290 | time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp).to(device) 291 | ode_sol = self.z0_diffeq_solver(prev_y, time_points) 292 | 293 | assert(not torch.isnan(ode_sol).any()) 294 | 295 | if False and torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: 296 | print("Error: first point of the ODE is not equal to initial value") 297 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) 298 | exit() 299 | #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) 300 | 301 | yi_ode = ode_sol[:, :, -1, :] 302 | xi = data[:,i,:].unsqueeze(0) 303 | 304 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) 305 | 306 | prev_y, prev_std = yi, yi_std 307 | prev_t, t_i = time_steps[i], time_steps[i-1] 308 | 309 | latent_ys.append(yi) 310 | 311 | if save_info: 312 | d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data, 313 | "yi": yi.detach(), "yi_std": yi_std.detach(), 314 | "time_points": time_points.detach(), "ode_sol": ode_sol.detach()} 315 | extra_info.append(d) 316 | 317 | latent_ys = torch.stack(latent_ys, 1) 318 | 319 | assert(not torch.isnan(yi).any()) 320 | assert(not torch.isnan(yi_std).any()) 321 | 322 | return yi, yi_std, latent_ys, extra_info 323 | 324 | 325 | 326 | class Decoder(nn.Module): 327 | def __init__(self, latent_dim, input_dim): 328 | super(Decoder, self).__init__() 329 | # decode data from latent space where we are solving an ODE back to the data space 330 | 331 | decoder = nn.Sequential( 332 | nn.Linear(latent_dim, input_dim),) 333 | 334 | utils.init_network_weights(decoder) 335 | self.decoder = decoder 336 | 337 | def forward(self, data): 338 | return self.decoder(data) 339 | 340 | 341 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/latent_ode.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import sklearn as sk 8 | import numpy as np 9 | #import gc 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.functional import relu 13 | 14 | import lib.utils as utils 15 | from lib.utils import get_device 16 | from lib.encoder_decoder import * 17 | from lib.likelihood_eval import * 18 | 19 | from torch.distributions.multivariate_normal import MultivariateNormal 20 | from torch.distributions.normal import Normal 21 | from torch.distributions import kl_divergence, Independent 22 | from lib.base_models import VAE_Baseline 23 | 24 | 25 | 26 | class LatentODE(VAE_Baseline): 27 | def __init__(self, input_dim, latent_dim, encoder_z0, decoder, diffeq_solver, 28 | z0_prior, device, obsrv_std = None, 29 | use_binary_classif = False, use_poisson_proc = False, 30 | linear_classifier = False, 31 | classif_per_tp = False, 32 | n_labels = 1, 33 | train_classif_w_reconstr = False): 34 | 35 | super(LatentODE, self).__init__( 36 | input_dim = input_dim, latent_dim = latent_dim, 37 | z0_prior = z0_prior, 38 | device = device, obsrv_std = obsrv_std, 39 | use_binary_classif = use_binary_classif, 40 | classif_per_tp = classif_per_tp, 41 | linear_classifier = linear_classifier, 42 | use_poisson_proc = use_poisson_proc, 43 | n_labels = n_labels, 44 | train_classif_w_reconstr = train_classif_w_reconstr) 45 | 46 | self.encoder_z0 = encoder_z0 47 | self.diffeq_solver = diffeq_solver 48 | self.decoder = decoder 49 | self.use_poisson_proc = use_poisson_proc 50 | 51 | def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, 52 | mask = None, n_traj_samples = 1, run_backwards = True, mode = None): 53 | 54 | if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \ 55 | isinstance(self.encoder_z0, Encoder_z0_RNN): 56 | 57 | truth_w_mask = truth 58 | if mask is not None: 59 | truth_w_mask = torch.cat((truth, mask), -1) 60 | first_point_mu, first_point_std = self.encoder_z0( 61 | truth_w_mask, truth_time_steps, run_backwards = run_backwards) 62 | 63 | means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) 64 | sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) 65 | first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0) 66 | 67 | else: 68 | raise Exception("Unknown encoder type {}".format(type(self.encoder_z0).__name__)) 69 | 70 | first_point_std = first_point_std.abs() 71 | assert(torch.sum(first_point_std < 0) == 0.) 72 | 73 | if self.use_poisson_proc: 74 | n_traj_samples, n_traj, n_dims = first_point_enc.size() 75 | # append a vector of zeros to compute the integral of lambda 76 | zeros = torch.zeros([n_traj_samples, n_traj,self.input_dim]).to(get_device(truth)) 77 | first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) 78 | means_z0_aug = torch.cat((means_z0, zeros), -1) 79 | else: 80 | first_point_enc_aug = first_point_enc 81 | means_z0_aug = means_z0 82 | 83 | assert(not torch.isnan(time_steps_to_predict).any()) 84 | assert(not torch.isnan(first_point_enc).any()) 85 | assert(not torch.isnan(first_point_enc_aug).any()) 86 | 87 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] 88 | 89 | # from soraxas_toolbox import timer 90 | # timer.stamp("diff") 91 | sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) 92 | # timer.stamp("rest") 93 | 94 | if self.use_poisson_proc: 95 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 96 | 97 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.) 98 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 99 | 100 | pred_x = self.decoder(sol_y) 101 | 102 | all_extra_info = { 103 | "first_point": (first_point_mu, first_point_std, first_point_enc), 104 | "latent_traj": sol_y.detach() 105 | } 106 | 107 | if self.use_poisson_proc: 108 | # intergral of lambda from the last step of ODE Solver 109 | all_extra_info["int_lambda"] = int_lambda[:,:,-1,:] 110 | all_extra_info["log_lambda_y"] = log_lambda_y 111 | 112 | if self.use_binary_classif: 113 | if self.classif_per_tp: 114 | all_extra_info["label_predictions"] = self.classifier(sol_y) 115 | else: 116 | all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1) 117 | 118 | return pred_x, all_extra_info 119 | 120 | 121 | def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples = 1): 122 | # input_dim = starting_point.size()[-1] 123 | # starting_point = starting_point.view(1,1,input_dim) 124 | 125 | # Sample z0 from prior 126 | starting_point_enc = self.z0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1) 127 | 128 | starting_point_enc_aug = starting_point_enc 129 | if self.use_poisson_proc: 130 | n_traj_samples, n_traj, n_dims = starting_point_enc.size() 131 | # append a vector of zeros to compute the integral of lambda 132 | zeros = torch.zeros(n_traj_samples, n_traj,self.input_dim).to(self.device) 133 | starting_point_enc_aug = torch.cat((starting_point_enc, zeros), -1) 134 | 135 | sol_y = self.diffeq_solver.sample_traj_from_prior(starting_point_enc_aug, time_steps_to_predict, 136 | n_traj_samples = 3) 137 | 138 | if self.use_poisson_proc: 139 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 140 | 141 | return self.decoder(sol_y) 142 | 143 | 144 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/likelihood_eval.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import gc 7 | import numpy as np 8 | import sklearn as sk 9 | import numpy as np 10 | #import gc 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn.functional import relu 14 | 15 | import lib.utils as utils 16 | from lib.utils import get_device 17 | from lib.encoder_decoder import * 18 | from lib.likelihood_eval import * 19 | 20 | from torch.distributions.multivariate_normal import MultivariateNormal 21 | from torch.distributions.normal import Normal 22 | from torch.distributions import kl_divergence, Independent 23 | 24 | 25 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None): 26 | n_data_points = mu_2d.size()[-1] 27 | 28 | if n_data_points > 0: 29 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1) 30 | log_prob = gaussian.log_prob(data_2d) 31 | log_prob = log_prob / n_data_points 32 | else: 33 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() 34 | return log_prob 35 | 36 | 37 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas): 38 | # masked_log_lambdas and masked_data 39 | n_data_points = masked_data.size()[-1] 40 | 41 | if n_data_points > 0: 42 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices] 43 | #log_prob = log_prob / n_data_points 44 | else: 45 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze() 46 | return log_prob 47 | 48 | 49 | 50 | def compute_binary_CE_loss(label_predictions, mortality_label): 51 | #print("Computing binary classification loss: compute_CE_loss") 52 | 53 | mortality_label = mortality_label.reshape(-1) 54 | 55 | if len(label_predictions.size()) == 1: 56 | label_predictions = label_predictions.unsqueeze(0) 57 | 58 | n_traj_samples = label_predictions.size(0) 59 | label_predictions = label_predictions.reshape(n_traj_samples, -1) 60 | 61 | idx_not_nan = ~torch.isnan(mortality_label) 62 | if len(idx_not_nan) == 0.: 63 | print("All are labels are NaNs!") 64 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label)) 65 | 66 | label_predictions = label_predictions[:,idx_not_nan] 67 | mortality_label = mortality_label[idx_not_nan] 68 | 69 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0: 70 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.") 71 | 72 | assert(not torch.isnan(label_predictions).any()) 73 | assert(not torch.isnan(mortality_label).any()) 74 | 75 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 76 | mortality_label = mortality_label.repeat(n_traj_samples, 1) 77 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) 78 | 79 | # divide by number of patients in a batch 80 | ce_loss = ce_loss / n_traj_samples 81 | return ce_loss 82 | 83 | 84 | def compute_multiclass_CE_loss(label_predictions, true_label, mask, weight=None): 85 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss") 86 | 87 | if (len(label_predictions.size()) == 3): 88 | label_predictions = label_predictions.unsqueeze(0) 89 | 90 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size() 91 | 92 | 93 | _ori_mask = mask 94 | 95 | if False: 96 | # assert(not torch.isnan(label_predictions).any()) 97 | # assert(not torch.isnan(true_label).any()) 98 | 99 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 100 | true_label = true_label.repeat(n_traj_samples, 1, 1) 101 | 102 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims) 103 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims) 104 | 105 | # choose time points with at least one measurement 106 | mask = torch.sum(mask, -1) > 0 107 | 108 | # repeat the mask for each label to mark that the label for this time point is present 109 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0) 110 | 111 | label_mask = mask 112 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1) 113 | label_mask = label_mask.repeat(n_traj_samples,1,1,1) 114 | 115 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims) 116 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1) 117 | 118 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): 119 | assert(label_predictions.size(-1) == true_label.size(-1)) 120 | # targets are in one-hot encoding -- convert to indices 121 | _, true_label = true_label.max(-1) 122 | 123 | res = [] 124 | for i in range(true_label.size(0)): 125 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool()) 126 | labels = torch.masked_select(true_label[i], label_mask[i].bool()) 127 | 128 | pred_masked = pred_masked.reshape(-1, n_dims) 129 | 130 | if (len(labels) == 0): 131 | continue 132 | 133 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long()) 134 | res.append(ce_loss) 135 | 136 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions)) 137 | ce_loss = torch.mean(ce_loss) 138 | 139 | else: 140 | cel = nn.CrossEntropyLoss(weight=weight) 141 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): 142 | assert(label_predictions.size(-1) == true_label.size(-1)) 143 | # targets are in one-hot encoding -- convert to indices 144 | _, true_label = true_label.max(-1) 145 | 146 | if (mask == 1).all(): 147 | # SKIP applying for mask 148 | true_label = true_label.repeat(n_traj_samples * n_traj, 1, 1).flatten() 149 | # true_label = true_label.repeat(n_traj_samples, 1, 1).flatten() 150 | label_predictions = label_predictions.reshape(-1, n_dims) 151 | _ce_loss = cel(label_predictions, true_label.long()) 152 | else: 153 | time_steps_to_consider = (torch.sum(_ori_mask, -1) > 0).repeat(n_traj_samples, 1, 1).flatten() 154 | true_label = true_label.repeat(n_traj_samples, 1, 1).flatten() 155 | label_predictions = label_predictions.reshape(-1, n_dims) 156 | _ce_loss = cel(label_predictions[time_steps_to_consider], true_label[time_steps_to_consider].long()) 157 | ce_loss = _ce_loss 158 | # print(ce_loss) 159 | # print(_ce_loss) 160 | 161 | 162 | 163 | # # divide by number of patients in a batch 164 | # ce_loss = ce_loss / n_traj_samples 165 | return ce_loss 166 | 167 | 168 | 169 | 170 | def compute_masked_likelihood(mu, data, mask, likelihood_func, _type, obsrv_std=None): 171 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 172 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 173 | 174 | #time_steps_to_consider = (torch.sum(mask, -1) > 0).repeat(n_traj_samples, 1, 1) 175 | time_steps_to_consider = (torch.sum(mask, -1) > 0) 176 | 177 | if False: 178 | 179 | res = [] 180 | for i in range(n_traj_samples): 181 | for k in range(n_traj): 182 | for j in range(n_dims): 183 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool()) 184 | 185 | #assert(torch.sum(data_masked == 0.) < 10) 186 | 187 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool()) 188 | # print(mu_masked.detach().cpu().numpy()) 189 | # print(data_masked.detach().cpu().numpy()) 190 | # print('---') 191 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j)) 192 | res.append(log_prob) 193 | # shape: [n_traj*n_traj_samples, 1] 194 | 195 | res = torch.stack(res, 0).to(get_device(data)) 196 | res = res.reshape((n_traj_samples, n_traj, n_dims)) 197 | # Take mean over the number of dimensions 198 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean 199 | res = res.transpose(0,1) 200 | 201 | 202 | 203 | 204 | # create a mask to map original shape into an unstructured shape 205 | _mask = mask.bool() 206 | # set target shape as the one without n_timepoints 207 | target_shape = n_traj_samples, n_traj, n_dims 208 | # apply mask to obtained an unstructured tensor 209 | _mu = mu[_mask] 210 | _data = data[_mask] 211 | 212 | if _type == 'logprob': 213 | # compute log prob as unstructured data shape 214 | gaussian = Independent(Normal(loc = _mu, scale=obsrv_std), 0) 215 | log_prob = gaussian.log_prob(_data) 216 | 217 | # map the masked log-prob back to the original shape 218 | lol = torch.zeros(data.shape).to(data.device) 219 | lol[_mask] = log_prob 220 | 221 | # create a diviser to average the log prob across timepoints 222 | diviser = _mask.float().sum(2) 223 | # replace any zero as 1 224 | diviser[diviser==0] = 1 225 | 226 | # get mean along dim 227 | lol = lol.sum(2) / diviser 228 | 229 | elif _type == 'mse': 230 | # compute mse as unstructured data shape 231 | #mse = nn.MSELoss()(_mu, _data) 232 | mse = (_mu - _data)**2 233 | 234 | # map the masked mse back to the original shape 235 | lol = torch.zeros(data.shape).to(data.device) 236 | lol[_mask] = mse 237 | 238 | # create a diviser to average the log prob across timepoints 239 | diviser = _mask.float().sum(2) 240 | # replace any zero as 1 241 | diviser[diviser==0] = 1 242 | 243 | # get mean along timepoints 244 | lol = lol.sum(2) / diviser 245 | 246 | # marginalise the dim 247 | lol = lol.mean(-1) 248 | 249 | # transpose back to the expected shape 250 | lol = lol.transpose(0,1) 251 | # assert torch.isclose(lol, res).all(), (lol, res) 252 | 253 | 254 | return lol 255 | 256 | 257 | return res 258 | 259 | 260 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None): 261 | # these cases are for plotting through plot_estim_density 262 | if (len(mu.size()) == 3): 263 | # add additional dimension for gp samples 264 | mu = mu.unsqueeze(0) 265 | 266 | if (len(data.size()) == 2): 267 | # add additional dimension for gp samples and time step 268 | data = data.unsqueeze(0).unsqueeze(2) 269 | elif (len(data.size()) == 3): 270 | # add additional dimension for gp samples 271 | data = data.unsqueeze(0) 272 | 273 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 274 | 275 | assert(data.size()[-1] == n_dims) 276 | 277 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 278 | if mask is None: 279 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 280 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 281 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 282 | 283 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std) 284 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1) 285 | else: 286 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 287 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices) 288 | #res = compute_masked_likelihood(mu, data, mask, func) 289 | res = compute_masked_likelihood(mu, data, mask, func, _type='logprob', obsrv_std=obsrv_std) 290 | return res 291 | 292 | 293 | 294 | def mse(mu, data, indices = None): 295 | n_data_points = mu.size()[-1] 296 | 297 | if n_data_points > 0: 298 | mse = nn.MSELoss()(mu, data) 299 | else: 300 | mse = torch.zeros([1]).to(get_device(data)).squeeze() 301 | return mse 302 | 303 | 304 | def compute_mse(mu, data, mask = None): 305 | # these cases are for plotting through plot_estim_density 306 | if (len(mu.size()) == 3): 307 | # add additional dimension for gp samples 308 | mu = mu.unsqueeze(0) 309 | 310 | if (len(data.size()) == 2): 311 | # add additional dimension for gp samples and time step 312 | data = data.unsqueeze(0).unsqueeze(2) 313 | elif (len(data.size()) == 3): 314 | # add additional dimension for gp samples 315 | data = data.unsqueeze(0) 316 | 317 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 318 | assert(data.size()[-1] == n_dims) 319 | 320 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 321 | if mask is None: 322 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 323 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 324 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 325 | res = mse(mu_flat, data_flat) 326 | else: 327 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 328 | #res = compute_masked_likelihood(mu, data, mask, mse) 329 | res = compute_masked_likelihood(mu, data, mask, mse, _type='mse') 330 | return res 331 | 332 | 333 | 334 | 335 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None): 336 | # Compute Poisson likelihood 337 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process 338 | # Sum log lambdas across all time points 339 | if mask is None: 340 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"] 341 | # Sum over data dims 342 | poisson_log_l = torch.mean(poisson_log_l, -1) 343 | else: 344 | # Compute likelihood of the data under the predictions 345 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 346 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1) 347 | 348 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 349 | int_lambda = info["int_lambda"] 350 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda) 351 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f) 352 | poisson_log_l = poisson_log_l.permute(1,0) 353 | # Take mean over n_traj 354 | #poisson_log_l = torch.mean(poisson_log_l, 1) 355 | 356 | # poisson_log_l shape: [n_traj_samples, n_traj] 357 | return poisson_log_l 358 | 359 | 360 | 361 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/ode_func.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.utils.spectral_norm import spectral_norm 10 | 11 | import lib.utils as utils 12 | 13 | ##################################################################################################### 14 | 15 | class ODEFunc(nn.Module): 16 | def __init__(self, input_dim, latent_dim, ode_func_net, device = torch.device("cpu")): 17 | """ 18 | input_dim: dimensionality of the input 19 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 20 | """ 21 | super(ODEFunc, self).__init__() 22 | 23 | self.input_dim = input_dim 24 | self.device = device 25 | 26 | utils.init_network_weights(ode_func_net) 27 | self.gradient_net = ode_func_net 28 | 29 | def forward(self, t_local, y, backwards = False): 30 | """ 31 | Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point 32 | 33 | t_local: current time point 34 | y: value at the current time point 35 | """ 36 | grad = self.get_ode_gradient_nn(t_local, y) 37 | if backwards: 38 | grad = -grad 39 | return grad 40 | 41 | def get_ode_gradient_nn(self, t_local, y): 42 | return self.gradient_net(y) 43 | 44 | def sample_next_point_from_prior(self, t_local, y): 45 | """ 46 | t_local: current time point 47 | y: value at the current time point 48 | """ 49 | return self.get_ode_gradient_nn(t_local, y) 50 | 51 | ##################################################################################################### 52 | 53 | class ODEFunc_w_Poisson(ODEFunc): 54 | 55 | def __init__(self, input_dim, latent_dim, ode_func_net, 56 | lambda_net, device = torch.device("cpu")): 57 | """ 58 | input_dim: dimensionality of the input 59 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 60 | """ 61 | super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, ode_func_net, device) 62 | 63 | self.latent_ode = ODEFunc(input_dim = input_dim, 64 | latent_dim = latent_dim, 65 | ode_func_net = ode_func_net, 66 | device = device) 67 | 68 | self.latent_dim = latent_dim 69 | self.lambda_net = lambda_net 70 | # The computation of poisson likelihood can become numerically unstable. 71 | #The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T] 72 | #Exponent of lambda can also take large values 73 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant 74 | self.const_for_lambda = torch.Tensor([100.]).to(device) 75 | 76 | def extract_poisson_rate(self, augmented, final_result = True): 77 | y, log_lambdas, int_lambda = None, None, None 78 | 79 | assert(augmented.size(-1) == self.latent_dim + self.input_dim) 80 | latent_lam_dim = self.latent_dim // 2 81 | 82 | if len(augmented.size()) == 3: 83 | int_lambda = augmented[:,:,-self.input_dim:] 84 | y_latent_lam = augmented[:,:,:-self.input_dim] 85 | 86 | log_lambdas = self.lambda_net(y_latent_lam[:,:,-latent_lam_dim:]) 87 | y = y_latent_lam[:,:,:-latent_lam_dim] 88 | 89 | elif len(augmented.size()) == 4: 90 | int_lambda = augmented[:,:,:,-self.input_dim:] 91 | y_latent_lam = augmented[:,:,:,:-self.input_dim] 92 | 93 | log_lambdas = self.lambda_net(y_latent_lam[:,:,:,-latent_lam_dim:]) 94 | y = y_latent_lam[:,:,:,:-latent_lam_dim] 95 | 96 | # Multiply the intergral over lambda by a constant 97 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn) 98 | if final_result: 99 | int_lambda = int_lambda * self.const_for_lambda 100 | 101 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas) 102 | assert(y.size(-1) == latent_lam_dim) 103 | 104 | return y, log_lambdas, int_lambda, y_latent_lam 105 | 106 | 107 | def get_ode_gradient_nn(self, t_local, augmented): 108 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result = False) 109 | dydt_dldt = self.latent_ode(t_local, y_latent_lam) 110 | 111 | log_lam = log_lam - torch.log(self.const_for_lambda) 112 | return torch.cat((dydt_dldt, torch.exp(log_lam)),-1) 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/ode_rnn.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | 11 | import lib.utils as utils 12 | from lib.encoder_decoder import * 13 | from lib.likelihood_eval import * 14 | 15 | from torch.distributions.multivariate_normal import MultivariateNormal 16 | from torch.distributions.normal import Normal 17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 18 | 19 | from torch.distributions.normal import Normal 20 | from torch.distributions import Independent 21 | from torch.nn.parameter import Parameter 22 | from lib.base_models import Baseline 23 | 24 | 25 | class ODE_RNN(Baseline): 26 | def __init__(self, input_dim, latent_dim, device = torch.device("cpu"), 27 | z0_diffeq_solver = None, n_gru_units = 100, n_units = 100, 28 | concat_mask = False, obsrv_std = 0.1, use_binary_classif = False, 29 | classif_per_tp = False, n_labels = 1, train_classif_w_reconstr = False): 30 | 31 | Baseline.__init__(self, input_dim, latent_dim, device = device, 32 | obsrv_std = obsrv_std, use_binary_classif = use_binary_classif, 33 | classif_per_tp = classif_per_tp, 34 | n_labels = n_labels, 35 | train_classif_w_reconstr = train_classif_w_reconstr) 36 | 37 | ode_rnn_encoder_dim = latent_dim 38 | 39 | self.ode_gru = Encoder_z0_ODE_RNN( 40 | latent_dim = ode_rnn_encoder_dim, 41 | input_dim = (input_dim) * 2, # input and the mask 42 | z0_diffeq_solver = z0_diffeq_solver, 43 | n_gru_units = n_gru_units, 44 | device = device).to(device) 45 | 46 | self.z0_diffeq_solver = z0_diffeq_solver 47 | 48 | self.decoder = nn.Sequential( 49 | nn.Linear(latent_dim, n_units), 50 | nn.Tanh(), 51 | nn.Linear(n_units, input_dim),) 52 | 53 | utils.init_network_weights(self.decoder) 54 | 55 | 56 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 57 | mask = None, n_traj_samples = None, mode = None): 58 | 59 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0): 60 | raise Exception("Extrapolation mode not implemented for ODE-RNN") 61 | 62 | # time_steps_to_predict and truth_time_steps should be the same 63 | assert(len(truth_time_steps) == len(time_steps_to_predict)) 64 | assert(mask is not None) 65 | 66 | data_and_mask = data 67 | if mask is not None: 68 | data_and_mask = torch.cat([data, mask],-1) 69 | 70 | _, _, latent_ys, _ = self.ode_gru.run_odernn( 71 | data_and_mask, truth_time_steps, run_backwards = False) 72 | 73 | latent_ys = latent_ys.permute(0,2,1,3) 74 | last_hidden = latent_ys[:,:,-1,:] 75 | 76 | #assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 77 | 78 | outputs = self.decoder(latent_ys) 79 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. 80 | first_point = data[:,0,:] 81 | outputs = utils.shift_outputs(outputs, first_point) 82 | 83 | extra_info = {"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:])} 84 | 85 | if self.use_binary_classif: 86 | if self.classif_per_tp: 87 | extra_info["label_predictions"] = self.classifier(latent_ys) 88 | else: 89 | extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1) 90 | 91 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] 92 | return outputs, extra_info 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/plotting.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import matplotlib 7 | # matplotlib.use('TkAgg') 8 | # matplotlib.use('Agg') 9 | import matplotlib.pyplot 10 | import matplotlib.pyplot as plt 11 | from matplotlib.lines import Line2D 12 | 13 | import os 14 | from scipy.stats import kde 15 | 16 | import numpy as np 17 | import subprocess 18 | import torch 19 | import lib.utils as utils 20 | import matplotlib.gridspec as gridspec 21 | from lib.utils import get_device 22 | 23 | from lib.encoder_decoder import * 24 | from lib.rnn_baselines import * 25 | from lib.ode_rnn import * 26 | import torch.nn.functional as functional 27 | from torch.distributions.normal import Normal 28 | from lib.latent_ode import LatentODE 29 | 30 | from lib.likelihood_eval import masked_gaussian_log_density 31 | # try: 32 | # import umap 33 | # except: 34 | # print("Couldn't import umap") 35 | 36 | from generate_timeseries import Periodic_1d 37 | from person_activity import PersonActivity 38 | 39 | from lib.utils import compute_loss_all_batches 40 | 41 | 42 | SMALL_SIZE = 14 43 | MEDIUM_SIZE = 16 44 | BIGGER_SIZE = 18 45 | LARGE_SIZE = 22 46 | 47 | def init_fonts(main_font_size = LARGE_SIZE): 48 | plt.rc('font', size=main_font_size) # controls default text sizes 49 | plt.rc('axes', titlesize=main_font_size) # fontsize of the axes title 50 | plt.rc('axes', labelsize=main_font_size - 2) # fontsize of the x and y labels 51 | plt.rc('xtick', labelsize=main_font_size - 2) # fontsize of the tick labels 52 | plt.rc('ytick', labelsize=main_font_size - 2) # fontsize of the tick labels 53 | plt.rc('legend', fontsize=main_font_size - 2) # legend fontsize 54 | plt.rc('figure', titlesize=main_font_size) # fontsize of the figure title 55 | 56 | 57 | def plot_trajectories(ax, traj, time_steps, min_y = None, max_y = None, title = "", 58 | add_to_plot = False, label = None, add_legend = False, dim_to_show = 0, 59 | linestyle = '-', marker = 'o', mask = None, color = None, linewidth = 1): 60 | # expected shape of traj: [n_traj, n_timesteps, n_dims] 61 | # The function will produce one line per trajectory (n_traj lines in total) 62 | if not add_to_plot: 63 | ax.cla() 64 | ax.set_title(title) 65 | ax.set_xlabel('Time') 66 | ax.set_ylabel('x') 67 | 68 | if min_y is not None: 69 | ax.set_ylim(bottom = min_y) 70 | 71 | if max_y is not None: 72 | ax.set_ylim(top = max_y) 73 | 74 | for i in range(traj.size()[0]): 75 | d = traj[i].cpu().numpy()[:, dim_to_show] 76 | ts = time_steps.cpu().numpy() 77 | if mask is not None: 78 | m = mask[i].cpu().numpy()[:, dim_to_show] 79 | d = d[m == 1] 80 | ts = ts[m == 1] 81 | ax.plot(ts, d, linestyle = linestyle, label = label, marker=marker, color = color, linewidth = linewidth) 82 | 83 | if add_legend: 84 | ax.legend() 85 | 86 | 87 | def plot_std(ax, traj, traj_std, time_steps, min_y = None, max_y = None, title = "", 88 | add_to_plot = False, label = None, alpha=0.2, color = None): 89 | 90 | # take only the first (and only?) dimension 91 | mean_minus_std = (traj - traj_std).cpu().numpy()[:, :, 0] 92 | mean_plus_std = (traj + traj_std).cpu().numpy()[:, :, 0] 93 | 94 | for i in range(traj.size()[0]): 95 | ax.fill_between(time_steps.cpu().numpy(), mean_minus_std[i], mean_plus_std[i], 96 | alpha=alpha, color = color) 97 | 98 | 99 | 100 | def plot_vector_field(ax, odefunc, latent_dim, device): 101 | # Code borrowed from https://github.com/rtqichen/ffjord/blob/29c016131b702b307ceb05c70c74c6e802bb8a44/diagnostics/viz_toy.py 102 | K = 13j 103 | y, x = np.mgrid[-6:6:K, -6:6:K] 104 | K = int(K.imag) 105 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) 106 | if latent_dim > 2: 107 | # Plots dimensions 0 and 2 108 | zs = torch.cat((zs, torch.zeros(K * K, latent_dim-2)), 1) 109 | dydt = odefunc(0, zs) 110 | dydt = -dydt.cpu().detach().numpy() 111 | if latent_dim > 2: 112 | dydt = dydt[:,:2] 113 | 114 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1) 115 | dydt = (dydt / mag) 116 | dydt = dydt.reshape(K, K, 2) 117 | 118 | ax.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], #color = dydt[:, :, 0], 119 | cmap="coolwarm", linewidth=2) 120 | 121 | # ax.quiver( 122 | # x, y, dydt[:, :, 0], dydt[:, :, 1], 123 | # np.exp(logmag), cmap="coolwarm", pivot="mid", scale = 100, 124 | # ) 125 | ax.set_xlim(-6, 6) 126 | ax.set_ylim(-6, 6) 127 | #ax.axis("off") 128 | 129 | 130 | 131 | def get_meshgrid(npts, int_y1, int_y2): 132 | min_y1, max_y1 = int_y1 133 | min_y2, max_y2 = int_y2 134 | 135 | y1_grid = np.linspace(min_y1, max_y1, npts) 136 | y2_grid = np.linspace(min_y2, max_y2, npts) 137 | 138 | xx, yy = np.meshgrid(y1_grid, y2_grid) 139 | 140 | flat_inputs = np.concatenate((np.expand_dims(xx.flatten(),1), np.expand_dims(yy.flatten(),1)), 1) 141 | flat_inputs = torch.from_numpy(flat_inputs).float() 142 | 143 | return xx, yy, flat_inputs 144 | 145 | 146 | def add_white(cmap): 147 | cmaplist = [cmap(i) for i in range(cmap.N)] 148 | # force the first color entry to be grey 149 | cmaplist[0] = (1.,1.,1.,1.0) 150 | # create the new map 151 | cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N) 152 | return cmap 153 | 154 | 155 | class Visualizations(): 156 | def __init__(self, device): 157 | self.init_visualization() 158 | init_fonts(SMALL_SIZE) 159 | self.device = device 160 | 161 | def init_visualization(self): 162 | self.fig = plt.figure(figsize=(12, 7), facecolor='white') 163 | 164 | self.ax_traj = [] 165 | for i in range(1,4): 166 | self.ax_traj.append(self.fig.add_subplot(2,3,i, frameon=False)) 167 | 168 | # self.ax_density = [] 169 | # for i in range(4,7): 170 | # self.ax_density.append(self.fig.add_subplot(3,3,i, frameon=False)) 171 | 172 | #self.ax_samples_same_traj = self.fig.add_subplot(3,3,7, frameon=False) 173 | self.ax_latent_traj = self.fig.add_subplot(2,3,4, frameon=False) 174 | self.ax_vector_field = self.fig.add_subplot(2,3,5, frameon=False) 175 | self.ax_traj_from_prior = self.fig.add_subplot(2,3,6, frameon=False) 176 | 177 | self.plot_limits = {} 178 | plt.show(block=False) 179 | 180 | def set_plot_lims(self, ax, name): 181 | if name not in self.plot_limits: 182 | self.plot_limits[name] = (ax.get_xlim(), ax.get_ylim()) 183 | return 184 | 185 | xlim, ylim = self.plot_limits[name] 186 | ax.set_xlim(xlim) 187 | ax.set_ylim(ylim) 188 | 189 | def draw_one_density_plot(self, ax, model, data_dict, traj_id, 190 | multiply_by_poisson = False): 191 | 192 | scale = 5 193 | cmap = add_white(plt.cm.get_cmap('Blues', 9)) # plt.cm.BuGn_r 194 | cmap2 = add_white(plt.cm.get_cmap('Reds', 9)) # plt.cm.BuGn_r 195 | #cmap = plt.cm.get_cmap('viridis') 196 | 197 | data = data_dict["data_to_predict"] 198 | time_steps = data_dict["tp_to_predict"] 199 | mask = data_dict["mask_predicted_data"] 200 | 201 | observed_data = data_dict["observed_data"] 202 | observed_time_steps = data_dict["observed_tp"] 203 | observed_mask = data_dict["observed_mask"] 204 | 205 | npts = 50 206 | xx, yy, z0_grid = get_meshgrid(npts = npts, int_y1 = (-scale,scale), int_y2 = (-scale,scale)) 207 | z0_grid = z0_grid.to(get_device(data)) 208 | 209 | if model.latent_dim > 2: 210 | z0_grid = torch.cat((z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim-2)), 1) 211 | 212 | if model.use_poisson_proc: 213 | n_traj, n_dims = z0_grid.size() 214 | # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda 215 | zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim]).to(get_device(data)) 216 | z0_grid_aug = torch.cat((z0_grid, zeros), -1) 217 | else: 218 | z0_grid_aug = z0_grid 219 | 220 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] 221 | sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps) 222 | 223 | if model.use_poisson_proc: 224 | sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 225 | 226 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.) 227 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 228 | 229 | pred_x = model.decoder(sol_y) 230 | 231 | # Plot density for one trajectory 232 | one_traj = data[traj_id] 233 | mask_one_traj = None 234 | if mask is not None: 235 | mask_one_traj = mask[traj_id].unsqueeze(0) 236 | mask_one_traj = mask_one_traj.repeat(npts**2,1,1).unsqueeze(0) 237 | 238 | ax.cla() 239 | 240 | # Plot: prior 241 | prior_density_grid = model.z0_prior.log_prob(z0_grid.unsqueeze(0)).squeeze(0) 242 | # Sum the density over two dimensions 243 | prior_density_grid = torch.sum(prior_density_grid, -1) 244 | 245 | # ================================================= 246 | # Plot: p(x | y(t0)) 247 | 248 | masked_gaussian_log_density_grid = masked_gaussian_log_density(pred_x, 249 | one_traj.repeat(npts**2,1,1).unsqueeze(0), 250 | mask = mask_one_traj, 251 | obsrv_std = model.obsrv_std).squeeze(-1) 252 | 253 | # Plot p(t | y(t0)) 254 | if model.use_poisson_proc: 255 | poisson_info = {} 256 | poisson_info["int_lambda"] = int_lambda[:,:,-1,:] 257 | poisson_info["log_lambda_y"] = log_lambda_y 258 | 259 | poisson_log_density_grid = compute_poisson_proc_likelihood( 260 | one_traj.repeat(npts**2,1,1).unsqueeze(0), 261 | pred_x, poisson_info, mask = mask_one_traj) 262 | poisson_log_density_grid = poisson_log_density_grid.squeeze(0) 263 | 264 | # ================================================= 265 | # Plot: p(x , y(t0)) 266 | 267 | log_joint_density = prior_density_grid + masked_gaussian_log_density_grid 268 | if multiply_by_poisson: 269 | log_joint_density = log_joint_density + poisson_log_density_grid 270 | 271 | density_grid = torch.exp(log_joint_density) 272 | 273 | density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) 274 | density_grid = density_grid.cpu().numpy() 275 | 276 | ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1) 277 | 278 | # ================================================= 279 | # Plot: q(y(t0)| x) 280 | #self.ax_density.set_title("Red: q(y(t0) | x) Blue: p(x, y(t0))") 281 | ax.set_xlabel('z1(t0)') 282 | ax.set_ylabel('z2(t0)') 283 | 284 | data_w_mask = observed_data[traj_id].unsqueeze(0) 285 | if observed_mask is not None: 286 | data_w_mask = torch.cat((data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1) 287 | z0_mu, z0_std = model.encoder_z0( 288 | data_w_mask, observed_time_steps) 289 | 290 | if model.use_poisson_proc: 291 | z0_mu = z0_mu[:, :, :model.latent_dim] 292 | z0_std = z0_std[:, :, :model.latent_dim] 293 | 294 | q_z0 = Normal(z0_mu, z0_std) 295 | 296 | q_density_grid = q_z0.log_prob(z0_grid) 297 | # Sum the density over two dimensions 298 | q_density_grid = torch.sum(q_density_grid, -1) 299 | density_grid = torch.exp(q_density_grid) 300 | 301 | density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) 302 | density_grid = density_grid.cpu().numpy() 303 | 304 | ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3) 305 | 306 | 307 | 308 | def draw_all_plots_one_dim(self, data_dict, model, 309 | plot_name = "", save = False, experimentID = 0.): 310 | 311 | data = data_dict["data_to_predict"] 312 | time_steps = data_dict["tp_to_predict"] 313 | mask = data_dict["mask_predicted_data"] 314 | 315 | observed_data = data_dict["observed_data"] 316 | observed_time_steps = data_dict["observed_tp"] 317 | observed_mask = data_dict["observed_mask"] 318 | 319 | device = get_device(time_steps) 320 | 321 | time_steps_to_predict = time_steps 322 | if isinstance(model, LatentODE): 323 | # sample at the original time points 324 | time_steps_to_predict = utils.linspace_vector(time_steps[0], time_steps[-1], 100).to(device) 325 | 326 | reconstructions, info = model.get_reconstruction(time_steps_to_predict, 327 | observed_data, observed_time_steps, mask = observed_mask, n_traj_samples = 10) 328 | 329 | n_traj_to_show = 3 330 | # plot only 10 trajectories 331 | data_for_plotting = observed_data[:n_traj_to_show] 332 | mask_for_plotting = observed_mask[:n_traj_to_show] 333 | reconstructions_for_plotting = reconstructions.mean(dim=0)[:n_traj_to_show] 334 | reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show] 335 | 336 | dim_to_show = 0 337 | max_y = max( 338 | data_for_plotting[:,:,dim_to_show].cpu().numpy().max(), 339 | reconstructions[:,:,dim_to_show].cpu().numpy().max()) 340 | min_y = min( 341 | data_for_plotting[:,:,dim_to_show].cpu().numpy().min(), 342 | reconstructions[:,:,dim_to_show].cpu().numpy().min()) 343 | 344 | ############################################ 345 | # Plot reconstructions, true postrior and approximate posterior 346 | 347 | cmap = plt.cm.get_cmap('Set1') 348 | for traj_id in range(3): 349 | # Plot observations 350 | plot_trajectories(self.ax_traj[traj_id], 351 | data_for_plotting[traj_id].unsqueeze(0), observed_time_steps, 352 | mask = mask_for_plotting[traj_id].unsqueeze(0), 353 | min_y = min_y, max_y = max_y, #title="True trajectories", 354 | marker = 'o', linestyle='', dim_to_show = dim_to_show, 355 | color = cmap(2)) 356 | # Plot reconstructions 357 | plot_trajectories(self.ax_traj[traj_id], 358 | reconstructions_for_plotting[traj_id].unsqueeze(0), time_steps_to_predict, 359 | min_y = min_y, max_y = max_y, title="Sample {} (data space)".format(traj_id), dim_to_show = dim_to_show, 360 | add_to_plot = True, marker = '', color = cmap(3), linewidth = 3) 361 | # Plot variance estimated over multiple samples from approx posterior 362 | plot_std(self.ax_traj[traj_id], 363 | reconstructions_for_plotting[traj_id].unsqueeze(0), reconstr_std[traj_id].unsqueeze(0), 364 | time_steps_to_predict, alpha=0.5, color = cmap(3)) 365 | self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id)) 366 | 367 | # Plot true posterior and approximate posterior 368 | # self.draw_one_density_plot(self.ax_density[traj_id], 369 | # model, data_dict, traj_id = traj_id, 370 | # multiply_by_poisson = False) 371 | # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id)) 372 | # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id)) 373 | ############################################ 374 | # Get several samples for the same trajectory 375 | # one_traj = data_for_plotting[:1] 376 | # first_point = one_traj[:,0] 377 | 378 | # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict, 379 | # observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5) 380 | # samples_same_traj = samples_same_traj.squeeze(1) 381 | 382 | # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '') 383 | # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "", 384 | # label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)") 385 | 386 | ############################################ 387 | # Plot trajectories from prior 388 | 389 | try: 390 | if isinstance(model, LatentODE): 391 | torch.manual_seed(1991) 392 | np.random.seed(1991) 393 | 394 | traj_from_prior = model.sample_traj_from_prior(time_steps_to_predict, n_traj_samples = 3) 395 | # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension 396 | traj_from_prior = traj_from_prior.squeeze(1) 397 | 398 | plot_trajectories(self.ax_traj_from_prior, traj_from_prior, time_steps_to_predict, 399 | marker = '', linewidth = 3) 400 | self.ax_traj_from_prior.set_title("Samples from prior (data space)", pad = 20) 401 | #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior") 402 | except AttributeError: 403 | pass 404 | ################################################ 405 | 406 | # Plot z0 407 | # first_point_mu, first_point_std, first_point_enc = info["first_point"] 408 | 409 | # dim1 = 0 410 | # dim2 = 1 411 | # self.ax_z0.cla() 412 | # # first_point_enc shape: [1, n_traj, n_dims] 413 | # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2]) 414 | # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)") 415 | # self.ax_z0.set_xlabel('dim {}'.format(dim1)) 416 | # self.ax_z0.set_ylabel('dim {}'.format(dim2)) 417 | 418 | try: 419 | ################################################ 420 | # Show vector field 421 | self.ax_vector_field.cla() 422 | plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func, model.latent_dim, device) 423 | self.ax_vector_field.set_title("Slice of vector field (latent space)", pad = 20) 424 | self.set_plot_lims(self.ax_vector_field, "vector_field") 425 | #self.ax_vector_field.set_ylim((-0.5, 1.5)) 426 | except AttributeError: 427 | pass 428 | 429 | ################################################ 430 | # Plot trajectories in the latent space 431 | 432 | # shape before [1, n_traj, n_tp, n_latent_dims] 433 | # Take only the first sample from approx posterior 434 | latent_traj = info["latent_traj"][0,:n_traj_to_show] 435 | # shape before permute: [1, n_tp, n_latent_dims] 436 | 437 | self.ax_latent_traj.cla() 438 | cmap = plt.cm.get_cmap('Accent') 439 | n_latent_dims = latent_traj.size(-1) 440 | 441 | custom_labels = {} 442 | for i in range(n_latent_dims): 443 | col = cmap(i) 444 | plot_trajectories(self.ax_latent_traj, latent_traj, time_steps_to_predict, 445 | title="Latent trajectories z(t) (latent space)", dim_to_show = i, color = col, 446 | marker = '', add_to_plot = True, 447 | linewidth = 3) 448 | custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col) 449 | 450 | self.ax_latent_traj.set_ylabel("z") 451 | self.ax_latent_traj.set_title("Latent trajectories z(t) (latent space)", pad = 20) 452 | self.ax_latent_traj.legend(custom_labels.values(), custom_labels.keys(), loc = 'lower left') 453 | self.set_plot_lims(self.ax_latent_traj, "latent_traj") 454 | 455 | ################################################ 456 | 457 | self.fig.tight_layout() 458 | plt.draw() 459 | 460 | if save: 461 | dirname = "plots/" + str(experimentID) + "/" 462 | os.makedirs(dirname, exist_ok=True) 463 | self.fig.savefig(dirname + plot_name) 464 | 465 | 466 | 467 | 468 | 469 | -------------------------------------------------------------------------------- /latentode/latent_ode/lib/rnn_baselines.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | 11 | import lib.utils as utils 12 | from lib.utils import get_device 13 | from lib.encoder_decoder import * 14 | from lib.likelihood_eval import * 15 | 16 | from torch.distributions.multivariate_normal import MultivariateNormal 17 | from torch.distributions.normal import Normal 18 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 19 | 20 | from torch.distributions.normal import Normal 21 | from torch.distributions import Independent 22 | from torch.nn.parameter import Parameter 23 | from lib.base_models import Baseline, VAE_Baseline 24 | 25 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 26 | # Exponential decay of the hidden states for RNN 27 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/ 28 | 29 | # Exp decay between hidden states 30 | class GRUCellExpDecay(RNNCellBase): 31 | def __init__(self, input_size, input_size_for_decay, hidden_size, device, bias=True): 32 | super(GRUCellExpDecay, self).__init__(input_size, hidden_size, bias, num_chunks=3) 33 | 34 | self.device = device 35 | self.input_size_for_decay = input_size_for_decay 36 | self.decay = nn.Sequential(nn.Linear(input_size_for_decay, 1),) 37 | utils.init_network_weights(self.decay) 38 | 39 | def gru_exp_decay_cell(self, input, hidden, w_ih, w_hh, b_ih, b_hh): 40 | # INPORTANT: assumes that cum delta t is the last dimension of the input 41 | batch_size, n_dims = input.size() 42 | 43 | # "input" contains the data, mask and also cumulative deltas for all inputs 44 | cum_delta_ts = input[:, -self.input_size_for_decay:] 45 | data = input[:, :-self.input_size_for_decay] 46 | 47 | decay = torch.exp( - torch.min(torch.max( 48 | torch.zeros([1]).to(self.device), self.decay(cum_delta_ts)), 49 | torch.ones([1]).to(self.device) * 1000 )) 50 | 51 | hidden = hidden * decay 52 | 53 | gi = torch.mm(data, w_ih.t()) + b_ih 54 | gh = torch.mm(hidden, w_hh.t()) + b_hh 55 | i_r, i_i, i_n = gi.chunk(3, 1) 56 | h_r, h_i, h_n = gh.chunk(3, 1) 57 | 58 | resetgate = torch.sigmoid(i_r + h_r) 59 | inputgate = torch.sigmoid(i_i + h_i) 60 | newgate = torch.tanh(i_n + resetgate * h_n) 61 | hy = newgate + inputgate * (hidden - newgate) 62 | return hy 63 | 64 | def forward(self, input, hx=None): 65 | # type: (Tensor, Optional[Tensor]) -> Tensor 66 | #self.check_forward_input(input) 67 | if hx is None: 68 | hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) 69 | #self.check_forward_hidden(input, hx, '') 70 | 71 | return self.gru_exp_decay_cell( 72 | input, hx, 73 | self.weight_ih, self.weight_hh, 74 | self.bias_ih, self.bias_hh 75 | ) 76 | 77 | 78 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 79 | # Imputation with a weighed average of previous value and empirical mean 80 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/ 81 | def get_cum_delta_ts(data, delta_ts, mask): 82 | n_traj, n_tp, n_dims = data.size() 83 | 84 | cum_delta_ts = delta_ts.repeat(1, 1, n_dims) 85 | missing_index = np.where(mask.cpu().numpy() == 0) 86 | 87 | for idx in range(missing_index[0].shape[0]): 88 | i = missing_index[0][idx] 89 | j = missing_index[1][idx] 90 | k = missing_index[2][idx] 91 | 92 | if j != 0 and j != (n_tp-1): 93 | cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k] 94 | cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize 95 | 96 | return cum_delta_ts 97 | 98 | 99 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/ 100 | # very slow 101 | def impute_using_input_decay(data, delta_ts, mask, w_input_decay, b_input_decay): 102 | n_traj, n_tp, n_dims = data.size() 103 | 104 | cum_delta_ts = delta_ts.repeat(1, 1, n_dims) 105 | missing_index = np.where(mask.cpu().numpy() == 0) 106 | 107 | data_last_obsv = np.copy(data.cpu().numpy()) 108 | for idx in range(missing_index[0].shape[0]): 109 | i = missing_index[0][idx] 110 | j = missing_index[1][idx] 111 | k = missing_index[2][idx] 112 | 113 | if j != 0 and j != (n_tp-1): 114 | cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k] 115 | if j != 0: 116 | data_last_obsv[i,j,k] = data_last_obsv[i,j-1,k] # last observation 117 | cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize 118 | 119 | data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data)) 120 | 121 | zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data)) 122 | decay = torch.exp( - torch.min( torch.max(zeros, 123 | w_input_decay * cum_delta_ts + b_input_decay), zeros + 1000 )) 124 | 125 | data_means = torch.mean(data, 1).unsqueeze(1) 126 | 127 | data_imputed = data * mask + (1-mask) * (decay * data_last_obsv + (1-decay) * data_means) 128 | return data_imputed 129 | 130 | 131 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 132 | 133 | def run_rnn(inputs, delta_ts, cell, first_hidden=None, 134 | mask = None, feed_previous=False, n_steps=0, 135 | decoder = None, input_decay_params = None, 136 | feed_previous_w_prob = 0., 137 | masked_update = True): 138 | if (feed_previous or feed_previous_w_prob) and decoder is None: 139 | raise Exception("feed_previous is set to True -- please specify RNN decoder") 140 | 141 | if n_steps == 0: 142 | n_steps = inputs.size(1) 143 | 144 | if (feed_previous or feed_previous_w_prob) and mask is None: 145 | mask = torch.ones((inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs)) 146 | 147 | if isinstance(cell, GRUCellExpDecay): 148 | cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask) 149 | 150 | if input_decay_params is not None: 151 | w_input_decay, b_input_decay = input_decay_params 152 | inputs = impute_using_input_decay(inputs, delta_ts, mask, 153 | w_input_decay, b_input_decay) 154 | 155 | all_hiddens = [] 156 | hidden = first_hidden 157 | 158 | if hidden is not None: 159 | all_hiddens.append(hidden) 160 | n_steps -= 1 161 | 162 | for i in range(n_steps): 163 | delta_t = delta_ts[:,i] 164 | if i == 0: 165 | rnn_input = inputs[:,i] 166 | elif feed_previous: 167 | rnn_input = decoder(hidden) 168 | elif feed_previous_w_prob > 0: 169 | feed_prev = np.random.uniform() > feed_previous_w_prob 170 | if feed_prev: 171 | rnn_input = decoder(hidden) 172 | else: 173 | rnn_input = inputs[:,i] 174 | else: 175 | rnn_input = inputs[:,i] 176 | 177 | if mask is not None: 178 | mask_i = mask[:,i,:] 179 | rnn_input = torch.cat((rnn_input, mask_i), -1) 180 | 181 | if isinstance(cell, GRUCellExpDecay): 182 | cum_delta_t = cum_delta_ts[:,i] 183 | input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1) 184 | else: 185 | input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1) 186 | 187 | prev_hidden = hidden 188 | hidden = cell(input_w_t, hidden) 189 | 190 | if masked_update and (mask is not None) and (prev_hidden is not None): 191 | # update only the hidden states for hidden state only if at least one feature is present for the current time point 192 | summed_mask = (torch.sum(mask_i, -1, keepdim = True) > 0).float() 193 | assert(not torch.isnan(summed_mask).any()) 194 | hidden = summed_mask * hidden + (1-summed_mask) * prev_hidden 195 | 196 | all_hiddens.append(hidden) 197 | 198 | all_hiddens = torch.stack(all_hiddens, 0) 199 | all_hiddens = all_hiddens.permute(1,0,2).unsqueeze(0) 200 | return hidden, all_hiddens 201 | 202 | 203 | 204 | 205 | class Classic_RNN(Baseline): 206 | def __init__(self, input_dim, latent_dim, device, 207 | concat_mask = False, obsrv_std = 0.1, 208 | use_binary_classif = False, 209 | linear_classifier = False, 210 | classif_per_tp = False, 211 | input_space_decay = False, 212 | cell = "gru", n_units = 100, 213 | n_labels = 1, 214 | train_classif_w_reconstr = False): 215 | 216 | super(Classic_RNN, self).__init__(input_dim, latent_dim, device, 217 | obsrv_std = obsrv_std, 218 | use_binary_classif = use_binary_classif, 219 | classif_per_tp = classif_per_tp, 220 | linear_classifier = linear_classifier, 221 | n_labels = n_labels, 222 | train_classif_w_reconstr = train_classif_w_reconstr) 223 | 224 | self.concat_mask = concat_mask 225 | 226 | encoder_dim = int(input_dim) 227 | if concat_mask: 228 | encoder_dim = encoder_dim * 2 229 | 230 | self.decoder = nn.Sequential( 231 | nn.Linear(latent_dim, n_units), 232 | nn.Tanh(), 233 | nn.Linear(n_units, input_dim),) 234 | 235 | #utils.init_network_weights(self.encoder) 236 | utils.init_network_weights(self.decoder) 237 | 238 | if cell == "gru": 239 | self.rnn_cell = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t 240 | elif cell == "expdecay": 241 | self.rnn_cell = GRUCellExpDecay( 242 | input_size = encoder_dim, 243 | input_size_for_decay = input_dim, 244 | hidden_size = latent_dim, 245 | device = device) 246 | else: 247 | raise Exception("Unknown RNN cell: {}".format(cell)) 248 | 249 | if input_space_decay: 250 | self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device) 251 | self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device) 252 | self.input_space_decay = input_space_decay 253 | 254 | self.z0_net = lambda hidden_state: hidden_state 255 | 256 | 257 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 258 | mask = None, n_traj_samples = 1, mode = None): 259 | 260 | assert(mask is not None) 261 | n_traj, n_tp, n_dims = data.size() 262 | 263 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0): 264 | raise Exception("Extrapolation mode not implemented for RNN models") 265 | 266 | # for classic RNN time_steps_to_predict should be the same as truth_time_steps 267 | assert(len(truth_time_steps) == len(time_steps_to_predict)) 268 | 269 | batch_size = data.size(0) 270 | zero_delta_t = torch.Tensor([0.]).to(self.device) 271 | 272 | delta_ts = truth_time_steps[1:] - truth_time_steps[:-1] 273 | delta_ts = torch.cat((delta_ts, zero_delta_t)) 274 | if len(delta_ts.size()) == 1: 275 | # delta_ts are shared for all trajectories in a batch 276 | assert(data.size(1) == delta_ts.size(0)) 277 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1)) 278 | 279 | input_decay_params = None 280 | if self.input_space_decay: 281 | input_decay_params = (self.w_input_decay, self.b_input_decay) 282 | 283 | if mask is not None: 284 | utils.check_mask(data, mask) 285 | 286 | hidden_state, all_hiddens = run_rnn(data, delta_ts, 287 | cell = self.rnn_cell, mask = mask, 288 | input_decay_params = input_decay_params, 289 | feed_previous_w_prob = (0. if self.use_binary_classif else 0.5), 290 | decoder = self.decoder) 291 | 292 | outputs = self.decoder(all_hiddens) 293 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. 294 | first_point = data[:,0,:] 295 | outputs = utils.shift_outputs(outputs, first_point) 296 | 297 | extra_info = {"first_point": (hidden_state.unsqueeze(0), 0.0, hidden_state.unsqueeze(0))} 298 | 299 | if self.use_binary_classif: 300 | if self.classif_per_tp: 301 | extra_info["label_predictions"] = self.classifier(all_hiddens) 302 | else: 303 | extra_info["label_predictions"] = self.classifier(hidden_state).reshape(1,-1) 304 | 305 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] 306 | return outputs, extra_info 307 | 308 | 309 | 310 | class RNN_VAE(VAE_Baseline): 311 | def __init__(self, input_dim, latent_dim, rec_dims, 312 | z0_prior, device, 313 | concat_mask = False, obsrv_std = 0.1, 314 | input_space_decay = False, 315 | use_binary_classif = False, 316 | classif_per_tp =False, 317 | linear_classifier = False, 318 | cell = "gru", n_units = 100, 319 | n_labels = 1, 320 | train_classif_w_reconstr = False): 321 | 322 | super(RNN_VAE, self).__init__( 323 | input_dim = input_dim, latent_dim = latent_dim, 324 | z0_prior = z0_prior, 325 | device = device, obsrv_std = obsrv_std, 326 | use_binary_classif = use_binary_classif, 327 | classif_per_tp = classif_per_tp, 328 | linear_classifier = linear_classifier, 329 | n_labels = n_labels, 330 | train_classif_w_reconstr = train_classif_w_reconstr) 331 | 332 | self.concat_mask = concat_mask 333 | 334 | encoder_dim = int(input_dim) 335 | if concat_mask: 336 | encoder_dim = encoder_dim * 2 337 | 338 | if cell == "gru": 339 | self.rnn_cell_enc = GRUCell(encoder_dim + 1, rec_dims) # +1 for delta t 340 | self.rnn_cell_dec = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t 341 | elif cell == "expdecay": 342 | self.rnn_cell_enc = GRUCellExpDecay( 343 | input_size = encoder_dim, 344 | input_size_for_decay = input_dim, 345 | hidden_size = rec_dims, 346 | device = device) 347 | self.rnn_cell_dec = GRUCellExpDecay( 348 | input_size = encoder_dim, 349 | input_size_for_decay = input_dim, 350 | hidden_size = latent_dim, 351 | device = device) 352 | else: 353 | raise Exception("Unknown RNN cell: {}".format(cell)) 354 | 355 | self.z0_net = nn.Sequential( 356 | nn.Linear(rec_dims, n_units), 357 | nn.Tanh(), 358 | nn.Linear(n_units, latent_dim * 2),) 359 | utils.init_network_weights(self.z0_net) 360 | 361 | self.decoder = nn.Sequential( 362 | nn.Linear(latent_dim, n_units), 363 | nn.Tanh(), 364 | nn.Linear(n_units, input_dim),) 365 | 366 | #utils.init_network_weights(self.encoder) 367 | utils.init_network_weights(self.decoder) 368 | 369 | if input_space_decay: 370 | self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device) 371 | self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device) 372 | self.input_space_decay = input_space_decay 373 | 374 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 375 | mask = None, n_traj_samples = 1, mode = None): 376 | 377 | assert(mask is not None) 378 | 379 | batch_size = data.size(0) 380 | zero_delta_t = torch.Tensor([0.]).to(self.device) 381 | 382 | # run encoder backwards 383 | run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1]) 384 | 385 | if run_backwards: 386 | # Look at data in the reverse order: from later points to the first 387 | data = utils.reverse(data) 388 | mask = utils.reverse(mask) 389 | 390 | delta_ts = truth_time_steps[1:] - truth_time_steps[:-1] 391 | if run_backwards: 392 | # we are going backwards in time 393 | delta_ts = utils.reverse(delta_ts) 394 | 395 | 396 | delta_ts = torch.cat((delta_ts, zero_delta_t)) 397 | if len(delta_ts.size()) == 1: 398 | # delta_ts are shared for all trajectories in a batch 399 | assert(data.size(1) == delta_ts.size(0)) 400 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1)) 401 | 402 | input_decay_params = None 403 | if self.input_space_decay: 404 | input_decay_params = (self.w_input_decay, self.b_input_decay) 405 | 406 | hidden_state, _ = run_rnn(data, delta_ts, 407 | cell = self.rnn_cell_enc, mask = mask, 408 | input_decay_params = input_decay_params) 409 | 410 | z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state)) 411 | z0_std = z0_std.abs() 412 | z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std) 413 | 414 | # Decoder # # # # # # # # # # # # # # # # # # # # 415 | delta_ts = torch.cat((zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1])) 416 | if len(delta_ts.size()) == 1: 417 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1)) 418 | 419 | _, all_hiddens = run_rnn(data, delta_ts, 420 | cell = self.rnn_cell_dec, 421 | first_hidden = z0_sample, feed_previous = True, 422 | n_steps = time_steps_to_predict.size(0), 423 | decoder = self.decoder, 424 | input_decay_params = input_decay_params) 425 | 426 | outputs = self.decoder(all_hiddens) 427 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. 428 | first_point = data[:,0,:] 429 | outputs = utils.shift_outputs(outputs, first_point) 430 | 431 | extra_info = {"first_point": (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0))} 432 | 433 | if self.use_binary_classif: 434 | if self.classif_per_tp: 435 | extra_info["label_predictions"] = self.classifier(all_hiddens) 436 | else: 437 | extra_info["label_predictions"] = self.classifier(z0_mean).reshape(1,-1) 438 | 439 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] 440 | return outputs, extra_info 441 | 442 | 443 | 444 | -------------------------------------------------------------------------------- /latentode/latent_ode/mujoco_physics.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Authors: Yulia Rubanova and Ricky Chen 4 | ########################### 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | from lib.utils import get_dict_template 10 | import lib.utils as utils 11 | from torchvision.datasets.utils import download_url 12 | 13 | class HopperPhysics(object): 14 | 15 | T = 200 16 | D = 14 17 | 18 | n_training_samples = 10000 19 | 20 | training_file = 'training.pt' 21 | 22 | def __init__(self, root, download = True, generate=False, device = torch.device("cpu")): 23 | self.root = root 24 | if download: 25 | self._download() 26 | 27 | if generate: 28 | self._generate_dataset() 29 | 30 | if not self._check_exists(): 31 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 32 | 33 | data_file = os.path.join(self.data_folder, self.training_file) 34 | 35 | self.data = torch.Tensor(torch.load(data_file)).to(device) 36 | self.data, self.data_min, self.data_max = utils.normalize_data(self.data) 37 | 38 | self.device =device 39 | 40 | def visualize(self, traj, plot_name = 'traj', dirname='hopper_imgs', video_name = None): 41 | r"""Generates images of the trajectory and stores them as /traj-.jpg""" 42 | 43 | T, D = traj.size() 44 | 45 | traj = traj.cpu() * self.data_max.cpu() + self.data_min.cpu() 46 | 47 | try: 48 | from dm_control import suite # noqa: F401 49 | except ImportError as e: 50 | raise Exception('Deepmind Control Suite is required to visualize the dataset.') from e 51 | 52 | try: 53 | from PIL import Image # noqa: F401 54 | except ImportError as e: 55 | raise Exception('PIL is required to visualize the dataset.') from e 56 | 57 | def save_image(data, filename): 58 | im = Image.fromarray(data) 59 | im.save(filename) 60 | 61 | os.makedirs(dirname, exist_ok=True) 62 | 63 | env = suite.load('hopper', 'stand') 64 | physics = env.physics 65 | 66 | for t in range(T): 67 | with physics.reset_context(): 68 | physics.data.qpos[:] = traj[t, :D // 2] 69 | physics.data.qvel[:] = traj[t, D // 2:] 70 | save_image( 71 | physics.render(height=480, width=640, camera_id=0), 72 | os.path.join(dirname, plot_name + '-{:03d}.jpg'.format(t)) 73 | ) 74 | 75 | def _generate_dataset(self): 76 | if self._check_exists(): 77 | return 78 | os.makedirs(self.data_folder, exist_ok=True) 79 | print('Generating dataset...') 80 | train_data = self._generate_random_trajectories(self.n_training_samples) 81 | torch.save(train_data, os.path.join(self.data_folder, self.training_file)) 82 | 83 | def _download(self): 84 | if self._check_exists(): 85 | return 86 | 87 | print("Downloading the dataset [325MB] ...") 88 | os.makedirs(self.data_folder, exist_ok=True) 89 | url = "http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt" 90 | download_url(url, self.data_folder, "training.pt", None) 91 | 92 | def _generate_random_trajectories(self, n_samples): 93 | 94 | try: 95 | from dm_control import suite # noqa: F401 96 | except ImportError as e: 97 | raise Exception('Deepmind Control Suite is required to generate the dataset.') from e 98 | 99 | env = suite.load('hopper', 'stand') 100 | physics = env.physics 101 | 102 | # Store the state of the RNG to restore later. 103 | st0 = np.random.get_state() 104 | np.random.seed(123) 105 | 106 | data = np.zeros((n_samples, self.T, self.D)) 107 | for i in range(n_samples): 108 | with physics.reset_context(): 109 | # x and z positions of the hopper. We want z > 0 for the hopper to stay above ground. 110 | physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2) 111 | physics.data.qpos[2:] = np.random.uniform(-2, 2, size=physics.data.qpos[2:].shape) 112 | physics.data.qvel[:] = np.random.uniform(-5, 5, size=physics.data.qvel.shape) 113 | for t in range(self.T): 114 | data[i, t, :self.D // 2] = physics.data.qpos 115 | data[i, t, self.D // 2:] = physics.data.qvel 116 | physics.step() 117 | 118 | # Restore RNG. 119 | np.random.set_state(st0) 120 | return data 121 | 122 | def _check_exists(self): 123 | return os.path.exists(os.path.join(self.data_folder, self.training_file)) 124 | 125 | @property 126 | def data_folder(self): 127 | return os.path.join(self.root, self.__class__.__name__) 128 | 129 | # def __getitem__(self, index): 130 | # return self.data[index] 131 | 132 | def get_dataset(self): 133 | return self.data 134 | 135 | def __len__(self): 136 | return len(self.data) 137 | 138 | def size(self, ind = None): 139 | if ind is not None: 140 | return self.data.shape[ind] 141 | return self.data.shape 142 | 143 | def __repr__(self): 144 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 145 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 146 | fmt_str += ' Root Location: {}\n'.format(self.root) 147 | return fmt_str 148 | 149 | -------------------------------------------------------------------------------- /latentode/latent_ode/person_activity.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Authors: Yulia Rubanova and Ricky Chen 4 | ########################### 5 | 6 | import os 7 | 8 | import lib.utils as utils 9 | import numpy as np 10 | import tarfile 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torchvision.datasets.utils import download_url 14 | from lib.utils import get_device 15 | 16 | # Adapted from: https://github.com/rtqichen/time-series-datasets 17 | 18 | class PersonActivity(object): 19 | urls = [ 20 | 'https://archive.ics.uci.edu/ml/machine-learning-databases/00196/ConfLongDemo_JSI.txt', 21 | ] 22 | 23 | tag_ids = [ 24 | "010-000-024-033", #"ANKLE_LEFT", 25 | "010-000-030-096", #"ANKLE_RIGHT", 26 | "020-000-033-111", #"CHEST", 27 | "020-000-032-221" #"BELT" 28 | ] 29 | 30 | tag_dict = {k: i for i, k in enumerate(tag_ids)} 31 | 32 | label_names = [ 33 | "walking", 34 | "falling", 35 | "lying down", 36 | "lying", 37 | "sitting down", 38 | "sitting", 39 | "standing up from lying", 40 | "on all fours", 41 | "sitting on the ground", 42 | "standing up from sitting", 43 | "standing up from sit on grnd" 44 | ] 45 | 46 | #label_dict = {k: i for i, k in enumerate(label_names)} 47 | 48 | #Merge similar labels into one class 49 | label_dict = { 50 | "walking": 0, 51 | "falling": 1, 52 | "lying": 2, 53 | "lying down": 2, 54 | "sitting": 3, 55 | "sitting down" : 3, 56 | "standing up from lying": 4, 57 | "standing up from sitting": 4, 58 | "standing up from sit on grnd": 4, 59 | "on all fours": 5, 60 | "sitting on the ground": 6 61 | } 62 | 63 | 64 | def __init__(self, root, download=False, 65 | reduce='average', max_seq_length = 50, 66 | n_samples = None, device = torch.device("cpu")): 67 | 68 | self.root = root 69 | self.reduce = reduce 70 | self.max_seq_length = max_seq_length 71 | 72 | if download: 73 | self.download() 74 | 75 | if not self._check_exists(): 76 | raise RuntimeError('Dataset not found. You can use download=True to download it') 77 | 78 | if device == torch.device("cpu"): 79 | self.data = torch.load(os.path.join(self.processed_folder, self.data_file), map_location='cpu') 80 | else: 81 | self.data = torch.load(os.path.join(self.processed_folder, self.data_file)) 82 | 83 | if n_samples is not None: 84 | self.data = self.data[:n_samples] 85 | 86 | def download(self): 87 | if self._check_exists(): 88 | return 89 | 90 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 91 | 92 | os.makedirs(self.raw_folder, exist_ok=True) 93 | os.makedirs(self.processed_folder, exist_ok=True) 94 | 95 | def save_record(records, record_id, tt, vals, mask, labels): 96 | tt = torch.tensor(tt).to(self.device) 97 | 98 | vals = torch.stack(vals) 99 | mask = torch.stack(mask) 100 | labels = torch.stack(labels) 101 | 102 | # flatten the measurements for different tags 103 | vals = vals.reshape(vals.size(0), -1) 104 | mask = mask.reshape(mask.size(0), -1) 105 | assert(len(tt) == vals.size(0)) 106 | assert(mask.size(0) == vals.size(0)) 107 | assert(labels.size(0) == vals.size(0)) 108 | 109 | #records.append((record_id, tt, vals, mask, labels)) 110 | 111 | seq_length = len(tt) 112 | # split the long time series into smaller ones 113 | offset = 0 114 | slide = self.max_seq_length // 2 115 | 116 | while (offset + self.max_seq_length < seq_length): 117 | idx = range(offset, offset + self.max_seq_length) 118 | 119 | first_tp = tt[idx][0] 120 | records.append((record_id, tt[idx] - first_tp, vals[idx], mask[idx], labels[idx])) 121 | offset += slide 122 | 123 | for url in self.urls: 124 | filename = url.rpartition('/')[2] 125 | download_url(url, self.raw_folder, filename, None) 126 | 127 | print('Processing {}...'.format(filename)) 128 | 129 | dirname = os.path.join(self.raw_folder) 130 | records = [] 131 | first_tp = None 132 | 133 | for txtfile in os.listdir(dirname): 134 | with open(os.path.join(dirname, txtfile)) as f: 135 | lines = f.readlines() 136 | prev_time = -1 137 | tt = [] 138 | 139 | record_id = None 140 | for l in lines: 141 | cur_record_id, tag_id, time, date, val1, val2, val3, label = l.strip().split(',') 142 | value_vec = torch.Tensor((float(val1), float(val2), float(val3))).to(self.device) 143 | time = float(time) 144 | 145 | if cur_record_id != record_id: 146 | if record_id is not None: 147 | save_record(records, record_id, tt, vals, mask, labels) 148 | tt, vals, mask, nobs, labels = [], [], [], [], [] 149 | record_id = cur_record_id 150 | 151 | tt = [torch.zeros(1).to(self.device)] 152 | vals = [torch.zeros(len(self.tag_ids),3).to(self.device)] 153 | mask = [torch.zeros(len(self.tag_ids),3).to(self.device)] 154 | nobs = [torch.zeros(len(self.tag_ids)).to(self.device)] 155 | labels = [torch.zeros(len(self.label_names)).to(self.device)] 156 | 157 | first_tp = time 158 | time = round((time - first_tp)/ 10**5) 159 | prev_time = time 160 | else: 161 | # for speed -- we actually don't need to quantize it in Latent ODE 162 | time = round((time - first_tp)/ 10**5) # quatizing by 100 ms. 10,000 is one millisecond, 10,000,000 is one second 163 | 164 | if time != prev_time: 165 | tt.append(time) 166 | vals.append(torch.zeros(len(self.tag_ids),3).to(self.device)) 167 | mask.append(torch.zeros(len(self.tag_ids),3).to(self.device)) 168 | nobs.append(torch.zeros(len(self.tag_ids)).to(self.device)) 169 | labels.append(torch.zeros(len(self.label_names)).to(self.device)) 170 | prev_time = time 171 | 172 | if tag_id in self.tag_ids: 173 | n_observations = nobs[-1][self.tag_dict[tag_id]] 174 | if (self.reduce == 'average') and (n_observations > 0): 175 | prev_val = vals[-1][self.tag_dict[tag_id]] 176 | new_val = (prev_val * n_observations + value_vec) / (n_observations + 1) 177 | vals[-1][self.tag_dict[tag_id]] = new_val 178 | else: 179 | vals[-1][self.tag_dict[tag_id]] = value_vec 180 | 181 | mask[-1][self.tag_dict[tag_id]] = 1 182 | nobs[-1][self.tag_dict[tag_id]] += 1 183 | 184 | if label in self.label_names: 185 | if torch.sum(labels[-1][self.label_dict[label]]) == 0: 186 | labels[-1][self.label_dict[label]] = 1 187 | else: 188 | assert tag_id == 'RecordID', 'Read unexpected tag id {}'.format(tag_id) 189 | save_record(records, record_id, tt, vals, mask, labels) 190 | 191 | torch.save( 192 | records, 193 | os.path.join(self.processed_folder, 'data.pt') 194 | ) 195 | 196 | print('Done!') 197 | 198 | def _check_exists(self): 199 | for url in self.urls: 200 | filename = url.rpartition('/')[2] 201 | if not os.path.exists( 202 | os.path.join(self.processed_folder, 'data.pt') 203 | ): 204 | return False 205 | return True 206 | 207 | @property 208 | def raw_folder(self): 209 | return os.path.join(self.root, self.__class__.__name__, 'raw') 210 | 211 | @property 212 | def processed_folder(self): 213 | return os.path.join(self.root, self.__class__.__name__, 'processed') 214 | 215 | @property 216 | def data_file(self): 217 | return 'data.pt' 218 | 219 | def __getitem__(self, index): 220 | return self.data[index] 221 | 222 | def __len__(self): 223 | return len(self.data) 224 | 225 | def __repr__(self): 226 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 227 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 228 | fmt_str += ' Root Location: {}\n'.format(self.root) 229 | fmt_str += ' Max length: {}\n'.format(self.max_seq_length) 230 | fmt_str += ' Reduce: {}\n'.format(self.reduce) 231 | return fmt_str 232 | 233 | def get_person_id(record_id): 234 | # The first letter is the person id 235 | person_id = record_id[0] 236 | person_id = ord(person_id) - ord("A") 237 | return person_id 238 | 239 | 240 | 241 | def variable_time_collate_fn_activity(batch, args, device = torch.device("cpu"), data_type = "train"): 242 | """ 243 | Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where 244 | - record_id is a patient id 245 | - tt is a 1-dimensional tensor containing T time values of observations. 246 | - vals is a (T, D) tensor containing observed values for D variables. 247 | - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise. 248 | - labels is a list of labels for the current patient, if labels are available. Otherwise None. 249 | Returns: 250 | combined_tt: The union of all time observations. 251 | combined_vals: (M, T, D) tensor containing the observed values. 252 | combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise. 253 | """ 254 | D = batch[0][2].shape[1] 255 | N = batch[0][-1].shape[1] # number of labels 256 | 257 | combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True) 258 | combined_tt = combined_tt.to(device) 259 | 260 | offset = 0 261 | combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device) 262 | combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device) 263 | combined_labels = torch.zeros([len(batch), len(combined_tt), N]).to(device) 264 | 265 | for b, (record_id, tt, vals, mask, labels) in enumerate(batch): 266 | tt = tt.to(device) 267 | vals = vals.to(device) 268 | mask = mask.to(device) 269 | labels = labels.to(device) 270 | 271 | indices = inverse_indices[offset:offset + len(tt)] 272 | offset += len(tt) 273 | 274 | combined_vals[b, indices] = vals 275 | combined_mask[b, indices] = mask 276 | combined_labels[b, indices] = labels 277 | 278 | combined_tt = combined_tt.float() 279 | 280 | if torch.max(combined_tt) != 0.: 281 | combined_tt = combined_tt / torch.max(combined_tt) 282 | 283 | data_dict = { 284 | "data": combined_vals, 285 | "time_steps": combined_tt, 286 | "mask": combined_mask, 287 | "labels": combined_labels} 288 | 289 | data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type) 290 | return data_dict 291 | 292 | 293 | if __name__ == '__main__': 294 | torch.manual_seed(1991) 295 | 296 | dataset = PersonActivity('data/PersonActivity', download=True) 297 | dataloader = DataLoader(dataset, batch_size=30, shuffle=True, collate_fn= variable_time_collate_fn_activity) 298 | dataloader.__iter__().next() 299 | -------------------------------------------------------------------------------- /latentode/latent_ode/physionet.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Authors: Yulia Rubanova and Ricky Chen 4 | ########################### 5 | 6 | import os 7 | import matplotlib 8 | if os.path.exists("/Users/yulia"): 9 | matplotlib.use('TkAgg') 10 | else: 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot 13 | import matplotlib.pyplot as plt 14 | 15 | import lib.utils as utils 16 | import numpy as np 17 | import tarfile 18 | import torch 19 | from torch.utils.data import DataLoader 20 | from torchvision.datasets.utils import download_url 21 | from lib.utils import get_device 22 | 23 | # Adapted from: https://github.com/rtqichen/time-series-datasets 24 | 25 | # get minimum and maximum for each feature across the whole dataset 26 | def get_data_min_max(records): 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | data_min, data_max = None, None 30 | inf = torch.Tensor([float("Inf")])[0].to(device) 31 | 32 | for b, (record_id, tt, vals, mask, labels) in enumerate(records): 33 | n_features = vals.size(-1) 34 | 35 | batch_min = [] 36 | batch_max = [] 37 | for i in range(n_features): 38 | non_missing_vals = vals[:,i][mask[:,i] == 1] 39 | if len(non_missing_vals) == 0: 40 | batch_min.append(inf) 41 | batch_max.append(-inf) 42 | else: 43 | batch_min.append(torch.min(non_missing_vals)) 44 | batch_max.append(torch.max(non_missing_vals)) 45 | 46 | batch_min = torch.stack(batch_min) 47 | batch_max = torch.stack(batch_max) 48 | 49 | if (data_min is None) and (data_max is None): 50 | data_min = batch_min 51 | data_max = batch_max 52 | else: 53 | data_min = torch.min(data_min, batch_min) 54 | data_max = torch.max(data_max, batch_max) 55 | 56 | return data_min, data_max 57 | 58 | 59 | class PhysioNet(object): 60 | 61 | urls = [ 62 | 'https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download', 63 | 'https://physionet.org/files/challenge-2012/1.0.0/set-b.tar.gz?download', 64 | ] 65 | 66 | outcome_urls = ['https://physionet.org/files/challenge-2012/1.0.0/Outcomes-a.txt'] 67 | 68 | params = [ 69 | 'Age', 'Gender', 'Height', 'ICUType', 'Weight', 'Albumin', 'ALP', 'ALT', 'AST', 'Bilirubin', 'BUN', 70 | 'Cholesterol', 'Creatinine', 'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'Mg', 71 | 'MAP', 'MechVent', 'Na', 'NIDiasABP', 'NIMAP', 'NISysABP', 'PaCO2', 'PaO2', 'pH', 'Platelets', 'RespRate', 72 | 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT', 'Urine', 'WBC' 73 | ] 74 | 75 | params_dict = {k: i for i, k in enumerate(params)} 76 | 77 | labels = [ "SAPS-I", "SOFA", "Length_of_stay", "Survival", "In-hospital_death" ] 78 | labels_dict = {k: i for i, k in enumerate(labels)} 79 | 80 | def __init__(self, root, train=True, download=False, 81 | quantization = 0.1, n_samples = None, device = torch.device("cpu")): 82 | 83 | self.root = root 84 | self.train = train 85 | self.reduce = "average" 86 | self.quantization = quantization 87 | 88 | if download: 89 | self.download() 90 | 91 | if not self._check_exists(): 92 | raise RuntimeError('Dataset not found. You can use download=True to download it') 93 | 94 | if self.train: 95 | data_file = self.training_file 96 | else: 97 | data_file = self.test_file 98 | 99 | if device == torch.device("cpu"): 100 | self.data = torch.load(os.path.join(self.processed_folder, data_file), map_location='cpu') 101 | self.labels = torch.load(os.path.join(self.processed_folder, self.label_file), map_location='cpu') 102 | else: 103 | self.data = torch.load(os.path.join(self.processed_folder, data_file)) 104 | self.labels = torch.load(os.path.join(self.processed_folder, self.label_file)) 105 | 106 | if n_samples is not None: 107 | self.data = self.data[:n_samples] 108 | self.labels = self.labels[:n_samples] 109 | 110 | 111 | def download(self): 112 | if self._check_exists(): 113 | return 114 | 115 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 116 | 117 | os.makedirs(self.raw_folder, exist_ok=True) 118 | os.makedirs(self.processed_folder, exist_ok=True) 119 | 120 | # Download outcome data 121 | for url in self.outcome_urls: 122 | filename = url.rpartition('/')[2] 123 | download_url(url, self.raw_folder, filename, None) 124 | 125 | txtfile = os.path.join(self.raw_folder, filename) 126 | with open(txtfile) as f: 127 | lines = f.readlines() 128 | outcomes = {} 129 | for l in lines[1:]: 130 | l = l.rstrip().split(',') 131 | record_id, labels = l[0], np.array(l[1:]).astype(float) 132 | outcomes[record_id] = torch.Tensor(labels).to(self.device) 133 | 134 | torch.save( 135 | labels, 136 | os.path.join(self.processed_folder, filename.split('.')[0] + '.pt') 137 | ) 138 | 139 | for url in self.urls: 140 | filename = url.rpartition('/')[2] 141 | download_url(url, self.raw_folder, filename, None) 142 | tar = tarfile.open(os.path.join(self.raw_folder, filename), "r:gz") 143 | tar.extractall(self.raw_folder) 144 | tar.close() 145 | 146 | print('Processing {}...'.format(filename)) 147 | 148 | dirname = os.path.join(self.raw_folder, filename.split('.')[0]) 149 | patients = [] 150 | total = 0 151 | for txtfile in os.listdir(dirname): 152 | record_id = txtfile.split('.')[0] 153 | with open(os.path.join(dirname, txtfile)) as f: 154 | lines = f.readlines() 155 | prev_time = 0 156 | tt = [0.] 157 | vals = [torch.zeros(len(self.params)).to(self.device)] 158 | mask = [torch.zeros(len(self.params)).to(self.device)] 159 | nobs = [torch.zeros(len(self.params))] 160 | for l in lines[1:]: 161 | total += 1 162 | time, param, val = l.split(',') 163 | # Time in hours 164 | time = float(time.split(':')[0]) + float(time.split(':')[1]) / 60. 165 | # round up the time stamps (up to 6 min by default) 166 | # used for speed -- we actually don't need to quantize it in Latent ODE 167 | time = round(time / self.quantization) * self.quantization 168 | 169 | if time != prev_time: 170 | tt.append(time) 171 | vals.append(torch.zeros(len(self.params)).to(self.device)) 172 | mask.append(torch.zeros(len(self.params)).to(self.device)) 173 | nobs.append(torch.zeros(len(self.params)).to(self.device)) 174 | prev_time = time 175 | 176 | if param in self.params_dict: 177 | #vals[-1][self.params_dict[param]] = float(val) 178 | n_observations = nobs[-1][self.params_dict[param]] 179 | if self.reduce == 'average' and n_observations > 0: 180 | prev_val = vals[-1][self.params_dict[param]] 181 | new_val = (prev_val * n_observations + float(val)) / (n_observations + 1) 182 | vals[-1][self.params_dict[param]] = new_val 183 | else: 184 | vals[-1][self.params_dict[param]] = float(val) 185 | mask[-1][self.params_dict[param]] = 1 186 | nobs[-1][self.params_dict[param]] += 1 187 | else: 188 | assert param == 'RecordID', 'Read unexpected param {}'.format(param) 189 | tt = torch.tensor(tt).to(self.device) 190 | vals = torch.stack(vals) 191 | mask = torch.stack(mask) 192 | 193 | labels = None 194 | if record_id in outcomes: 195 | # Only training set has labels 196 | labels = outcomes[record_id] 197 | # Out of 5 label types provided for Physionet, take only the last one -- mortality 198 | labels = labels[4] 199 | 200 | patients.append((record_id, tt, vals, mask, labels)) 201 | 202 | torch.save( 203 | patients, 204 | os.path.join(self.processed_folder, 205 | filename.split('.')[0] + "_" + str(self.quantization) + '.pt') 206 | ) 207 | 208 | print('Done!') 209 | 210 | def _check_exists(self): 211 | for url in self.urls: 212 | filename = url.rpartition('/')[2] 213 | 214 | if not os.path.exists( 215 | os.path.join(self.processed_folder, 216 | filename.split('.')[0] + "_" + str(self.quantization) + '.pt') 217 | ): 218 | return False 219 | return True 220 | 221 | @property 222 | def raw_folder(self): 223 | return os.path.join(self.root, self.__class__.__name__, 'raw') 224 | 225 | @property 226 | def processed_folder(self): 227 | return os.path.join(self.root, self.__class__.__name__, 'processed') 228 | 229 | @property 230 | def training_file(self): 231 | return 'set-a_{}.pt'.format(self.quantization) 232 | 233 | @property 234 | def test_file(self): 235 | return 'set-b_{}.pt'.format(self.quantization) 236 | 237 | @property 238 | def label_file(self): 239 | return 'Outcomes-a.pt' 240 | 241 | def __getitem__(self, index): 242 | return self.data[index] 243 | 244 | def __len__(self): 245 | return len(self.data) 246 | 247 | def get_label(self, record_id): 248 | return self.labels[record_id] 249 | 250 | def __repr__(self): 251 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 252 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 253 | fmt_str += ' Split: {}\n'.format('train' if self.train is True else 'test') 254 | fmt_str += ' Root Location: {}\n'.format(self.root) 255 | fmt_str += ' Quantization: {}\n'.format(self.quantization) 256 | fmt_str += ' Reduce: {}\n'.format(self.reduce) 257 | return fmt_str 258 | 259 | def visualize(self, timesteps, data, mask, plot_name): 260 | width = 15 261 | height = 15 262 | 263 | non_zero_attributes = (torch.sum(mask,0) > 2).numpy() 264 | non_zero_idx = [i for i in range(len(non_zero_attributes)) if non_zero_attributes[i] == 1.] 265 | n_non_zero = sum(non_zero_attributes) 266 | 267 | mask = mask[:, non_zero_idx] 268 | data = data[:, non_zero_idx] 269 | 270 | params_non_zero = [self.params[i] for i in non_zero_idx] 271 | params_dict = {k: i for i, k in enumerate(params_non_zero)} 272 | 273 | n_col = 3 274 | n_row = n_non_zero // n_col + (n_non_zero % n_col > 0) 275 | fig, ax_list = plt.subplots(n_row, n_col, figsize=(width, height), facecolor='white') 276 | 277 | #for i in range(len(self.params)): 278 | for i in range(n_non_zero): 279 | param = params_non_zero[i] 280 | param_id = params_dict[param] 281 | 282 | tp_mask = mask[:,param_id].long() 283 | 284 | tp_cur_param = timesteps[tp_mask == 1.] 285 | data_cur_param = data[tp_mask == 1., param_id] 286 | 287 | ax_list[i // n_col, i % n_col].plot(tp_cur_param.numpy(), data_cur_param.numpy(), marker='o') 288 | ax_list[i // n_col, i % n_col].set_title(param) 289 | 290 | fig.tight_layout() 291 | fig.savefig(plot_name) 292 | plt.close(fig) 293 | 294 | 295 | def variable_time_collate_fn(batch, args, device = torch.device("cpu"), data_type = "train", 296 | data_min = None, data_max = None): 297 | """ 298 | Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where 299 | - record_id is a patient id 300 | - tt is a 1-dimensional tensor containing T time values of observations. 301 | - vals is a (T, D) tensor containing observed values for D variables. 302 | - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise. 303 | - labels is a list of labels for the current patient, if labels are available. Otherwise None. 304 | Returns: 305 | combined_tt: The union of all time observations. 306 | combined_vals: (M, T, D) tensor containing the observed values. 307 | combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise. 308 | """ 309 | D = batch[0][2].shape[1] 310 | combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True) 311 | combined_tt = combined_tt.to(device) 312 | 313 | offset = 0 314 | combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device) 315 | combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device) 316 | 317 | combined_labels = None 318 | N_labels = 1 319 | 320 | combined_labels = torch.zeros(len(batch), N_labels) + torch.tensor(float('nan')) 321 | combined_labels = combined_labels.to(device = device) 322 | 323 | for b, (record_id, tt, vals, mask, labels) in enumerate(batch): 324 | tt = tt.to(device) 325 | vals = vals.to(device) 326 | mask = mask.to(device) 327 | if labels is not None: 328 | labels = labels.to(device) 329 | 330 | indices = inverse_indices[offset:offset + len(tt)] 331 | offset += len(tt) 332 | 333 | combined_vals[b, indices] = vals 334 | combined_mask[b, indices] = mask 335 | 336 | if labels is not None: 337 | combined_labels[b] = labels 338 | 339 | combined_vals, _, _ = utils.normalize_masked_data(combined_vals, combined_mask, 340 | att_min = data_min, att_max = data_max) 341 | 342 | if torch.max(combined_tt) != 0.: 343 | combined_tt = combined_tt / torch.max(combined_tt) 344 | 345 | data_dict = { 346 | "data": combined_vals, 347 | "time_steps": combined_tt, 348 | "mask": combined_mask, 349 | "labels": combined_labels} 350 | 351 | data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type) 352 | return data_dict 353 | 354 | if __name__ == '__main__': 355 | torch.manual_seed(1991) 356 | 357 | dataset = PhysioNet('data/physionet', train=False, download=True) 358 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=variable_time_collate_fn) 359 | print(dataloader.__iter__().next()) 360 | -------------------------------------------------------------------------------- /latentode/latent_ode/train-activity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif -b 100 4 | -------------------------------------------------------------------------------- /latentode/latent_ode/train-ecg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #python3 run_models.py --niters 200 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 1000 4 | 5 | #python3 run_models.py --niters 400 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 4000 6 | 7 | # smaller learning rate 8 | python3 run_models.py --niters 400 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 2000 --lr 0.0001 9 | -------------------------------------------------------------------------------- /latentode/latent_ode/train-periodic100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 --viz -t 100 4 | -------------------------------------------------------------------------------- /latentode/latent_ode/train-periodic1000.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 --viz -t 1000 4 | -------------------------------------------------------------------------------- /latentode/ours_impl/__pycache__/lv_field.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/latentode/ours_impl/__pycache__/lv_field.cpython-38.pyc -------------------------------------------------------------------------------- /latentode/ours_impl/lv_field.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import FrEIA.framework as Ff 9 | import FrEIA.modules as Fm 10 | 11 | 12 | class LVField(nn.Module): 13 | 14 | def __init__(self, dim=3, augmented_dim=6, hidden_dim=200, num_layers=8, e_vec_factor=1.0, t_d_factor=1.0): 15 | super().__init__() 16 | vec = torch.zeros(augmented_dim) 17 | vec[:dim] = .1 18 | self.dim = dim 19 | self.augmented_dim = augmented_dim 20 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0)) 21 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0) 22 | 23 | self.e_vec_factor = e_vec_factor 24 | self.t_d_factor = t_d_factor 25 | 26 | # build inn 27 | def subnet_fc(dims_in, dims_out): 28 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(), 29 | nn.Linear(hidden_dim, dims_out)) 30 | 31 | self.inn = Ff.SequenceINN(augmented_dim) 32 | for k in range(num_layers): 33 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc 34 | # ,permute_soft=True 35 | ) 36 | 37 | # H = 1000 38 | # self._nn = torch.nn.Sequential( 39 | # torch.nn.Linear(augmented_dim, H), 40 | # torch.nn.ReLU(), 41 | # torch.nn.Linear(H, H), 42 | # torch.nn.ReLU(), 43 | # torch.nn.Linear(H, H), 44 | # torch.nn.ReLU(), 45 | # torch.nn.Linear(H, augmented_dim), 46 | # ) 47 | 48 | # def node(self, t, x): 49 | # return self._nn(x) 50 | 51 | 52 | # def eigen_ode(self,w_vec,init_x,t_d): 53 | # e_val, e_vec= w_vec[0], w_vec[1:] 54 | # init_v = torch.mm(torch.linalg.inv(e_vec), init_x.reshape((-1,1))) 55 | 56 | # return torch.mm(init_v.T * e_vec, torch.exp(e_val[:, None] * t_d)).T 57 | 58 | def eigen_ode(self, w_vec,init_x,t_d): 59 | from torchdiffeq import odeint, odeint_adjoint 60 | 61 | #print(init_x.shape) 62 | #print(t_d.shape) 63 | # out = odeint(func=self.node, y0=init_x, t=t_d, method='midpoint', options={ 64 | # 'step_size': 1, 65 | # }) 66 | # out = out.transpose(0,1) 67 | # #print(out) 68 | # #print(out.shape) 69 | # return out 70 | 71 | 72 | e_val,e_vec=(w_vec[0],w_vec[1:]) 73 | 74 | # #_e_vec = e_vec 75 | # #_t_d = t_d 76 | # #_t_d = t_d * 0.001 77 | # 78 | # # for act 79 | # _e_vec = e_vec * 0.001 80 | # _t_d = t_d * 0.1 81 | # 82 | # # for periodic 83 | # _e_vec = e_vec #* 0.001 84 | # _t_d = t_d #* 0.1 85 | # 86 | # 87 | # _e_vec = e_vec * 1.0 88 | # _t_d = t_d * 0.01 89 | 90 | _e_vec = e_vec * self.e_vec_factor 91 | _t_d = t_d * self.t_d_factor 92 | 93 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None]) 94 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))), 95 | torch.exp( 96 | e_val[:,None] * _t_d 97 | #self._nn(t_d[:, None]).T 98 | ) 99 | .expand((init_x.shape[0],-1,-1))).transpose(1,2) 100 | return rs 101 | 102 | 103 | def _forward(self, init_v, t_d, padding, remove_padding_after): 104 | 105 | if padding: 106 | init_v = self.pad(init_v) 107 | 108 | init_v_in = self.inn(init_v)[0] 109 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d) 110 | 111 | _ori_shape = eval_lin.shape 112 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0] 113 | out = out.reshape(_ori_shape) 114 | if remove_padding_after: 115 | out = out[:, :, :self.dim] 116 | return out 117 | 118 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True): 119 | # return self._forward(init_v, t_d, padding, remove_padding_after) 120 | # print(init_v.shape) 121 | if len(init_v.shape) == 2: 122 | # batch of multiple traj 123 | return self._forward(init_v, t_d, padding, remove_padding_after) 124 | elif len(init_v.shape) == 3: 125 | # batch of multiple traj 126 | # TODO make it forward a whole batch 127 | out = [] 128 | for i in range(init_v.shape[0]): 129 | out.append(self(init_v[i], t_d, padding, remove_padding_after)) 130 | # timer.print_stats() 131 | return torch.stack(out) 132 | else: 133 | raise NotImplementedError(f"input has dimensionality {init_v.shape}") 134 | -------------------------------------------------------------------------------- /robotic/.combine_stats.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/.combine_stats.py.swp -------------------------------------------------------------------------------- /robotic/.run_stiff_time.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/.run_stiff_time.py.swp -------------------------------------------------------------------------------- /robotic/3D_Cshape_top.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/3D_Cshape_top.mat -------------------------------------------------------------------------------- /robotic/S_gt_traj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/S_gt_traj.npy -------------------------------------------------------------------------------- /robotic/bench_output/out.txt: -------------------------------------------------------------------------------- 1 | ============================== 2 | dataset C 3 | dim=3 4 | seed = 0 5 | ============================== 6 | dataset S 7 | dim=2 8 | ============================== 9 | dataset S 10 | dim=2 11 | seed = 0 12 | ============================== 13 | dataset S 14 | dim=2 15 | seed = 0 16 | -------------------------------------------------------------------------------- /robotic/cube_pick.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/cube_pick.npy -------------------------------------------------------------------------------- /rotating_MNIST/LV_stat_dic_0.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/LV_stat_dic_0.tar -------------------------------------------------------------------------------- /rotating_MNIST/autoencoder_state_dic_0.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/autoencoder_state_dic_0.tar -------------------------------------------------------------------------------- /rotating_MNIST/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as pl 3 | import torch 4 | import scipy 5 | from mnist import MNIST 6 | from scipy import ndimage 7 | import torch.nn as nn 8 | import torchdiffeq 9 | from scipy import ndimage 10 | import FrEIA.framework as Ff 11 | import FrEIA.modules as Fm 12 | 13 | mndata = MNIST('samples') 14 | images, labels = mndata.load_training() 15 | ds=np.array(images) 16 | d_l=np.array(labels) 17 | dt3=ds[d_l[:]==3] 18 | dt3_tor=torch.tensor(dt3).reshape((-1,28,28)) 19 | 20 | 21 | class Autoencoder(nn.Module): 22 | def __init__(self): 23 | super(Autoencoder, self).__init__() 24 | self.encoder = nn.Sequential( # like the Composition layer you built 25 | nn.Conv2d(1, 16, 3, stride=2, padding=1), 26 | nn.ReLU(), 27 | nn.Conv2d(16, 32, 3, stride=2, padding=1), 28 | nn.ReLU(), 29 | nn.Conv2d(32, 64, 7) 30 | 31 | ) 32 | self.decoder = nn.Sequential( 33 | nn.ConvTranspose2d(64, 32, 7), 34 | nn.ReLU(), 35 | nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), 36 | nn.ReLU(), 37 | nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), 38 | nn.Sigmoid() 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.encoder(x) 43 | #x=nn.Flatten()(x) 44 | #x=nn.Unflatten(dim=1,unflattened_size=(64,1,1))(x) 45 | x = self.decoder(x) 46 | return x 47 | 48 | def train(model, num_epochs=5, batch_size=64, learning_rate=1e-4): 49 | torch.manual_seed(42) 50 | criterion = nn.MSELoss() 51 | optimizer = torch.optim.Adam(model.parameters(), 52 | lr=learning_rate, 53 | weight_decay=1e-5) 54 | train_loader = torch.utils.data.DataLoader(all_imgs_l, 55 | batch_size=batch_size, 56 | shuffle=True) 57 | outputs = [] 58 | for epoch in range(num_epochs): 59 | for data in train_loader: 60 | img = data 61 | recon = model(img) 62 | loss = criterion(recon, img) 63 | loss.backward() 64 | optimizer.step() 65 | optimizer.zero_grad() 66 | if((epoch+1)%5==0): 67 | print('Epoch:{}, Loss:{:.4f}'.format(epoch, float(loss))) 68 | outputs.append((epoch, img, recon),) 69 | return outputs 70 | 71 | model = Autoencoder() 72 | gt_angles=np.linspace(5,177,440) 73 | class LVField(nn.Module): 74 | 75 | def __init__(self, dim=64, augmented_dim=64+32, hidden_dim=500, num_layers=8, e_vec_factor=1e-7, t_d_factor=1e-4): 76 | super().__init__() 77 | vec = torch.zeros(augmented_dim) 78 | vec[:dim] = .1 79 | self.dim = dim 80 | self.augmented_dim = augmented_dim 81 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0)) 82 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0) 83 | 84 | self.e_vec_factor = torch.tensor(e_vec_factor) 85 | self.t_d_factor = torch.tensor(t_d_factor) 86 | 87 | # build inn 88 | def subnet_fc(dims_in, dims_out): 89 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(), 90 | nn.Linear(hidden_dim, dims_out)) 91 | 92 | self.inn = Ff.SequenceINN(augmented_dim) 93 | for k in range(num_layers): 94 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc 95 | # ,permute_soft=True 96 | ) 97 | 98 | 99 | def eigen_ode(self, w_vec,init_x,t_d): 100 | from torchdiffeq import odeint, odeint_adjoint 101 | 102 | 103 | 104 | e_val,e_vec=(w_vec[0],w_vec[1:]) 105 | 106 | #e_val = -e_val**2 - 1e-10 107 | 108 | 109 | _e_vec = e_vec * self.e_vec_factor 110 | _t_d = t_d * self.t_d_factor 111 | 112 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None]) 113 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))), 114 | torch.exp( 115 | e_val[:,None] * _t_d 116 | #self._nn(t_d[:, None]).T 117 | ) 118 | .expand((init_x.shape[0],-1,-1))).transpose(1,2) 119 | return rs 120 | 121 | 122 | def _forward(self, init_v, t_d, padding, remove_padding_after): 123 | 124 | if padding: 125 | init_v = self.pad(init_v) 126 | 127 | init_v_in = self.inn(init_v)[0] 128 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d) 129 | 130 | _ori_shape = eval_lin.shape 131 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0] 132 | out = out.reshape(_ori_shape) 133 | if remove_padding_after: 134 | out = out[:, :, :self.dim] 135 | return out 136 | 137 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True): 138 | # return self._forward(init_v, t_d, padding, remove_padding_after) 139 | # print(init_v.shape) 140 | if len(init_v.shape) == 2: 141 | # batch of multiple traj 142 | return self._forward(init_v, t_d, padding, remove_padding_after) 143 | elif len(init_v.shape) == 3: 144 | # batch of multiple traj 145 | # TODO make it forward a whole batch 146 | out = [] 147 | for i in range(init_v.shape[0]): 148 | out.append(self(init_v[i], t_d, padding, remove_padding_after)) 149 | # timer.print_stats() 150 | return torch.stack(out) 151 | #else: 152 | # raise NotImplementedError(f"input has dimensionality {init_v.shape}") 153 | #create ground truth 154 | save_list=[] 155 | 156 | for n in range(100,200): 157 | #print(n) 158 | s_list=[] 159 | for i in range(len(gt_angles)): 160 | #pl.figure() 161 | rotated = ndimage.rotate(dt3_tor[n], gt_angles[i],reshape=False) 162 | #rot_len=int(len(rotated)/2) 163 | #rotated=rotated[] 164 | s_list.append(torch.tensor(rotated)[None]) 165 | s_list_tor=torch.cat(s_list) 166 | save_list.append(s_list_tor[None]) 167 | all_tor_ds=torch.cat(save_list) 168 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float() 169 | all_tor_ds_max=torch.max(all_tor_ds.reshape(44000,-1),dim=1) 170 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1)) 171 | m=LVField() 172 | 173 | 174 | 175 | #create ground truth 176 | save_list=[] 177 | 178 | for n in range(100,200): 179 | #print(n) 180 | s_list=[] 181 | for i in range(len(gt_angles)): 182 | #pl.figure() 183 | rotated = ndimage.rotate(dt3_tor[n], gt_angles[i],reshape=False) 184 | #rot_len=int(len(rotated)/2) 185 | #rotated=rotated[] 186 | s_list.append(torch.tensor(rotated)[None]) 187 | s_list_tor=torch.cat(s_list) 188 | save_list.append(s_list_tor[None]) 189 | 190 | all_tor_ds=torch.cat(save_list) 191 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float() 192 | all_tor_ds_max=torch.max(all_tor_ds.reshape(44000,-1),dim=1) 193 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1)) 194 | 195 | 196 | model.load_state_dict(torch.load('autoencoder_state_dic_0.tar')) 197 | m.load_state_dict(torch.load('LV_stat_dic_0.tar')) 198 | 199 | all_tor_rr=all_tor_ds_s.reshape((100,440,1,28,28)) 200 | test_c=all_tor_rr[:,0] 201 | enc=model.encoder(test_c) 202 | 203 | #GPU 204 | device='cpu' 205 | m_gpu=m.to(device) 206 | st_gpu=enc[0,:,0,0][None].to(device) 207 | 208 | 209 | 210 | nt_gpu_exp=torch.arange(0,44,0.1).to(device) 211 | 212 | import time 213 | time_list=[] 214 | for i in range(len(enc)): 215 | st_gpu=enc[i,:,0,0][None].to(device) 216 | start_time = time.time() 217 | out_traj=m_gpu.forward(st_gpu,nt_gpu_exp) 218 | end_time=time.time() 219 | time_list.append(end_time-start_time) 220 | print("mean (ms):") 221 | print(np.array(time_list).mean()*1000) 222 | print('std:') 223 | print(np.array(time_list).std()*1000) 224 | 225 | st_gpu=enc[:,:,0,0][None].to(device) 226 | out_traj=m_gpu.forward(st_gpu,nt_gpu_exp) 227 | 228 | rot_t1=model.decoder(out_traj[:,:,:,None,None].to('cpu').reshape((-1,64,1,1))).detach() 229 | 230 | print('mse') 231 | print(nn.MSELoss()(rot_t1,all_tor_ds_s)) 232 | -------------------------------------------------------------------------------- /rotating_MNIST/readme.md: -------------------------------------------------------------------------------- 1 | Run run_rotate.py to train and save weights, then run evaluate.py to evaluate 2 | -------------------------------------------------------------------------------- /rotating_MNIST/rotating3_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/rotating3_2.pdf -------------------------------------------------------------------------------- /rotating_MNIST/run_rotate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as pl 3 | import torch 4 | import scipy 5 | from mnist import MNIST 6 | 7 | mndata = MNIST('samples') 8 | 9 | images, labels = mndata.load_training() 10 | ds=np.array(images) 11 | d_l=np.array(labels) 12 | dt3=ds[d_l[:]==3] 13 | dt3_tor=torch.tensor(dt3).reshape((-1,28,28)) 14 | 15 | 16 | from scipy import ndimage 17 | save_list=[] 18 | print('generating data') 19 | for n in range(0,100): 20 | 21 | s_list=[] 22 | for i in range(5,180,4): 23 | #pl.figure() 24 | rotated = ndimage.rotate(dt3_tor[n], i,reshape=False) 25 | #rot_len=int(len(rotated)/2) 26 | #rotated=rotated[] 27 | s_list.append(torch.tensor(rotated)[None]) 28 | s_list_tor=torch.cat(s_list) 29 | save_list.append(s_list_tor[None]) 30 | save_list_tor=torch.cat(save_list) 31 | print('finished generating data') 32 | save_list_tor_sample=save_list_tor 33 | traj=save_list_tor_sample 34 | import torch.nn as nn 35 | import torchdiffeq 36 | 37 | class Autoencoder(nn.Module): 38 | def __init__(self): 39 | super(Autoencoder, self).__init__() 40 | self.encoder = nn.Sequential( # like the Composition layer you built 41 | nn.Conv2d(1, 16, 3, stride=2, padding=1), 42 | nn.ReLU(), 43 | nn.Conv2d(16, 32, 3, stride=2, padding=1), 44 | nn.ReLU(), 45 | nn.Conv2d(32, 64, 7) 46 | 47 | ) 48 | self.decoder = nn.Sequential( 49 | nn.ConvTranspose2d(64, 32, 7), 50 | nn.ReLU(), 51 | nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), 52 | nn.ReLU(), 53 | nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), 54 | nn.Sigmoid() 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.encoder(x) 59 | #x=nn.Flatten()(x) 60 | #x=nn.Unflatten(dim=1,unflattened_size=(64,1,1))(x) 61 | x = self.decoder(x) 62 | return x 63 | 64 | all_imgs=traj[:,:,:,:].reshape(-1,1,28,28).float() 65 | all_imgs_max=torch.max(all_imgs.reshape(4400,-1),dim=1) 66 | all_imgs=all_imgs/all_imgs_max.values.reshape((-1,1,1,1)) 67 | all_imgs_l=list(all_imgs) 68 | 69 | def train(model, num_epochs=5, batch_size=64, learning_rate=1e-4): 70 | torch.manual_seed(42) 71 | criterion = nn.MSELoss() # mean square error loss 72 | optimizer = torch.optim.Adam(model.parameters(), 73 | lr=learning_rate, 74 | weight_decay=1e-5) # <-- 75 | train_loader = torch.utils.data.DataLoader(all_imgs_l, 76 | batch_size=batch_size, 77 | shuffle=True) 78 | outputs = [] 79 | for epoch in range(num_epochs): 80 | for data in train_loader: 81 | img = data 82 | recon = model(img) 83 | loss = criterion(recon, img) 84 | loss.backward() 85 | optimizer.step() 86 | optimizer.zero_grad() 87 | 88 | print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss))) 89 | outputs.append((epoch, img, recon),) 90 | return outputs 91 | 92 | from scipy import ndimage 93 | save_list=[] 94 | print('generating test data') 95 | for n in range(100,200): 96 | print(n) 97 | s_list=[] 98 | for i in range(5,180,4): 99 | #pl.figure() 100 | rotated = ndimage.rotate(dt3_tor[n], i,reshape=False) 101 | #rot_len=int(len(rotated)/2) 102 | #rotated=rotated[] 103 | s_list.append(torch.tensor(rotated)[None]) 104 | s_list_tor=torch.cat(s_list) 105 | save_list.append(s_list_tor[None]) 106 | all_tor_ds=torch.cat(save_list) 107 | 108 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float() 109 | all_tor_ds_max=torch.max(all_tor_ds.reshape(4400,-1),dim=1) 110 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1)) 111 | print('generated test data') 112 | 113 | print('train conv autoencoder') 114 | seed = 0 115 | torch.manual_seed(seed) 116 | import random 117 | random.seed(seed) 118 | np.random.seed(seed) 119 | device='cpu' 120 | 121 | model = Autoencoder() 122 | 123 | max_epochs = 50 124 | outputs = train(model, num_epochs=max_epochs) 125 | 126 | import FrEIA.framework as Ff 127 | import FrEIA.modules as Fm 128 | 129 | 130 | class LVField(nn.Module): 131 | 132 | def __init__(self, dim=64, augmented_dim=64+32, hidden_dim=500, num_layers=8, e_vec_factor=1e-7, t_d_factor=1e-4): 133 | super().__init__() 134 | vec = torch.zeros(augmented_dim) 135 | vec[:dim] = .1 136 | self.dim = dim 137 | self.augmented_dim = augmented_dim 138 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0)) 139 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0) 140 | 141 | self.e_vec_factor = torch.tensor(e_vec_factor) 142 | self.t_d_factor = torch.tensor(t_d_factor) 143 | 144 | # build inn 145 | def subnet_fc(dims_in, dims_out): 146 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(), 147 | nn.Linear(hidden_dim, dims_out)) 148 | 149 | self.inn = Ff.SequenceINN(augmented_dim) 150 | for k in range(num_layers): 151 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc 152 | # ,permute_soft=True 153 | ) 154 | 155 | 156 | def eigen_ode(self, w_vec,init_x,t_d): 157 | from torchdiffeq import odeint, odeint_adjoint 158 | 159 | 160 | 161 | e_val,e_vec=(w_vec[0],w_vec[1:]) 162 | 163 | #e_val = -e_val**2 - 1e-10 164 | 165 | 166 | _e_vec = e_vec * self.e_vec_factor 167 | _t_d = t_d * self.t_d_factor 168 | 169 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None]) 170 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))), 171 | torch.exp( 172 | e_val[:,None] * _t_d 173 | #self._nn(t_d[:, None]).T 174 | ) 175 | .expand((init_x.shape[0],-1,-1))).transpose(1,2) 176 | return rs 177 | 178 | 179 | def _forward(self, init_v, t_d, padding, remove_padding_after): 180 | 181 | if padding: 182 | init_v = self.pad(init_v) 183 | 184 | init_v_in = self.inn(init_v)[0] 185 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d) 186 | 187 | _ori_shape = eval_lin.shape 188 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0] 189 | out = out.reshape(_ori_shape) 190 | if remove_padding_after: 191 | out = out[:, :, :self.dim] 192 | return out 193 | 194 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True): 195 | # return self._forward(init_v, t_d, padding, remove_padding_after) 196 | # print(init_v.shape) 197 | if len(init_v.shape) == 2: 198 | # batch of multiple traj 199 | return self._forward(init_v, t_d, padding, remove_padding_after) 200 | elif len(init_v.shape) == 3: 201 | # batch of multiple traj 202 | # TODO make it forward a whole batch 203 | out = [] 204 | for i in range(init_v.shape[0]): 205 | out.append(self(init_v[i], t_d, padding, remove_padding_after)) 206 | # timer.print_stats() 207 | return torch.stack(out) 208 | else: 209 | raise NotImplementedError(f"input has dimensionality {init_v.shape}") 210 | 211 | # seed = 0 212 | # torch.manual_seed(seed) 213 | # import random 214 | # random.seed(seed) 215 | # np.random.seed(seed) 216 | m=LVField() 217 | opt_ours=torch.optim.Adam(m.parameters(),lr=0.0005,weight_decay=1e-5) 218 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float() 219 | all_tor_ds_max=torch.max(all_tor_ds.reshape(4400,-1),dim=1) 220 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1)) 221 | tt=model.encoder(all_imgs) 222 | tt_re=tt.reshape((-1,44,64)) 223 | nt=torch.arange(0,44).float() 224 | start_p=tt_re[:,0].detach().clone() 225 | for i in range(5000): 226 | opt_ours.zero_grad() 227 | out_traj_o=m.forward(start_p,nt) 228 | 229 | nloss=nn.MSELoss()(out_traj_o,tt_re.detach()) 230 | if(i%20==0): 231 | print(str(i)+':'+str(float(nloss))) 232 | nloss.backward() 233 | opt_ours.step() 234 | torch.save(m.state_dict(), 'LV_stat_dic_0.tar') 235 | torch.save(model.state_dict(), 'autoencoder_state_dic_0.tar') 236 | -------------------------------------------------------------------------------- /rotating_MNIST/samples/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /rotating_MNIST/samples/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /rotating_MNIST/samples/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/train-images-idx3-ubyte -------------------------------------------------------------------------------- /rotating_MNIST/samples/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /run-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | (cd rotating_MNIST/ && python3 evaluate.py) 4 | (cd stiff_ode/ && python3 run_rober.py) 5 | (cd robotic/ && python3 run.py) 6 | (cd systems/ && python3 LV_train_run_ours.py) 7 | (cd systems/ && python3 LV_train_theirs.py) 8 | (cd systems/ && python3 LV_train_run_theirs.py) 9 | (cd systems/ && python3 lor_train_ours.py) 10 | (cd systems/ && python3 lor_train_theirs.py) 11 | (cd systems/ && python3 lor_run_all.py) 12 | (cd latentode/latent_ode/ && ./train-periodic100.sh) 13 | (cd latentode/latent_ode/ && ./train-periodic1000.sh) 14 | (cd latentode/latent_ode/ && ./train-activity.sh) 15 | (cd latentode/latent_ode/ && ./train-ecg.sh) 16 | 17 | -------------------------------------------------------------------------------- /stiff_ode/run_rober.py: -------------------------------------------------------------------------------- 1 | from scipy.integrate import solve_ivp 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import platform 5 | def robertson_conserved ( t, y ): 6 | h = np.sum ( y, axis = 0 ) 7 | return h 8 | 9 | def robertson_deriv ( t, y ): 10 | y1 = y[0] 11 | y2 = y[1] 12 | y3 = y[2] 13 | dydt = np.zeros(3) 14 | dydt[0] = - 0.04 * y1 + 10000.0 * y2 * y3 15 | dydt[1] = 0.04 * y1 - 10000.0 * y2 * y3 - 30000000.0 * y2 * y2 16 | dydt[2] = + 30000000.0 * y2 * y2 17 | return dydt 18 | from scipy.integrate import solve_ivp 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import platform 22 | 23 | tmin = 0.0 24 | tmin_logsp=-6 25 | tmax_logsp = 6#10000.0 26 | tmax=1000000.0 27 | y0 = np.array ( [ 1.0, 0.0, 0.0 ] ) 28 | tspan = np.array ( [ tmin, tmax ] ) 29 | t = np.logspace ( tmin_logsp, tmax_logsp, num=50 ) 30 | # 31 | # Use the LSODA solver, that is suitable for stiff systems. 32 | # 33 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=t, method = 'LSODA' ) 34 | 35 | import torch 36 | import torch.nn as n 37 | import FrEIA.framework as Ff 38 | import FrEIA.modules as Fm 39 | 40 | tt=torch.tensor(np.log10(sol.t)) 41 | yy=torch.tensor(sol.y.T) 42 | 43 | seed = 0 44 | torch.manual_seed(seed) 45 | import random 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]]) 49 | device='cpu' 50 | t_d=tt 51 | t_d = t_d.to(device) 52 | xy_d =torch.tensor(yy).to(device) 53 | 54 | xy_d_max=torch.max(xy_d,dim=0).values 55 | factor=1./xy_d_max 56 | hidden_size=1500 57 | num_layers=5 58 | 59 | f_x=n.Sequential( 60 | n.Linear(6, 30), 61 | n.Tanh(), 62 | n.Linear(30, 30), 63 | n.Tanh(), 64 | n.Linear(30, 30), 65 | n.Tanh(), 66 | n.Linear(30, 6) 67 | ) 68 | 69 | xy_d_max=torch.max(xy_d,dim=0).values 70 | 71 | def fx(t,x): 72 | return(f_x(x)) 73 | 74 | def for_inn(x): 75 | return(inn(x)[0]) 76 | def rev_inn(x): 77 | return(inn(x,rev=True)[0]) 78 | def rev_mse_inn_eig(rf,x_gt): 79 | return(torch.mean(torch.norm(rf-x_gt,dim=1))) 80 | def linear_val_ode(w_vec,init_v,t_d): 81 | init_v_in=rev_inn(init_v) 82 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d) 83 | ori_shape = eval_lin.shape 84 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1])) 85 | return(eval_out.reshape(ori_shape)) 86 | def linear_val_ode2(init_v,t_d): 87 | init_v_in=rev_inn(init_v) 88 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01} 89 | eval_out=for_inn(eval_lin) 90 | return(eval_out) 91 | 92 | 93 | 94 | seed = 42 95 | torch.manual_seed(seed) 96 | import random 97 | random.seed(seed) 98 | np.random.seed(seed) 99 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]]) 100 | device='cpu' 101 | 102 | t_d = tt.to(device) 103 | xy_d =torch.tensor(yy).to(device) 104 | 105 | xy_d_max=torch.max(xy_d,dim=0).values 106 | factor=1./xy_d_max 107 | hidden_size=1500 108 | num_layers=5 109 | 110 | f_x=n.Sequential( 111 | n.Linear(6, 30), 112 | n.Tanh(), 113 | n.Linear(30, 30), 114 | n.Tanh(), 115 | n.Linear(30, 30), 116 | n.Tanh(), 117 | n.Linear(30, 6) 118 | ) 119 | 120 | xy_d_max=torch.max(xy_d,dim=0).values 121 | 122 | def fx(t,x): 123 | return(f_x(x)) 124 | 125 | def for_inn(x): 126 | return(inn(x)[0]) 127 | def rev_inn(x): 128 | return(inn(x,rev=True)[0]) 129 | def rev_mse_inn_eig(rf,x_gt): 130 | return(torch.mean(torch.norm(rf-x_gt,dim=1))) 131 | def linear_val_ode(w_vec,init_v,t_d): 132 | init_v_in=rev_inn(init_v) 133 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d) 134 | ori_shape = eval_lin.shape 135 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1])) 136 | return(eval_out.reshape(ori_shape)) 137 | def linear_val_ode2(init_v,t_d): 138 | init_v_in=rev_inn(init_v) 139 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01} 140 | eval_out=for_inn(eval_lin) 141 | return(eval_out) 142 | 143 | 144 | 145 | N_DIM = 6 146 | def subnet_fc(dims_in, dims_out): 147 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(), 148 | n.Linear(hidden_size, dims_out)) 149 | 150 | inn = Ff.SequenceINN(N_DIM) 151 | for k in range(num_layers): 152 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True) 153 | 154 | 155 | optimizer_comb = torch.optim.Adam( 156 | [{'params': f_x.parameters(),'lr': 0.0001},{'params': inn.parameters(), 157 | 'lr': 0.0001}]) 158 | print(sum(p.numel() for p in inn.parameters())) 159 | 160 | startings=tt_tors.clone().detach() 161 | 162 | #Training loop 163 | import timeit 164 | epoch_time=[] 165 | import tqdm 166 | from tqdm import trange 167 | 168 | tt_tors = tt_tors.to(device) 169 | #t_d = t_d.to(device) 170 | inn.to(device) 171 | import torchdiffeq 172 | for i in trange(0, 5000): 173 | optimizer_comb.zero_grad() 174 | loss=0.0 175 | start = timeit.default_timer() 176 | eval_nl=linear_val_ode2(tt_tors,t_d) 177 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01} 178 | """ 179 | for j in range(len(xy_d_list)): 180 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 181 | 182 | #torchdiffeq.odeint(fx, 183 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 184 | # method='euler')[:,0,:] 185 | 186 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 187 | #loss+=loss_cur 188 | """ 189 | #factor 190 | eval_cal=eval_nl[:,:3] 191 | eval_gt=factor*xy_d 192 | eval_gt=xy_d 193 | loss_cur = torch.mean(torch.norm((eval_cal-factor*eval_gt),dim=1)) 194 | loss+=loss_cur 195 | 196 | loss.backward() 197 | optimizer_comb.step() 198 | end = timeit.default_timer() 199 | epoch_time.append(end-start) 200 | if(i%100==0): 201 | print('Combined loss:'+str(i)+': '+str(loss)) 202 | ep_time=np.array(epoch_time) 203 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}') 204 | #print(f'total: {ep_time.sum() / run_for * target:.2f}') 205 | 206 | torch.save(f_x.state_dict(),'n_stiff_fx.tar') 207 | torch.save(inn.state_dict(),'n_stiff_inn.tar') 208 | 209 | device2='cuda' 210 | tt_tors_n=torch.tensor([[1.0,0.,0.,0.,0.,0.]]) 211 | texp = np.logspace ( tmin_logsp, tmax_logsp, num=500 ) 212 | tt_tors_interp=torch.tensor(texp) 213 | t_d_exp=torch.log10(torch.tensor(tt_tors_interp)) 214 | #eval_nl=linear_val_ode2(tt_tors_n,t_d_exp) 215 | tt_tors_n = tt_tors_n.to(device2) 216 | t_d_exp=t_d_exp.to(device2) 217 | inn=inn.to(device2) 218 | f_x=f_x.to(device2) 219 | 220 | q_time=[] 221 | for i in range(10): 222 | start = timeit.default_timer() 223 | eval_nl=linear_val_ode2(tt_tors_n,t_d_exp) 224 | end = timeit.default_timer() 225 | q_time.append(end-start) 226 | q_time_np=np.array(q_time) 227 | print("mean time:") 228 | print(q_time_np.mean()*1000) 229 | print("std time:") 230 | print(q_time_np.std()*1000) 231 | 232 | eval_nl=linear_val_ode2(tt_tors,t_d_exp) 233 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01} 234 | """ 235 | for j in range(len(xy_d_list)): 236 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 237 | 238 | #torchdiffeq.odeint(fx, 239 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 240 | # method='euler')[:,0,:] 241 | 242 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 243 | #loss+=loss_cur 244 | """ 245 | #factor 246 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=texp, method = 'LSODA' ) 247 | yy_exp=torch.tensor(sol.y.T) 248 | eval_cal=eval_nl[:,:3] 249 | eval_gt=torch.tensor(yy_exp).to(device) 250 | 251 | print('MAE:') 252 | print(torch.mean(torch.norm(eval_nl.detach()[:,:3].to('cpu')-factor*yy_exp.to('cpu'),dim=1))) 253 | -------------------------------------------------------------------------------- /stiff_ode/run_rober_comp_dopri.py: -------------------------------------------------------------------------------- 1 | from scipy.integrate import solve_ivp 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import platform 5 | def robertson_conserved ( t, y ): 6 | h = np.sum ( y, axis = 0 ) 7 | return h 8 | 9 | def robertson_deriv ( t, y ): 10 | y1 = y[0] 11 | y2 = y[1] 12 | y3 = y[2] 13 | dydt = np.zeros(3) 14 | dydt[0] = - 0.04 * y1 + 10000.0 * y2 * y3 15 | dydt[1] = 0.04 * y1 - 10000.0 * y2 * y3 - 30000000.0 * y2 * y2 16 | dydt[2] = + 30000000.0 * y2 * y2 17 | return dydt 18 | from scipy.integrate import solve_ivp 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import platform 22 | 23 | tmin = 0.0 24 | tmin_logsp=-6 25 | tmax_logsp = 6#10000.0 26 | tmax=1000000.0 27 | y0 = np.array ( [ 1.0, 0.0, 0.0 ] ) 28 | tspan = np.array ( [ tmin, tmax ] ) 29 | t = np.logspace ( tmin_logsp, tmax_logsp, num=50 ) 30 | # 31 | # Use the LSODA solver, that is suitable for stiff systems. 32 | # 33 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=t, method = 'LSODA' ) 34 | 35 | import torch 36 | import torch.nn as n 37 | import FrEIA.framework as Ff 38 | import FrEIA.modules as Fm 39 | 40 | tt=torch.tensor(np.log10(sol.t)) 41 | yy=torch.tensor(sol.y.T) 42 | 43 | seed = 0 44 | torch.manual_seed(seed) 45 | import random 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]]) 49 | device='cpu' 50 | t_d=tt 51 | t_d = t_d.to(device) 52 | xy_d =torch.tensor(yy).to(device) 53 | 54 | xy_d_max=torch.max(xy_d,dim=0).values 55 | factor=1./xy_d_max 56 | hidden_size=1500 57 | num_layers=5 58 | 59 | f_x=n.Sequential( 60 | n.Linear(6, 30), 61 | n.Tanh(), 62 | n.Linear(30, 30), 63 | n.Tanh(), 64 | n.Linear(30, 30), 65 | n.Tanh(), 66 | n.Linear(30, 6) 67 | ) 68 | 69 | xy_d_max=torch.max(xy_d,dim=0).values 70 | 71 | def fx(t,x): 72 | return(f_x(x)) 73 | 74 | def for_inn(x): 75 | return(inn(x)[0]) 76 | def rev_inn(x): 77 | return(inn(x,rev=True)[0]) 78 | def rev_mse_inn_eig(rf,x_gt): 79 | return(torch.mean(torch.norm(rf-x_gt,dim=1))) 80 | def linear_val_ode(w_vec,init_v,t_d): 81 | init_v_in=rev_inn(init_v) 82 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d) 83 | ori_shape = eval_lin.shape 84 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1])) 85 | return(eval_out.reshape(ori_shape)) 86 | def linear_val_ode2(init_v,t_d): 87 | init_v_in=rev_inn(init_v) 88 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01} 89 | eval_out=for_inn(eval_lin) 90 | return(eval_out) 91 | 92 | 93 | 94 | seed = 42 95 | torch.manual_seed(seed) 96 | import random 97 | import torchdiffeq 98 | random.seed(seed) 99 | np.random.seed(seed) 100 | f_x2=n.Sequential( 101 | n.Linear(6, 150), 102 | n.Tanh(), 103 | n.Linear(150, 150), 104 | n.Tanh(), 105 | n.Linear(150, 150), 106 | n.Tanh(), 107 | n.Linear(150, 6) 108 | ) 109 | device='cpu' 110 | t_d = tt.to(device) 111 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]]) 112 | tt_tors=tt_tors.to(device) 113 | xy_d =torch.tensor(yy).to(device) 114 | xy_d_max=torch.max(xy_d,dim=0).values 115 | factor=1./xy_d_max 116 | def fx2(t,x): 117 | return(f_x2(x)) 118 | optimizer = torch.optim.Adam( 119 | [{'params': f_x2.parameters(),'lr': 0.0001}]) 120 | import timeit 121 | epoch_time=[] 122 | import tqdm 123 | from tqdm import trange 124 | for i in trange(0, 5000): 125 | optimizer.zero_grad() 126 | loss=0.0 127 | start = timeit.default_timer() 128 | eval_nl=torchdiffeq.odeint(fx2,tt_tors,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01} 129 | """ 130 | for j in range(len(xy_d_list)): 131 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 132 | 133 | #torchdiffeq.odeint(fx, 134 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 135 | # method='euler')[:,0,:] 136 | 137 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 138 | #loss+=loss_cur 139 | """ 140 | #factor 141 | eval_cal=eval_nl[:,:3] 142 | eval_gt=factor*xy_d 143 | loss_cur = torch.mean(torch.norm(eval_cal-eval_gt,dim=1)) 144 | loss+=loss_cur 145 | 146 | loss.backward() 147 | optimizer.step() 148 | end = timeit.default_timer() 149 | epoch_time.append(end-start) 150 | if(i%100==0): 151 | print('Combined loss:'+str(i)+': '+str(loss)) 152 | ep_time=np.array(epoch_time) 153 | torch.save(f_x2.state_dict(),'f_x2.tar') 154 | 155 | 156 | device2='cuda' 157 | tt_tors_n=torch.tensor([[1.0,0.,0.,0.,0.,0.]]) 158 | texp = np.logspace ( tmin_logsp, tmax_logsp, num=500 ) 159 | tt_tors_interp=torch.tensor(texp) 160 | t_d_exp=torch.log10(torch.tensor(tt_tors_interp)) 161 | #eval_nl=linear_val_ode2(tt_tors_n,t_d_exp) 162 | tt_tors_n = tt_tors_n.to(device2) 163 | t_d_exp=t_d_exp.to(device2) 164 | f_x2=f_x2.to(device2) 165 | 166 | q_time=[] 167 | for i in range(10): 168 | start = timeit.default_timer() 169 | eval_nl=torchdiffeq.odeint(fx2,tt_tors_n,t_d_exp,atol=1e-5,method='dopri5')[:,0,:] 170 | end = timeit.default_timer() 171 | q_time.append(end-start) 172 | q_time_np=np.array(q_time) 173 | print("mean time:") 174 | print(q_time_np.mean()*1000) 175 | print("std time:") 176 | print(q_time_np.std()*1000) 177 | 178 | eval_nl=torchdiffeq.odeint(fx2,tt_tors_n,t_d_exp,atol=1e-5,method='dopri5')[:,0,:] 179 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01} 180 | """ 181 | for j in range(len(xy_d_list)): 182 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 183 | 184 | #torchdiffeq.odeint(fx, 185 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 186 | # method='euler')[:,0,:] 187 | 188 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 189 | #loss+=loss_cur 190 | """ 191 | #factor 192 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=texp, method = 'LSODA' ) 193 | yy_exp=torch.tensor(sol.y.T) 194 | eval_cal=eval_nl[:,:3] 195 | eval_gt=torch.tensor(yy_exp).to(device) 196 | 197 | print('MAE:') 198 | print(torch.mean(torch.norm(eval_nl.detach()[:,:3].to('cpu')-factor*yy_exp.to('cpu'),dim=1))) 199 | -------------------------------------------------------------------------------- /systems/LV_run_theirs.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as pl 5 | import torchdiffeq 6 | import torch.nn as n 7 | import FrEIA.framework as Ff 8 | import FrEIA.modules as Fm 9 | from mpl_toolkits import mplot3d 10 | 11 | ##########################################################3 12 | # system dynamics 13 | 14 | tt=torch.tensor([0.,5.]) 15 | t_d=torch.linspace(0,7,700) 16 | 17 | devicestr="cpu" 18 | device_str="cuda" 19 | 20 | def test_fun(t,x): 21 | A=0.75 22 | B=0.75 23 | C=0.75 24 | D=0.75 25 | E=0.75 26 | F=0.75 27 | G=0.75 28 | vel=torch.zeros((3,1)) 29 | vel[0]=x[0]*(A-B*x[1]) 30 | vel[1]=x[1]*(-C+D*x[0]-E*x[2]) 31 | vel[2]=x[2]*(-F+G*x[1]) 32 | return(vel) 33 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.], 34 | [2,2.,2.,0.,0.,0.], 35 | [4,4.,3.,0.,0.,0.], 36 | [3,3.,4.,0.,0.,0.], 37 | [1,1.,5.,0.,0.,0.], 38 | [5,5.,1.,0.,0.,0.], 39 | [2,6.,2.,0.,0.,0.], 40 | [3,1.,4.,0.,0.,0.], 41 | [7,1.,2.,0.,0.,0.], 42 | [6,2.,4.,0.,0.,0.]]) 43 | 44 | xy_d_list=[] 45 | for i in range(len(tt_tors)): 46 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3))) 47 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c 48 | xy_d_list.append(xy_d.clone().detach()) 49 | 50 | ##########################################################3 51 | 52 | def test_theirs(method): 53 | global tt_tors, t_d 54 | def test_fun(t,x): 55 | A=0.75 56 | B=0.75 57 | C=0.75 58 | D=0.75 59 | E=0.75 60 | F=0.75 61 | G=0.75 62 | vel=torch.zeros((3,1)) 63 | vel[0]=x[0]*(A-B*x[1]) 64 | vel[1]=x[1]*(-C+D*x[0]-E*x[2]) 65 | vel[2]=x[2]*(-F+G*x[1]) 66 | return(vel) 67 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.], 68 | [2,2.,2.,0.,0.,0.], 69 | [4,4.,3.,0.,0.,0.], 70 | [3,3.,4.,0.,0.,0.], 71 | [1,1.,5.,0.,0.,0.], 72 | [5,5.,1.,0.,0.,0.], 73 | [2,6.,2.,0.,0.,0.], 74 | [3,1.,4.,0.,0.,0.], 75 | [7,1.,2.,0.,0.,0.], 76 | [6,2.,4.,0.,0.,0.]]) 77 | 78 | xy_d_list=[] 79 | for i in range(len(tt_tors)): 80 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3))) 81 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c 82 | xy_d_list.append(xy_d.clone().detach()) 83 | 84 | if method == 'euler': 85 | weight_fn = 'LV_euler0' 86 | elif method == 'midpoint': 87 | weight_fn = 'LV_mid0' 88 | elif method == 'rk4': 89 | weight_fn = 'LV_rk40' 90 | elif method == 'rk4': 91 | weight_fn = 'LV_rk40' 92 | elif method == 'dopri5': 93 | weight_fn = 'LV_dopri0' 94 | f_x=n.Sequential( 95 | n.Linear(6, 150), 96 | n.Tanh(), 97 | n.Linear(150, 150), 98 | n.Tanh(), 99 | n.Linear(150, 150), 100 | n.Tanh(), 101 | n.Linear(150, 150), 102 | n.Tanh(), 103 | n.Linear(150, 150), 104 | n.Tanh(), 105 | n.Linear(150, 6) 106 | ) 107 | def fx(t,x): 108 | return(f_x(x)) 109 | 110 | f_x.load_state_dict(torch.load(weight_fn)) 111 | import time 112 | time_list=[] 113 | 114 | f_x = f_x.to(device_str) 115 | tt_tors = tt_tors.to(device_str) 116 | t_d = t_d.to(device_str) 117 | for i in range(10): 118 | tic = time.perf_counter() 119 | tx=torchdiffeq.odeint(fx,tt_tors,t_d,method=method,atol=1e-5,rtol=1e-5) 120 | toc = time.perf_counter() 121 | time_list.append(toc-tic) 122 | tx=tx.permute(1,0,2) 123 | #interpolation 124 | sum_num=0 125 | 126 | for i in range(1,len(xy_d_list)): 127 | xy_d=xy_d_list[i] 128 | error=torch.mean(torch.norm(tx[i][:,:3].to('cpu')-xy_d,dim=1)**2) 129 | sum_num+=error 130 | 131 | print(f'Interpolation MSE: {sum_num/len(xy_d_list):.4f}') 132 | 133 | times=np.array(time_list) 134 | 135 | print(method) 136 | print('mean time:') 137 | print(times.mean()) 138 | print('time std:') 139 | print(times.std()) 140 | t_d_test_tt=torch.linspace(0,7,700) 141 | test_tors_list=[] 142 | for i in range(2,6,1): 143 | for j in range(2,6,1): 144 | test_tors_list.append(torch.tensor([[float(i),float(j),2.,0.,0.,0.]])) 145 | 146 | test_tors=torch.cat(test_tors_list) 147 | test_d_list=[] 148 | for i in range(len(test_tors)): 149 | xy_d=torchdiffeq.odeint(test_fun,test_tors[i,:3][None].T,t_d_test_tt.to('cpu')).reshape((-1,3)) 150 | test_d_list.append(xy_d.clone().detach()) 151 | 152 | pushed_list_tt=[] 153 | diff_list=[] 154 | 155 | #w_vec=w_vec.to('cuda:0') 156 | #inn=inn.to('cuda:0') 157 | #t_d_test_tt=t_d_test_tt.to('cuda:0') 158 | f_x = f_x.to('cpu') 159 | for j in range(len(test_d_list)): 160 | pushed=torchdiffeq.odeint(fx,test_tors[j],t_d,method=method,atol=1e-5,rtol=1e-5) 161 | pushed_list_tt.append(pushed.clone().detach()) 162 | print('Extrapolation Results:') 163 | i=0 164 | xy_d=test_d_list[i] 165 | sum_num=0 166 | for i in range(1,len(test_d_list)): 167 | xy_d=test_d_list[i].to(device_str) 168 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3].to('cpu')-xy_d.to('cpu'),dim=1)**2) 169 | sum_num+=error 170 | 171 | print(f'Extrapolation MSE: {sum_num/len(test_d_list):.4f}') 172 | 173 | 174 | test_theirs('euler') 175 | test_theirs('rk4') 176 | test_theirs('midpoint') 177 | test_theirs('dopri5') 178 | ##########################################################3 179 | 180 | -------------------------------------------------------------------------------- /systems/LV_train_run_ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as pl 4 | import torchdiffeq 5 | import torch.nn as n 6 | import FrEIA.framework as Ff 7 | import FrEIA.modules as Fm 8 | 9 | device = 'cpu' 10 | 11 | xx=torch.tensor([0.,0.,0.]) 12 | tt=torch.tensor([0.,5.]) 13 | t_d=torch.linspace(0,7,70) 14 | torch.manual_seed(0) 15 | import random 16 | random.seed(0) 17 | np.random.seed(0) 18 | def test_fun(t,x): 19 | A=0.75 20 | B=0.75 21 | C=0.75 22 | D=0.75 23 | E=0.75 24 | F=0.75 25 | G=0.75 26 | vel=torch.zeros((3,1)) 27 | vel[0]=x[0]*(A-B*x[1]) 28 | vel[1]=x[1]*(-C+D*x[0]-E*x[2]) 29 | vel[2]=x[2]*(-F+G*x[1]) 30 | return(vel) 31 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.], 32 | [2,2.,2.,0.,0.,0.], 33 | [4,4.,3.,0.,0.,0.], 34 | [3,3.,4.,0.,0.,0.], 35 | [1,1.,5.,0.,0.,0.], 36 | [5,5.,1.,0.,0.,0.], 37 | [2,6.,2.,0.,0.,0.], 38 | [3,1.,4.,0.,0.,0.], 39 | [7,1.,2.,0.,0.,0.], 40 | [6,2.,4.,0.,0.,0.]]) 41 | 42 | 43 | def eigen_ode(w_vec,init_x,t_d): 44 | e_val,e_vec=(w_vec[0],w_vec[1:]) 45 | init_v=torch.mm(torch.linalg.inv(e_vec),init_x.reshape((-1,1))) 46 | #print(init_v) 47 | int_list=[] 48 | 49 | for i in range(0,6): 50 | int_c=0 51 | for j in range(0,6): 52 | int_c+=init_v[j]*e_vec[i,j]*torch.exp(e_val[j]*t_d) 53 | int_c.reshape((-1,1)) 54 | int_list.append(int_c[None].clone()) 55 | return(torch.cat(int_list).T) 56 | 57 | def eigen_ode__(w_vec,init_x,t_d): 58 | 59 | e_val,e_vec=(w_vec[0],w_vec[1:]) 60 | 61 | _e_vec = e_vec * 1 62 | _t_d = t_d * 1 63 | 64 | 65 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None]) 66 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))), 67 | torch.exp( 68 | e_val[:,None] * _t_d 69 | #self._nn(t_d[:, None]).T 70 | ) 71 | .expand((init_x.shape[0],-1,-1))).transpose(1,2) 72 | return rs 73 | 74 | def loss_er(x_pred,x_gt): 75 | mse_l=torch.norm(x_pred-x_gt,dim=1) 76 | sum_c=0.0 77 | for i in range(len(mse_l)): 78 | sum_c+=(mse_l[i]**2)#*(1/float(i+1)) 79 | return(sum_c/len(mse_l)) 80 | 81 | 82 | 83 | 84 | 85 | 86 | def train(hidden_size, num_layers): 87 | print('='*20) 88 | #print(f'hidden_size={hidden_size}, num_layers={num_layers}') 89 | global tt_tors, xy_d_list, t_d 90 | 91 | def for_inn(x): 92 | return(inn(x)[0]) 93 | def rev_inn(x): 94 | return(inn(x,rev=True)[0]) 95 | def rev_mse_inn_eig(rf,x_gt): 96 | return(torch.mean(torch.norm(rf-x_gt,dim=1))) 97 | def linear_val_ode(w_vec,init_v,t_d): 98 | init_v_in=rev_inn(init_v) 99 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d) 100 | ori_shape = eval_lin.shape 101 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1])) 102 | return(eval_out.reshape(ori_shape)) 103 | 104 | seed = 123 105 | torch.manual_seed(seed) 106 | import random 107 | random.seed(seed) 108 | np.random.seed(seed) 109 | 110 | 111 | e_val=torch.tensor([[0.1,0.1,0.1,0.,0.,0.]]) 112 | e_vec=torch.eye(6) 113 | w_vec=torch.cat([e_val,e_vec],dim=0) 114 | w_vec.requires_grad=True 115 | 116 | 117 | 118 | xy_d_list=[] 119 | for i in range(len(tt_tors)): 120 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3))) 121 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T.to('cpu'),t_d.to('cpu')).reshape((-1,3))+noise_c 122 | xy_d_list.append(xy_d.clone().detach()) 123 | 124 | 125 | N_DIM = 6 126 | def subnet_fc(dims_in, dims_out): 127 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(), 128 | n.Linear(hidden_size, dims_out)) 129 | 130 | inn = Ff.SequenceINN(N_DIM) 131 | for k in range(num_layers): 132 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True) 133 | 134 | 135 | optimizer_comb = torch.optim.Adam( 136 | [{'params': w_vec,'lr': 0.0001},{'params': inn.parameters(), 137 | 'lr': 0.0001}]) 138 | print(sum(p.numel() for p in inn.parameters())) 139 | opt=torch.optim.Adam([w_vec],lr=0.005) 140 | startings=tt_tors.clone().detach() 141 | #quick initialise 142 | #for i in range(0,1): 143 | for i in range(0,200): 144 | opt.zero_grad() 145 | loss = 0. 146 | for j in range(len(xy_d_list)): 147 | start_d=startings[j][None] 148 | output = eigen_ode(w_vec,start_d.to('cpu'),t_d.to('cpu')) 149 | loss_cur = loss_er(output[:,:3], xy_d_list[j][:,:3].to('cpu')) 150 | loss+=loss_cur 151 | loss.backward() 152 | opt.step() 153 | 154 | #Training loop 155 | import timeit 156 | epoch_time=[] 157 | import tqdm 158 | from tqdm import trange 159 | 160 | w_vec = w_vec.to(device) 161 | tt_tors = tt_tors.to(device) 162 | #t_d = t_d.to(device) 163 | inn.to(device) 164 | t_d = t_d.to(device) 165 | xy_d_list = torch.stack(xy_d_list).to(device) 166 | 167 | 168 | 169 | for i in trange(0, 5000): 170 | optimizer_comb.zero_grad() 171 | loss=0.0 172 | start = timeit.default_timer() 173 | eval_nl=linear_val_ode(w_vec,tt_tors,t_d) 174 | """ 175 | for j in range(len(xy_d_list)): 176 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 177 | 178 | #torchdiffeq.odeint(fx, 179 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 180 | # method='euler')[:,0,:] 181 | 182 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 183 | #loss+=loss_cur 184 | """ 185 | loss_cur = rev_mse_inn_eig(eval_nl[...,:3],xy_d_list) 186 | loss+=loss_cur 187 | 188 | loss.backward() 189 | optimizer_comb.step() 190 | end = timeit.default_timer() 191 | epoch_time.append(end-start) 192 | if(i%100==0): 193 | print('Combined loss:'+str(i)+': '+str(loss)) 194 | ep_time=np.array(epoch_time) 195 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}') 196 | #print(f'total: {ep_time.sum() / run_for * target:.2f}') 197 | 198 | 199 | ############################### 200 | #Evaluation (Interpolation): 201 | ############################### 202 | 203 | xx=torch.tensor([0.,0.]) 204 | tt=torch.tensor([0.,5.]) 205 | t_d=torch.linspace(0,7,700) 206 | 207 | 208 | xy_d_list=[] 209 | for i in range(len(tt_tors)): 210 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T.to('cpu'),t_d).reshape((-1,3)) 211 | xy_d_list.append(xy_d.clone().detach()) 212 | 213 | xy_d_list = torch.stack(xy_d_list).to(device) 214 | 215 | test_d_list = xy_d_list 216 | 217 | 218 | t_d_test_tt=torch.linspace(0,7,700).to(device) 219 | 220 | 221 | import timeit 222 | 223 | pushed_list_tt=[] 224 | diff_list=[] 225 | 226 | #w_vec=w_vec.to('cuda:0') 227 | #inn=inn.to('cuda:0') 228 | #t_d_test_tt=t_d_test_tt.to('cuda:0') 229 | 230 | for j in range(len(tt_tors)): 231 | #test_tors[j]=test_tors[j].to('cuda:0') 232 | start = timeit.default_timer() 233 | #linear_val_ode(w_vec,tt_tors[j],t_d_test_tt) 234 | #pushed=torchdiffeq.odeint(fx, 235 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 236 | # method='euler')[:,0,:] 237 | pushed = linear_val_ode(w_vec,tt_tors[j],t_d_test_tt) 238 | stop=timeit.default_timer() 239 | diff_list.append(stop-start) 240 | pushed_list_tt.append(pushed.clone().detach()) 241 | 242 | print('Interpolation:') 243 | print(f'time: {np.mean(np.array(diff_list)):.3f} {np.std(np.array(diff_list)):.3f}') 244 | #pl.figure(figsize=(4,4)) 245 | 246 | i=0 247 | xy_d=test_d_list[i] 248 | #pl.plot(xy_d[:,0],xy_d[:,1],c='b',alpha=0.5,label='Ground truth') 249 | 250 | sum_num=0 251 | 252 | for i in range(1,len(xy_d_list)): 253 | xy_d=xy_d_list[i] 254 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3]-xy_d,dim=2)**2) 255 | sum_num+=error 256 | 257 | print(f'Interpolation MSE: {sum_num/len(xy_d_list):.4f}') 258 | 259 | 260 | ############################### 261 | #Evaluation (Extrapolation): 262 | ############################### 263 | 264 | 265 | test_tors_list=[] 266 | for i in range(2,6,1): 267 | for j in range(2,6,1): 268 | test_tors_list.append(torch.tensor([[float(i),float(j),2.,0.,0.,0.]])) 269 | 270 | test_tors=torch.cat(test_tors_list) 271 | test_d_list=[] 272 | for i in range(len(test_tors)): 273 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3))) 274 | xy_d=torchdiffeq.odeint(test_fun,test_tors[i,:3][None].T,t_d_test_tt.to('cpu')).reshape((-1,3)) 275 | test_d_list.append(xy_d.clone().detach()) 276 | 277 | test_tors=test_tors.to(device) 278 | import timeit 279 | 280 | pushed_list_tt=[] 281 | diff_list=[] 282 | 283 | #w_vec=w_vec.to('cuda:0') 284 | #inn=inn.to('cuda:0') 285 | #t_d_test_tt=t_d_test_tt.to('cuda:0') 286 | 287 | for j in range(len(test_d_list)): 288 | #test_tors[j]=test_tors[j].to('cuda:0') 289 | start = timeit.default_timer() 290 | #linear_val_ode(w_vec,test_tors[j],t_d_test_tt) 291 | pushed = linear_val_ode(w_vec,test_tors[j],t_d_test_tt) 292 | stop=timeit.default_timer() 293 | diff_list.append(stop-start) 294 | pushed_list_tt.append(pushed.clone().detach()) 295 | print('Extrapolation Results:') 296 | print(f'time: {np.mean(np.array(diff_list)):.3f} {np.std(np.array(diff_list)):.3f}') 297 | 298 | 299 | i=0 300 | xy_d=test_d_list[i] 301 | sum_num=0 302 | for i in range(1,len(test_d_list)): 303 | xy_d=test_d_list[i].to(device) 304 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3]-xy_d,dim=2)**2) 305 | sum_num+=error 306 | 307 | print(f'Extrapolation MSE: {sum_num/len(test_d_list):.4f}') 308 | 309 | 310 | 311 | train(1500, 5) 312 | 313 | -------------------------------------------------------------------------------- /systems/LV_train_theirs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as pl 4 | import torchdiffeq 5 | import torch.nn as n 6 | import FrEIA.framework as Ff 7 | import FrEIA.modules as Fm 8 | 9 | #xx=torch.tensor([0.,0.,0.]) 10 | tt=torch.tensor([0.,5.]) 11 | t_d=torch.linspace(0,7,70) 12 | 13 | import random 14 | 15 | random.seed(10) 16 | np.random.seed(0) 17 | 18 | device_str='cpu' 19 | def test_fun(t,x): 20 | A=0.75 21 | B=0.75 22 | C=0.75 23 | D=0.75 24 | E=0.75 25 | F=0.75 26 | G=0.75 27 | vel=torch.zeros((3,1)) 28 | vel[0]=x[0]*(A-B*x[1]) 29 | vel[1]=x[1]*(-C+D*x[0]-E*x[2]) 30 | vel[2]=x[2]*(-F+G*x[1]) 31 | return(vel) 32 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.], 33 | [2,2.,2.,0.,0.,0.], 34 | [4,4.,3.,0.,0.,0.], 35 | [3,3.,4.,0.,0.,0.], 36 | [1,1.,5.,0.,0.,0.], 37 | [5,5.,1.,0.,0.,0.], 38 | [2,6.,2.,0.,0.,0.], 39 | [3,1.,4.,0.,0.,0.], 40 | [7,1.,2.,0.,0.,0.], 41 | [6,2.,4.,0.,0.,0.]]) 42 | 43 | xy_d_list=[] 44 | for i in range(len(tt_tors)): 45 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3))) 46 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c 47 | xy_d_list.append(xy_d.clone().detach()) 48 | 49 | 50 | for kk in range(1): 51 | f_x=n.Sequential( 52 | n.Linear(6, 150), 53 | n.Tanh(), 54 | n.Linear(150, 150), 55 | n.Tanh(), 56 | n.Linear(150, 150), 57 | n.Tanh(), 58 | n.Linear(150, 150), 59 | n.Tanh(), 60 | n.Linear(150, 150), 61 | n.Tanh(), 62 | n.Linear(150, 6) 63 | ) 64 | 65 | def fx(t,x): 66 | return(f_x(x)) 67 | 68 | 69 | optimizer = torch.optim.Adam( 70 | [{'params': f_x.parameters(),'lr': 0.0001}]) 71 | 72 | 73 | for i in range(0,5000): 74 | loss_all=0. 75 | optimizer.zero_grad() 76 | for j in range(0,len(tt_tors)): 77 | xx=xy_d_list[j] 78 | tx=torchdiffeq.odeint(fx, 79 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5, 80 | method='euler')[:,0,:] 81 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 82 | loss_all+=loss_c 83 | if(i%100==0): 84 | print(str(i)+'euler:'+str(loss_all/10)) 85 | loss_all.backward() 86 | optimizer.step() 87 | torch.save(f_x.state_dict(),'LV_euler'+str(kk)) 88 | 89 | f_x=n.Sequential( 90 | n.Linear(6, 150), 91 | n.Tanh(), 92 | n.Linear(150, 150), 93 | n.Tanh(), 94 | n.Linear(150, 150), 95 | n.Tanh(), 96 | n.Linear(150, 150), 97 | n.Tanh(), 98 | n.Linear(150, 150), 99 | n.Tanh(), 100 | n.Linear(150, 6) 101 | ) 102 | 103 | def fx(t,x): 104 | return(f_x(x)) 105 | 106 | 107 | optimizer = torch.optim.Adam( 108 | [{'params': f_x.parameters(),'lr': 0.0001}]) 109 | 110 | print('Mid') 111 | 112 | for i in range(0,5000): 113 | loss_all=0. 114 | optimizer.zero_grad() 115 | for j in range(0,len(tt_tors)): 116 | xx=xy_d_list[j] 117 | tx=torchdiffeq.odeint(fx, 118 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5, 119 | method='midpoint')[:,0,:] 120 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 121 | loss_all+=loss_c 122 | if(i%100==0): 123 | print(str(i)+'mid:'+str(loss_all/10)) 124 | loss_all.backward() 125 | optimizer.step() 126 | torch.save(f_x.state_dict(),'LV_mid'+str(kk)) 127 | 128 | f_x=n.Sequential( 129 | n.Linear(6, 150), 130 | n.Tanh(), 131 | n.Linear(150, 150), 132 | n.Tanh(), 133 | n.Linear(150, 150), 134 | n.Tanh(), 135 | n.Linear(150, 150), 136 | n.Tanh(), 137 | n.Linear(150, 150), 138 | n.Tanh(), 139 | n.Linear(150, 6) 140 | ) 141 | 142 | def fx(t,x): 143 | return(f_x(x)) 144 | 145 | 146 | optimizer = torch.optim.Adam( 147 | [{'params': f_x.parameters(),'lr': 0.0001}]) 148 | 149 | print('RK') 150 | for i in range(0,5000): 151 | loss_all=0. 152 | optimizer.zero_grad() 153 | for j in range(0,len(tt_tors)): 154 | xx=xy_d_list[j] 155 | tx=torchdiffeq.odeint(fx, 156 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5, 157 | method='rk4')[:,0,:] 158 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 159 | loss_all+=loss_c 160 | if(i%100==0): 161 | print(str(i)+'rk:'+str(loss_all/10)) 162 | loss_all.backward() 163 | optimizer.step() 164 | 165 | torch.save(f_x.state_dict(),'LV_rk4'+str(kk)) 166 | 167 | f_x=n.Sequential( 168 | n.Linear(6, 150), 169 | n.Tanh(), 170 | n.Linear(150, 150), 171 | n.Tanh(), 172 | n.Linear(150, 150), 173 | n.Tanh(), 174 | n.Linear(150, 150), 175 | n.Tanh(), 176 | n.Linear(150, 150), 177 | n.Tanh(), 178 | n.Linear(150, 6) 179 | ) 180 | 181 | def fx(t,x): 182 | return(f_x(x)) 183 | 184 | 185 | optimizer = torch.optim.Adam( 186 | [{'params': f_x.parameters(),'lr': 0.0001}]) 187 | 188 | print('Dop') 189 | 190 | for i in range(0,5000): 191 | loss_all=0. 192 | optimizer.zero_grad() 193 | for j in range(0,len(tt_tors)): 194 | xx=xy_d_list[j] 195 | tx=torchdiffeq.odeint(fx, 196 | tt_tors[j][None],t_d,rtol=1e-5, atol=1e-5,#,rtol=1e-5, 197 | method='dopri5')[:,0,:] 198 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 199 | loss_all+=loss_c 200 | if(i%100==0): 201 | print(str(i)+'euler:'+str(loss_all/10)) 202 | loss_all.backward() 203 | optimizer.step() 204 | 205 | torch.save(f_x.state_dict(),'LV_dopri'+str(kk)) 206 | -------------------------------------------------------------------------------- /systems/f_x_base_save_good_eod2.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/systems/f_x_base_save_good_eod2.tar -------------------------------------------------------------------------------- /systems/inn2_save_good_eod2.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/systems/inn2_save_good_eod2.tar -------------------------------------------------------------------------------- /systems/lor_run_all.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as pl 5 | import torchdiffeq 6 | import torch.nn as n 7 | import FrEIA.framework as Ff 8 | import FrEIA.modules as Fm 9 | from mpl_toolkits import mplot3d 10 | 11 | 12 | ##########################################################3 13 | # system dynamics 14 | 15 | xx=torch.tensor([0.,0.,0.]) 16 | tt=torch.tensor([0.,5.]) 17 | t_d=torch.linspace(0,2,800) 18 | devicestr="cpu" 19 | device_str="cuda" 20 | 21 | def test_fun(t,x): 22 | sig=10. 23 | rho=28. 24 | beta=8/3 25 | vel=torch.zeros((3,1)) 26 | vel[0]=sig*(x[1]-x[0]) 27 | vel[1]=x[0]*(rho-x[2])-x[1] 28 | vel[2]=x[0]*x[1]-beta*x[2] 29 | return(vel) 30 | 31 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]]) 32 | xy_d_list=[] 33 | for i in range(len(tt_tors)): 34 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3))) 35 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c #no noise for test 36 | xy_d_list.append(xy_d.clone().detach()) 37 | 38 | pl.figure(figsize=(12,10)) 39 | for i in range(len(xy_d_list)): 40 | xy_d=xy_d_list[i] 41 | pl.plot(xy_d[:,0],xy_d[:,1],marker='o') 42 | 43 | xx=xy_d_list[0] 44 | 45 | ##########################################################3 46 | 47 | def test_theirs(method, size): 48 | global tt_tors, t_d 49 | if method == 'euler': 50 | weight_fn = 'lor_euler.tar' 51 | elif method == 'midpoint': 52 | weight_fn = 'lor_mid.tar' 53 | elif method == 'rk4': 54 | weight_fn = 'lor_rk4.tar' 55 | elif method == 'dopri5': 56 | weight_fn = 'lor_dopri.tar' 57 | f_x=n.Sequential( 58 | n.Linear(6, 150), 59 | n.Tanh(), 60 | n.Linear(150, 150), 61 | n.Tanh(), 62 | n.Linear(150, 150), 63 | n.Tanh(), 64 | n.Linear(150, 150), 65 | n.Tanh(), 66 | n.Linear(150, 150), 67 | n.Tanh(), 68 | n.Linear(150, 6) 69 | ) 70 | def fx(t,x): 71 | return(f_x(x)) 72 | 73 | f_x.load_state_dict(torch.load(weight_fn)) 74 | import time 75 | time_list=[] 76 | 77 | f_x = f_x.to(device_str) 78 | tt_tors = tt_tors.to(device_str) 79 | t_d = t_d.to(device_str) 80 | for i in range(10): 81 | tic = time.perf_counter() 82 | tx=torchdiffeq.odeint(fx,tt_tors,t_d,method=method,rtol=1e-5, atol=1e-5) 83 | toc = time.perf_counter() 84 | time_list.append(toc-tic) 85 | 86 | tx=tx.permute(1,0,2) 87 | #interpolation 88 | sum_num=0 89 | 90 | for i in range(1,len(xy_d_list)): 91 | xy_d=xy_d_list[i] 92 | error=torch.mean(torch.norm(tx[i][:,:3].to('cpu')-xy_d,dim=1)) 93 | sum_num+=error 94 | 95 | print(f'Interpolation MAE: {sum_num/len(xy_d_list):.4f}') 96 | 97 | times=np.array(time_list) 98 | 99 | print(method) 100 | print('mean:') 101 | print(times.mean()) 102 | print('euler std:') 103 | print(times.std()) 104 | 105 | 106 | #test_theirs('euler') 107 | #test_theirs('rk4') 108 | #test_theirs('midpoint') 109 | 110 | ##########################################################3 111 | 112 | N_DIM = 6 113 | def subnet_fc(dims_in, dims_out): 114 | return n.Sequential(n.Linear(dims_in, 1500), n.ReLU(), 115 | n.Linear(1500, dims_out)) 116 | inn2 = Ff.SequenceINN(N_DIM) 117 | for k in range(5): 118 | inn2.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc) 119 | 120 | f_x_base=n.Sequential( 121 | n.Linear(6, 30), 122 | n.Tanh(), 123 | n.Linear(30, 30), 124 | n.Tanh(), 125 | n.Linear(30, 30), 126 | n.Tanh(), 127 | n.Linear(30, 6) 128 | ) 129 | def fx_base(t,x): 130 | return(f_x_base(x)) 131 | 132 | def for_inn2(x): 133 | return(inn2(x)[0]) 134 | def rev_inn2(x): 135 | return(inn2(x,rev=True)[0]) 136 | def linear_val_ode2(init_v,t_d): 137 | init_v_in=rev_inn2(init_v) 138 | eval_lin=torchdiffeq.odeint(fx_base,init_v_in,t_d, 139 | method='euler')[:,0,:]#options={'step_size':0.01} 140 | eval_out=for_inn2(eval_lin) 141 | return(eval_out) 142 | 143 | f_x_base.load_state_dict(torch.load('f_x_base_save_good_eod2.tar')) 144 | inn2.load_state_dict(torch.load('inn2_save_good_eod2.tar')) 145 | f_x_base = f_x_base.to(device_str) 146 | inn2 = inn2.to(device_str) 147 | 148 | f_x_base = f_x_base.to(device_str) 149 | tt_tors = tt_tors.to(device_str) 150 | t_d = t_d.to(device_str) 151 | 152 | time_list=[] 153 | for i in range(10): 154 | tic = time.perf_counter() 155 | tx=linear_val_ode2(tt_tors,t_d) 156 | toc = time.perf_counter() 157 | time_list.append(toc-tic) 158 | 159 | 160 | ours=np.array(time_list) 161 | print('ours time mean:') 162 | print(ours.mean()) 163 | print('ours time std:') 164 | print(ours.std()) 165 | 166 | 167 | sum_num=0 168 | 169 | for i in range(0,len(tt_tors)): 170 | xy_d=xy_d_list[i] 171 | print(xy_d.shape) 172 | error=torch.mean(torch.norm(tx[:,:3].to('cpu')-xy_d,dim=1)) 173 | sum_num+=error 174 | 175 | print(f'Ours Interpolation MAE: {sum_num:.4f}') 176 | 177 | test_theirs('euler') 178 | test_theirs('rk4') 179 | test_theirs('midpoint') 180 | 181 | -------------------------------------------------------------------------------- /systems/lor_train_ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as pl 4 | import torchdiffeq 5 | import torch.nn as n 6 | import FrEIA.framework as Ff 7 | import FrEIA.modules as Fm 8 | 9 | device = 'cpu' 10 | seed = 0 11 | torch.manual_seed(seed) 12 | import random 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | xx=torch.tensor([0.,0.,0.]) 16 | tt=torch.tensor([0.,5.]) 17 | t_d=torch.linspace(0,2,80) 18 | 19 | 20 | device_str='cpu' 21 | def test_fun(t,x): 22 | sig=10. 23 | rho=28. 24 | beta=8/3 25 | vel=torch.zeros((3,1)) 26 | vel[0]=sig*(x[1]-x[0]) 27 | vel[1]=x[0]*(rho-x[2])-x[1] 28 | vel[2]=x[0]*x[1]-beta*x[2] 29 | return(vel) 30 | 31 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]]) 32 | xy_d_list=[] 33 | for i in range(len(tt_tors)): 34 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3))) 35 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c 36 | xy_d_list.append(xy_d.clone().detach()) 37 | 38 | 39 | 40 | def loss_er(x_pred,x_gt): 41 | mse_l=torch.norm(x_pred-x_gt,dim=1) 42 | sum_c=0.0 43 | for i in range(len(mse_l)): 44 | sum_c+=(mse_l[i])#*(1/float(i+1)) 45 | return(sum_c/len(mse_l)) 46 | 47 | 48 | 49 | 50 | 51 | 52 | def train(hidden_size, num_layers): 53 | print('='*20) 54 | #print(f'hidden_size={hidden_size}, num_layers={num_layers}') 55 | global tt_tors, xy_d_list, t_d 56 | #seed = 123 57 | torch.manual_seed(seed) 58 | import random 59 | random.seed(seed) 60 | np.random.seed(seed) 61 | f_x=n.Sequential( 62 | n.Linear(6, 30), 63 | n.Tanh(), 64 | n.Linear(30, 30), 65 | n.Tanh(), 66 | n.Linear(30, 30), 67 | n.Tanh(), 68 | n.Linear(30, 6) 69 | ) 70 | def fx(t,x): 71 | return(f_x(x)) 72 | 73 | def for_inn(x): 74 | return(inn(x)[0]) 75 | def rev_inn(x): 76 | return(inn(x,rev=True)[0]) 77 | def rev_mse_inn_eig(rf,x_gt): 78 | return(torch.mean(torch.norm(rf-x_gt,dim=1))) 79 | def linear_val_ode(w_vec,init_v,t_d): 80 | init_v_in=rev_inn(init_v) 81 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d) 82 | ori_shape = eval_lin.shape 83 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1])) 84 | return(eval_out.reshape(ori_shape)) 85 | def linear_val_ode2(init_v,t_d): 86 | init_v_in=rev_inn(init_v) 87 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,method='dopri5',atol=1e-5,rtol=1e-5)[:,0,:]#options={'step_size':0.01} 88 | eval_out=for_inn(eval_lin) 89 | return(eval_out) 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | N_DIM = 6 98 | def subnet_fc(dims_in, dims_out): 99 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(), 100 | n.Linear(hidden_size, dims_out)) 101 | 102 | inn = Ff.SequenceINN(N_DIM) 103 | for k in range(num_layers): 104 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True) 105 | 106 | 107 | optimizer_comb = torch.optim.Adam( 108 | [{'params': f_x.parameters(),'lr': 0.0001},{'params': inn.parameters(), 109 | 'lr': 0.0001}]) 110 | print(sum(p.numel() for p in inn.parameters())) 111 | 112 | startings=tt_tors.clone().detach() 113 | 114 | #Training loop 115 | import timeit 116 | epoch_time=[] 117 | import tqdm 118 | from tqdm import trange 119 | 120 | tt_tors = tt_tors.to(device) 121 | #t_d = t_d.to(device) 122 | inn.to(device) 123 | t_d = t_d.to(device) 124 | xy_d_list = torch.stack(xy_d_list).to(device) 125 | 126 | print('training INN') 127 | #for i in trange(0, 5000): 128 | for i in trange(0, 5000): 129 | optimizer_comb.zero_grad() 130 | loss=0.0 131 | start = timeit.default_timer() 132 | eval_nl=linear_val_ode2(tt_tors,t_d) 133 | """ 134 | for j in range(len(xy_d_list)): 135 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d) 136 | 137 | #torchdiffeq.odeint(fx, 138 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5, 139 | # method='euler')[:,0,:] 140 | 141 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j]) 142 | #loss+=loss_cur 143 | """ 144 | loss_cur = torch.mean(torch.norm(eval_nl[:,:3]-xy_d,dim=1)) 145 | loss+=loss_cur 146 | 147 | loss.backward() 148 | optimizer_comb.step() 149 | end = timeit.default_timer() 150 | epoch_time.append(end-start) 151 | if(i%100==0): 152 | print('Combined loss:'+str(i)+': '+str(loss)) 153 | ep_time=np.array(epoch_time) 154 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}') 155 | #print(f'total: {ep_time.sum() / run_for * target:.2f}') 156 | torch.save(f_x.state_dict(),'f_x_base_save_good_eod2.tar') 157 | torch.save(inn.state_dict(),'inn2_save_good_eod2.tar') 158 | 159 | 160 | 161 | 162 | train(1500, 5) 163 | -------------------------------------------------------------------------------- /systems/lor_train_theirs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as pl 4 | import torchdiffeq 5 | import torch.nn as n 6 | import FrEIA.framework as Ff 7 | import FrEIA.modules as Fm 8 | 9 | xx=torch.tensor([0.,0.,0.]) 10 | tt=torch.tensor([0.,5.]) 11 | t_d=torch.linspace(0,2,80) 12 | 13 | 14 | device_str='cpu' 15 | def test_fun(t,x): 16 | sig=10. 17 | rho=28. 18 | beta=8/3 19 | vel=torch.zeros((3,1)) 20 | vel[0]=sig*(x[1]-x[0]) 21 | vel[1]=x[0]*(rho-x[2])-x[1] 22 | vel[2]=x[0]*x[1]-beta*x[2] 23 | return(vel) 24 | 25 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]]) 26 | xy_d_list=[] 27 | for i in range(len(tt_tors)): 28 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3))) 29 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c 30 | xy_d_list.append(xy_d.clone().detach()) 31 | 32 | pl.figure(figsize=(12,10)) 33 | for i in range(len(xy_d_list)): 34 | xy_d=xy_d_list[i] 35 | pl.plot(xy_d[:,0],xy_d[:,1],marker='o') 36 | 37 | pl.savefig('lorr.png') 38 | xx=xy_d_list[0] 39 | n.Linear(6, 150), 40 | n.Tanh(), 41 | n.Linear(150, 150), 42 | n.Tanh(), 43 | n.Linear(150, 150), 44 | n.Tanh(), 45 | n.Linear(150, 150), 46 | n.Tanh(), 47 | n.Linear(150, 150), 48 | n.Tanh(), 49 | n.Linear(150, 6) 50 | ) 51 | 52 | def fx(t,x): 53 | return(f_x(x)) 54 | 55 | 56 | optimizer = torch.optim.Adam( 57 | [{'params': f_x.parameters(),'lr': 0.0001}]) 58 | 59 | for j in range(0,5000): 60 | optimizer.zero_grad() 61 | tx=torchdiffeq.odeint(fx, 62 | tt_tors,t_d,#,rtol=1e-5, 63 | method='euler')[:,0,:] 64 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 65 | if(j%100==0): 66 | print('euler: '+str(j)+': '+str(loss)) 67 | loss.backward() 68 | optimizer.step() 69 | torch.save(f_x.state_dict(),'lor_euler.tar') 70 | 71 | f_x=n.Sequential( 72 | n.Linear(6, 150), 73 | n.Tanh(), 74 | n.Linear(150, 150), 75 | n.Tanh(), 76 | n.Linear(150, 150), 77 | n.Tanh(), 78 | n.Linear(150, 150), 79 | n.Tanh(), 80 | n.Linear(150, 150), 81 | n.Tanh(), 82 | n.Linear(150, 6) 83 | ) 84 | 85 | def fx(t,x): 86 | return(f_x(x)) 87 | 88 | 89 | optimizer = torch.optim.Adam( 90 | [{'params': f_x.parameters(),'lr': 0.0001}]) 91 | 92 | for j in range(0,5000): 93 | optimizer.zero_grad() 94 | tx=torchdiffeq.odeint(fx, 95 | tt_tors,t_d,#,rtol=1e-5, 96 | method='midpoint')[:,0,:] 97 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 98 | if(j%100==0): 99 | print('midpoint: '+str(j)+': '+str(loss)) 100 | loss.backward() 101 | optimizer.step() 102 | torch.save(f_x.state_dict(),'lor_mid.tar') 103 | 104 | f_x=n.Sequential( 105 | n.Linear(6, 150), 106 | n.Tanh(), 107 | n.Linear(150, 150), 108 | n.Tanh(), 109 | n.Linear(150, 150), 110 | n.Tanh(), 111 | n.Linear(150, 150), 112 | n.Tanh(), 113 | n.Linear(150, 150), 114 | n.Tanh(), 115 | n.Linear(150, 6) 116 | ) 117 | 118 | def fx(t,x): 119 | return(f_x(x)) 120 | 121 | 122 | optimizer = torch.optim.Adam( 123 | [{'params': f_x.parameters(),'lr': 0.0001}]) 124 | 125 | for j in range(0,5000): 126 | optimizer.zero_grad() 127 | tx=torchdiffeq.odeint(fx, 128 | tt_tors,t_d,#,rtol=1e-5, 129 | method='rk4')[:,0,:] 130 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 131 | if(j%100==0): 132 | print('rk4: '+str(j)+': '+str(loss)) 133 | 134 | loss.backward() 135 | optimizer.step() 136 | torch.save(f_x.state_dict(),'lor_rk4.tar') 137 | 138 | f_x=n.Sequential( 139 | n.Linear(6, 150), 140 | n.Tanh(), 141 | n.Linear(150, 150), 142 | n.Tanh(), 143 | n.Linear(150, 150), 144 | n.Tanh(), 145 | n.Linear(150, 150), 146 | n.Tanh(), 147 | n.Linear(150, 150), 148 | n.Tanh(), 149 | n.Linear(150, 6) 150 | ) 151 | 152 | def fx(t,x): 153 | return(f_x(x)) 154 | 155 | 156 | optimizer = torch.optim.Adam( 157 | [{'params': f_x.parameters(),'lr': 0.0001}]) 158 | 159 | for j in range(0,5000): 160 | optimizer.zero_grad() 161 | tx=torchdiffeq.odeint(fx, 162 | tt_tors,t_d,rtol=1e-5, atol=1e-5, 163 | method='dopri5')[:,0,:] 164 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1)) 165 | if(j%100==0): 166 | print('dopri: '+str(j)+': '+str(loss)) 167 | loss.backward() 168 | optimizer.step() 169 | torch.save(f_x.state_dict(),'lor_dopri.tar') 170 | --------------------------------------------------------------------------------