├── README.md
├── latentode
├── latent_ode
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── generate_timeseries.py
│ ├── lib
│ │ ├── base_models.py
│ │ ├── create_latent_ode_model.py
│ │ ├── diffeq_solver.py
│ │ ├── encoder_decoder.py
│ │ ├── latent_ode.py
│ │ ├── likelihood_eval.py
│ │ ├── ode_func.py
│ │ ├── ode_rnn.py
│ │ ├── parse_datasets.py
│ │ ├── plotting.py
│ │ ├── rnn_baselines.py
│ │ └── utils.py
│ ├── mujoco_physics.py
│ ├── person_activity.py
│ ├── physionet.py
│ ├── run_models.py
│ ├── train-activity.sh
│ ├── train-ecg.sh
│ ├── train-periodic100.sh
│ └── train-periodic1000.sh
└── ours_impl
│ ├── __pycache__
│ └── lv_field.cpython-38.pyc
│ └── lv_field.py
├── robotic
├── .combine_stats.py.swp
├── .ipynb_checkpoints
│ └── Untitled-checkpoint.ipynb
├── .run_stiff_time.py.swp
├── 3D_Cshape_top.mat
├── S_gt_traj.npy
├── bench_output
│ └── out.txt
├── cube_pick.npy
└── run.py
├── rotating_MNIST
├── LV_stat_dic_0.tar
├── autoencoder_state_dic_0.tar
├── evaluate.py
├── readme.md
├── rotating3_2.pdf
├── run_rotate.py
└── samples
│ ├── t10k-images-idx3-ubyte
│ ├── t10k-labels-idx1-ubyte
│ ├── train-images-idx3-ubyte
│ └── train-labels-idx1-ubyte
├── run-all.sh
├── stiff_ode
├── run_rober.py
└── run_rober_comp_dopri.py
└── systems
├── LV_run_theirs.py
├── LV_train_run_ours.py
├── LV_train_theirs.py
├── Untitled.ipynb
├── f_x_base_save_good_eod2.tar
├── inn2_save_good_eod2.tar
├── lor_run_all.py
├── lor_train_ours.py
└── lor_train_theirs.py
/README.md:
--------------------------------------------------------------------------------
1 | # ODElearning_INN
2 | [ICML 2022] Learning Efficient and Robust Ordinary Differential \\ Equations via Invertible Neural Networks
3 |
4 | Code requires Torch, Torchdiffeq (for neural ODEs), FrEIA (for invertible NNs), and for the latent ODE problems, borrows heavily from code from the original Latent ODE work.
5 | To understand the method, it is recommended to start by looking at the code in "systems".
6 | Code has yet to be refactored for clarity and readability -- this is coming soon.
7 |
--------------------------------------------------------------------------------
/latentode/latent_ode/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/latentode/latent_ode/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yulia Rubanova
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/latentode/latent_ode/README.md:
--------------------------------------------------------------------------------
1 | # Latent ODEs for Irregularly-Sampled Time Series
2 |
3 | Code for the paper:
4 | > Yulia Rubanova, Ricky Chen, David Duvenaud. "Latent ODEs for Irregularly-Sampled Time Series" (2019)
5 | [[arxiv]](https://arxiv.org/abs/1907.03907)
6 |
7 |
8 |
9 |
10 |
11 | ## Prerequisites
12 |
13 | Install `torchdiffeq` from https://github.com/rtqichen/torchdiffeq.
14 |
15 | ## Experiments on different datasets
16 |
17 | By default, the dataset are downloadeded and processed when script is run for the first time.
18 |
19 | Raw datasets:
20 | [[MuJoCo]](http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt)
21 | [[Physionet]](https://physionet.org/physiobank/database/challenge/2012/)
22 | [[Human Activity]](https://archive.ics.uci.edu/ml/datasets/Localization+Data+for+Person+Activity/)
23 |
24 | To generate MuJoCo trajectories from scratch, [DeepMind Control Suite](https://github.com/deepmind/dm_control/) is required
25 |
26 |
27 | * Toy dataset of 1d periodic functions
28 | ```
29 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01
30 | ```
31 |
32 | * MuJoCo
33 |
34 | ```
35 | python3 run_models.py --niters 300 -n 10000 -l 15 --dataset hopper --latent-ode --rec-dims 30 --gru-units 100 --units 300 --gen-layers 3 --rec-layers 3
36 | ```
37 |
38 | * Physionet (discretization by 1 min)
39 | ```
40 | python3 run_models.py --niters 100 -n 8000 -l 20 --dataset physionet --latent-ode --rec-dims 40 --rec-layers 3 --gen-layers 3 --units 50 --gru-units 50 --quantization 0.016 --classif
41 |
42 | ```
43 |
44 | * Human Activity
45 | ```
46 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif
47 |
48 | ```
49 |
50 |
51 | ### Running different models
52 |
53 | * ODE-RNN
54 | ```
55 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --ode-rnn
56 | ```
57 |
58 | * Latent ODE with ODE-RNN encoder
59 | ```
60 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode
61 | ```
62 |
63 | * Latent ODE with ODE-RNN encoder and poisson likelihood
64 | ```
65 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --poisson
66 | ```
67 |
68 | * Latent ODE with RNN encoder (Chen et al, 2018)
69 | ```
70 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --z0-encoder rnn
71 | ```
72 |
73 | * RNN-VAE
74 | ```
75 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --rnn-vae
76 | ```
77 |
78 | * Classic RNN
79 | ```
80 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --classic-rnn
81 | ```
82 |
83 | * GRU-D
84 |
85 | GRU-D consists of two parts: input imputation (--input-decay) and exponential decay of the hidden state (--rnn-cell expdecay)
86 |
87 | ```
88 | python3 run_models.py --niters 500 -n 100 -b 30 -l 10 --dataset periodic --classic-rnn --input-decay --rnn-cell expdecay
89 | ```
90 |
91 |
92 | ### Making the visualization
93 | ```
94 | python3 run_models.py --niters 100 -n 5000 -b 100 -l 3 --dataset periodic --latent-ode --noise-weight 0.5 --lr 0.01 --viz --rec-layers 2 --gen-layers 2 -u 100 -c 30
95 | ```
96 |
--------------------------------------------------------------------------------
/latentode/latent_ode/generate_timeseries.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | # Create a synthetic dataset
7 | from __future__ import absolute_import, division
8 | from __future__ import print_function
9 | import os
10 | import matplotlib
11 | # if os.path.exists("/Users/yulia"):
12 | # matplotlib.use('TkAgg')
13 | # else:
14 | # matplotlib.use('Agg')
15 |
16 | import numpy as np
17 | import numpy.random as npr
18 | from scipy.special import expit as sigmoid
19 | import pickle
20 | import matplotlib.pyplot as plt
21 | import matplotlib.image
22 | import torch
23 | import lib.utils as utils
24 |
25 | # ======================================================================================
26 |
27 | def get_next_val(init, t, tmin, tmax, final = None):
28 | if final is None:
29 | return init
30 | val = init + (final - init) / (tmax - tmin) * t
31 | return val
32 |
33 |
34 | def generate_periodic(time_steps, init_freq, init_amplitude, starting_point,
35 | final_freq = None, final_amplitude = None, phi_offset = 0.):
36 |
37 | tmin = time_steps.min()
38 | tmax = time_steps.max()
39 |
40 | data = []
41 | t_prev = time_steps[0]
42 | phi = phi_offset
43 | for t in time_steps:
44 | dt = t - t_prev
45 | amp = get_next_val(init_amplitude, t, tmin, tmax, final_amplitude)
46 | freq = get_next_val(init_freq, t, tmin, tmax, final_freq)
47 | phi = phi + 2 * np.pi * freq * dt # integrate to get phase
48 |
49 | y = amp * np.sin(phi) + starting_point
50 | t_prev = t
51 | data.append([t,y])
52 | return np.array(data)
53 |
54 | def assign_value_or_sample(value, sampling_interval = [0.,1.]):
55 | if value is None:
56 | int_length = sampling_interval[1] - sampling_interval[0]
57 | return np.random.random() * int_length + sampling_interval[0]
58 | else:
59 | return value
60 |
61 | class TimeSeries:
62 | def __init__(self, device = torch.device("cpu")):
63 | self.device = device
64 | self.z0 = None
65 |
66 | def init_visualization(self):
67 | self.fig = plt.figure(figsize=(10, 4), facecolor='white')
68 | self.ax = self.fig.add_subplot(111, frameon=False)
69 | plt.show(block=False)
70 |
71 | def visualize(self, truth):
72 | self.ax.plot(truth[:,0], truth[:,1])
73 |
74 | def add_noise(self, traj_list, time_steps, noise_weight):
75 | n_samples = traj_list.size(0)
76 |
77 | # Add noise to all the points except the first point
78 | n_tp = len(time_steps) - 1
79 | noise = np.random.sample((n_samples, n_tp))
80 | noise = torch.Tensor(noise).to(self.device)
81 |
82 | traj_list_w_noise = traj_list.clone()
83 | # Dimension [:,:,0] is a time dimension -- do not add noise to that
84 | traj_list_w_noise[:,1:,0] += noise_weight * noise
85 | return traj_list_w_noise
86 |
87 |
88 |
89 | class Periodic_1d(TimeSeries):
90 | def __init__(self, device = torch.device("cpu"),
91 | init_freq = 0.3, init_amplitude = 1.,
92 | final_amplitude = 10., final_freq = 1.,
93 | z0 = 0.):
94 | """
95 | If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
96 | For now, all the time series share the time points and the starting point.
97 | """
98 | super(Periodic_1d, self).__init__(device)
99 |
100 | self.init_freq = init_freq
101 | self.init_amplitude = init_amplitude
102 | self.final_amplitude = final_amplitude
103 | self.final_freq = final_freq
104 | self.z0 = z0
105 |
106 | def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1.,
107 | cut_out_section = None):
108 | """
109 | Sample periodic functions.
110 | """
111 | traj_list = []
112 | for i in range(n_samples):
113 | init_freq = assign_value_or_sample(self.init_freq, [0.4,0.8])
114 | if self.final_freq is None:
115 | final_freq = init_freq
116 | else:
117 | final_freq = assign_value_or_sample(self.final_freq, [0.4,0.8])
118 | init_amplitude = assign_value_or_sample(self.init_amplitude, [0.,1.])
119 | final_amplitude = assign_value_or_sample(self.final_amplitude, [0.,1.])
120 |
121 | noisy_z0 = self.z0 + np.random.normal(loc=0., scale=0.1)
122 |
123 | traj = generate_periodic(time_steps, init_freq = init_freq,
124 | init_amplitude = init_amplitude, starting_point = noisy_z0,
125 | final_amplitude = final_amplitude, final_freq = final_freq)
126 |
127 | # Cut the time dimension
128 | traj = np.expand_dims(traj[:,1:], 0)
129 | traj_list.append(traj)
130 |
131 | # shape: [n_samples, n_timesteps, 2]
132 | # traj_list[:,:,0] -- time stamps
133 | # traj_list[:,:,1] -- values at the time stamps
134 | traj_list = np.array(traj_list)
135 | traj_list = torch.Tensor().new_tensor(traj_list, device = self.device)
136 | traj_list = traj_list.squeeze(1)
137 |
138 | traj_list = self.add_noise(traj_list, time_steps, noise_weight)
139 | return traj_list
140 |
141 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/base_models.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 |
11 | import lib.utils as utils
12 | from lib.encoder_decoder import *
13 | from lib.likelihood_eval import *
14 |
15 | from torch.distributions.multivariate_normal import MultivariateNormal
16 | from torch.distributions.normal import Normal
17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
18 |
19 | from torch.distributions.normal import Normal
20 | from torch.distributions import Independent
21 | from torch.nn.parameter import Parameter
22 |
23 |
24 | def create_classifier(z0_dim, n_labels):
25 | return nn.Sequential(
26 | nn.Linear(z0_dim, 300),
27 | nn.ReLU(),
28 | nn.Linear(300, 300),
29 | nn.ReLU(),
30 | nn.Linear(300, n_labels),)
31 |
32 |
33 | class Baseline(nn.Module):
34 | def __init__(self, input_dim, latent_dim, device,
35 | obsrv_std = 0.01, use_binary_classif = False,
36 | classif_per_tp = False,
37 | use_poisson_proc = False,
38 | linear_classifier = False,
39 | n_labels = 1,
40 | train_classif_w_reconstr = False):
41 | super(Baseline, self).__init__()
42 |
43 | self.input_dim = input_dim
44 | self.latent_dim = latent_dim
45 | self.n_labels = n_labels
46 |
47 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
48 | self.device = device
49 |
50 | self.use_binary_classif = use_binary_classif
51 | self.classif_per_tp = classif_per_tp
52 | self.use_poisson_proc = use_poisson_proc
53 | self.linear_classifier = linear_classifier
54 | self.train_classif_w_reconstr = train_classif_w_reconstr
55 |
56 | z0_dim = latent_dim
57 | if use_poisson_proc:
58 | z0_dim += latent_dim
59 |
60 | if use_binary_classif:
61 | if linear_classifier:
62 | self.classifier = nn.Sequential(
63 | nn.Linear(z0_dim, n_labels))
64 | else:
65 | self.classifier = create_classifier(z0_dim, n_labels)
66 | utils.init_network_weights(self.classifier)
67 |
68 |
69 | def get_gaussian_likelihood(self, truth, pred_y, mask = None):
70 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
71 | # truth shape [n_traj, n_tp, n_dim]
72 | if mask is not None:
73 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
74 |
75 | # Compute likelihood of the data under the predictions
76 | log_density_data = masked_gaussian_log_density(pred_y, truth,
77 | obsrv_std = self.obsrv_std, mask = mask)
78 | log_density_data = log_density_data.permute(1,0)
79 |
80 | # Compute the total density
81 | # Take mean over n_traj_samples
82 | log_density = torch.mean(log_density_data, 0)
83 |
84 | # shape: [n_traj]
85 | return log_density
86 |
87 |
88 | def get_mse(self, truth, pred_y, mask = None):
89 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
90 | # truth shape [n_traj, n_tp, n_dim]
91 | if mask is not None:
92 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
93 |
94 | # Compute likelihood of the data under the predictions
95 | log_density_data = compute_mse(pred_y, truth, mask = mask)
96 | # shape: [1]
97 | return torch.mean(log_density_data)
98 |
99 |
100 | def compute_all_losses(self, batch_dict,
101 | n_tp_to_sample = None, n_traj_samples = 1, kl_coef = 1.):
102 |
103 | # Condition on subsampled points
104 | # Make predictions for all the points
105 | pred_x, info = self.get_reconstruction(batch_dict["tp_to_predict"],
106 | batch_dict["observed_data"], batch_dict["observed_tp"],
107 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
108 | mode = batch_dict["mode"])
109 |
110 | # Compute likelihood of all the points
111 | likelihood = self.get_gaussian_likelihood(batch_dict["data_to_predict"], pred_x,
112 | mask = batch_dict["mask_predicted_data"])
113 |
114 | mse = self.get_mse(batch_dict["data_to_predict"], pred_x,
115 | mask = batch_dict["mask_predicted_data"])
116 |
117 | ################################
118 | # Compute CE loss for binary classification on Physionet
119 | # Use only last attribute -- mortatility in the hospital
120 | device = get_device(batch_dict["data_to_predict"])
121 | ce_loss = torch.Tensor([0.]).to(device)
122 |
123 | if (batch_dict["labels"] is not None) and self.use_binary_classif:
124 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
125 | ce_loss = compute_binary_CE_loss(
126 | info["label_predictions"],
127 | batch_dict["labels"])
128 | else:
129 | ce_loss = compute_multiclass_CE_loss(
130 | info["label_predictions"],
131 | batch_dict["labels"],
132 | mask = batch_dict["mask_predicted_data"])
133 |
134 | if torch.isnan(ce_loss):
135 | print("label pred")
136 | print(info["label_predictions"])
137 | print("labels")
138 | print( batch_dict["labels"])
139 | raise Exception("CE loss is Nan!")
140 |
141 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
142 | if self.use_poisson_proc:
143 | pois_log_likelihood = compute_poisson_proc_likelihood(
144 | batch_dict["data_to_predict"], pred_x,
145 | info, mask = batch_dict["mask_predicted_data"])
146 | # Take mean over n_traj
147 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1)
148 |
149 | loss = - torch.mean(likelihood)
150 |
151 | if self.use_poisson_proc:
152 | loss = loss - 0.1 * pois_log_likelihood
153 |
154 | if self.use_binary_classif:
155 | if self.train_classif_w_reconstr:
156 | loss = loss + ce_loss * 100
157 | else:
158 | loss = ce_loss
159 |
160 | # Take mean over the number of samples in a batch
161 | results = {}
162 | results["loss"] = torch.mean(loss)
163 | results["likelihood"] = torch.mean(likelihood).detach()
164 | results["mse"] = torch.mean(mse).detach()
165 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
166 | results["ce_loss"] = torch.mean(ce_loss).detach()
167 | results["kl"] = 0.
168 | results["kl_first_p"] = 0.
169 | results["std_first_p"] = 0.
170 |
171 | if batch_dict["labels"] is not None and self.use_binary_classif:
172 | results["label_predictions"] = info["label_predictions"].detach()
173 | return results
174 |
175 |
176 |
177 | class VAE_Baseline(nn.Module):
178 | def __init__(self, input_dim, latent_dim,
179 | z0_prior, device,
180 | obsrv_std = 0.01,
181 | use_binary_classif = False,
182 | classif_per_tp = False,
183 | use_poisson_proc = False,
184 | linear_classifier = False,
185 | n_labels = 1,
186 | train_classif_w_reconstr = False):
187 |
188 | super(VAE_Baseline, self).__init__()
189 |
190 | self.input_dim = input_dim
191 | self.latent_dim = latent_dim
192 | self.device = device
193 | self.n_labels = n_labels
194 |
195 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
196 |
197 | self.z0_prior = z0_prior
198 | self.use_binary_classif = use_binary_classif
199 | self.classif_per_tp = classif_per_tp
200 | self.use_poisson_proc = use_poisson_proc
201 | self.linear_classifier = linear_classifier
202 | self.train_classif_w_reconstr = train_classif_w_reconstr
203 |
204 | z0_dim = latent_dim
205 | if use_poisson_proc:
206 | z0_dim += latent_dim
207 |
208 | if use_binary_classif:
209 | if linear_classifier:
210 | self.classifier = nn.Sequential(
211 | nn.Linear(z0_dim, n_labels))
212 | else:
213 | self.classifier = create_classifier(z0_dim, n_labels)
214 | utils.init_network_weights(self.classifier)
215 |
216 |
217 | def get_gaussian_likelihood(self, truth, pred_y, mask = None):
218 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
219 | # truth shape [n_traj, n_tp, n_dim]
220 | n_traj, n_tp, n_dim = truth.size()
221 |
222 | # Compute likelihood of the data under the predictions
223 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
224 |
225 | if mask is not None:
226 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
227 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated,
228 | obsrv_std = self.obsrv_std, mask = mask)
229 | log_density_data = log_density_data.permute(1,0)
230 | log_density = torch.mean(log_density_data, 1)
231 |
232 | # shape: [n_traj_samples]
233 | return log_density
234 |
235 |
236 | def get_mse(self, truth, pred_y, mask = None):
237 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
238 | # truth shape [n_traj, n_tp, n_dim]
239 | n_traj, n_tp, n_dim = truth.size()
240 |
241 | # Compute likelihood of the data under the predictions
242 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
243 |
244 | if mask is not None:
245 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
246 |
247 | # Compute likelihood of the data under the predictions
248 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask)
249 | # shape: [1]
250 | return torch.mean(log_density_data)
251 |
252 |
253 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1., weight=None):
254 | # Condition on subsampled points
255 | # Make predictions for all the points
256 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"],
257 | batch_dict["observed_data"], batch_dict["observed_tp"],
258 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
259 | mode = batch_dict["mode"])
260 |
261 | #print("get_reconstruction done -- computing likelihood")
262 | fp_mu, fp_std, fp_enc = info["first_point"]
263 | fp_std = fp_std.abs()
264 | fp_distr = Normal(fp_mu, fp_std)
265 |
266 | assert(torch.sum(fp_std < 0) == 0.)
267 |
268 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior)
269 |
270 | if torch.isnan(kldiv_z0).any():
271 | print(fp_mu)
272 | print(fp_std)
273 | raise Exception("kldiv_z0 is Nan!")
274 |
275 | # Mean over number of latent dimensions
276 | # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated)
277 | # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly)
278 | # shape after: [n_traj_samples]
279 | kldiv_z0 = torch.mean(kldiv_z0,(1,2))
280 |
281 | # Compute likelihood of all the points
282 | rec_likelihood = self.get_gaussian_likelihood(
283 | batch_dict["data_to_predict"], pred_y,
284 | mask = batch_dict["mask_predicted_data"])
285 |
286 | mse = self.get_mse(
287 | batch_dict["data_to_predict"], pred_y,
288 | mask = batch_dict["mask_predicted_data"])
289 |
290 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
291 | if self.use_poisson_proc:
292 | pois_log_likelihood = compute_poisson_proc_likelihood(
293 | batch_dict["data_to_predict"], pred_y,
294 | info, mask = batch_dict["mask_predicted_data"])
295 | # Take mean over n_traj
296 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1)
297 |
298 | ################################
299 | # Compute CE loss for binary classification on Physionet
300 | device = get_device(batch_dict["data_to_predict"])
301 | ce_loss = torch.Tensor([0.]).to(device)
302 | # print(info["label_predictions"])
303 | # print(batch_dict["labels"])
304 | # print(batch_dict["mask_predicted_data"])
305 | # print(info["label_predictions"].shape)
306 | # print(batch_dict["labels"].shape)
307 | # print(batch_dict["mask_predicted_data"].shape)
308 | # exit()
309 | if (batch_dict["labels"] is not None) and self.use_binary_classif:
310 |
311 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
312 | ce_loss = compute_binary_CE_loss(
313 | info["label_predictions"],
314 | batch_dict["labels"])
315 | else:
316 | ce_loss = compute_multiclass_CE_loss(
317 | info["label_predictions"],
318 | batch_dict["labels"],
319 | mask = batch_dict["mask_predicted_data"],
320 | weight=weight)
321 |
322 | # IWAE loss
323 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0)
324 | if torch.isnan(loss):
325 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)
326 |
327 | if self.use_poisson_proc:
328 | loss = loss - 0.1 * pois_log_likelihood
329 |
330 | if self.use_binary_classif:
331 | if self.train_classif_w_reconstr:
332 | loss = loss + ce_loss * 100
333 | else:
334 | loss = ce_loss
335 |
336 | results = {}
337 | results["loss"] = torch.mean(loss)
338 | results["likelihood"] = torch.mean(rec_likelihood).detach()
339 | results["mse"] = torch.mean(mse).detach()
340 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
341 | results["ce_loss"] = torch.mean(ce_loss).detach()
342 | results["kl_first_p"] = torch.mean(kldiv_z0).detach()
343 | results["std_first_p"] = torch.mean(fp_std).detach()
344 |
345 | if batch_dict["labels"] is not None and self.use_binary_classif:
346 | results["label_predictions"] = info["label_predictions"].detach()
347 |
348 | return results
349 |
350 |
351 |
352 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/create_latent_ode_model.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import os
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn.functional import relu
12 |
13 | import lib.utils as utils
14 | from lib.latent_ode import LatentODE
15 | from lib.encoder_decoder import *
16 | from lib.diffeq_solver import DiffeqSolver
17 |
18 | import sys
19 | sys.path.insert(0, '../ours_impl')
20 | from lv_field import LVField
21 |
22 | from torch.distributions.normal import Normal
23 | from lib.ode_func import ODEFunc, ODEFunc_w_Poisson
24 |
25 | #####################################################################################################
26 |
27 | def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device,
28 | classif_per_tp = False, n_labels = 1,
29 | BLOB_OF_MODEL_SETTINGS=None
30 | ):
31 |
32 | dim = args.latents
33 | if args.poisson:
34 | lambda_net = utils.create_net(dim, input_dim,
35 | n_layers = 1, n_units = args.units, nonlinear = nn.Tanh)
36 |
37 | # ODE function produces the gradient for latent state and for poisson rate
38 | ode_func_net = utils.create_net(dim * 2, args.latents * 2,
39 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)
40 |
41 | gen_ode_func = ODEFunc_w_Poisson(
42 | input_dim = input_dim,
43 | latent_dim = args.latents * 2,
44 | ode_func_net = ode_func_net,
45 | lambda_net = lambda_net,
46 | device = device).to(device)
47 | else:
48 | dim = args.latents
49 | ode_func_net = utils.create_net(dim, args.latents,
50 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)
51 |
52 | gen_ode_func = ODEFunc(
53 | input_dim = input_dim,
54 | latent_dim = args.latents,
55 | ode_func_net = ode_func_net,
56 | device = device).to(device)
57 |
58 | z0_diffeq_solver = None
59 | n_rec_dims = args.rec_dims
60 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask
61 | gen_data_dim = input_dim
62 |
63 | z0_dim = args.latents
64 | if args.poisson:
65 | z0_dim += args.latents # predict the initial poisson rate
66 |
67 | if args.z0_encoder == "odernn":
68 | ode_func_net = utils.create_net(n_rec_dims, n_rec_dims,
69 | n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)
70 |
71 |
72 | # using = 'ours'
73 | # using = 'theirs'
74 |
75 |
76 | # if using == 'ours':
77 |
78 | # z0_diffeq_solver = LVField(
79 | # dim=enc_input_dim * args.latents,
80 | # augmented_dim=input_dim * args.latents + 5,
81 | # num_layers=3,
82 | # )
83 | # # pred_y = self.lv_field(first_point, time_steps_to_predict)
84 |
85 |
86 | # elif using == 'theirs':
87 |
88 | if True:
89 | """
90 | We won't be using this as the integration time-steps within encoder tends to be very short.
91 | """
92 | rec_ode_func = ODEFunc(
93 | input_dim = enc_input_dim,
94 | latent_dim = n_rec_dims,
95 | ode_func_net = ode_func_net,
96 | device = device).to(device)
97 |
98 | z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents,
99 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
100 |
101 |
102 | encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver,
103 | z0_dim = z0_dim, n_gru_units = args.gru_units, device = device).to(device)
104 |
105 |
106 |
107 | elif args.z0_encoder == "rnn":
108 | encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim,
109 | lstm_output_size = n_rec_dims, device = device).to(device)
110 | else:
111 | raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder)
112 |
113 | decoder = Decoder(args.latents, gen_data_dim).to(device)
114 |
115 | if BLOB_OF_MODEL_SETTINGS['model'] == 'node':
116 | diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func,
117 | BLOB_OF_MODEL_SETTINGS['node__int_method'],
118 | # 'euler',
119 | args.latents,
120 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
121 |
122 | elif BLOB_OF_MODEL_SETTINGS['model'] == 'ours':
123 | e_vec_factor = 1.0
124 | t_d_factor = 1.0
125 | if BLOB_OF_MODEL_SETTINGS['dataset'] == 'activity':
126 | e_vec_factor = 0.001
127 | t_d_factor = 0.1
128 | elif BLOB_OF_MODEL_SETTINGS['dataset'] == 'periodic':
129 | if BLOB_OF_MODEL_SETTINGS['timepoints'] == 1000:
130 | e_vec_factor = 1.0
131 | t_d_factor = 0.01
132 | diffeq_solver = LVField(
133 | dim=args.latents,
134 | augmented_dim=args.latents + BLOB_OF_MODEL_SETTINGS['ours__extra_augnmented_dim'],
135 | num_layers=BLOB_OF_MODEL_SETTINGS['ours__num_layers'],
136 | hidden_dim=BLOB_OF_MODEL_SETTINGS['ours__inn_hidden_dim'],
137 |
138 | e_vec_factor=e_vec_factor,
139 | t_d_factor=t_d_factor,
140 | )
141 |
142 | model = LatentODE(
143 | input_dim = gen_data_dim,
144 | latent_dim = args.latents,
145 | encoder_z0 = encoder_z0,
146 | decoder = decoder,
147 | diffeq_solver = diffeq_solver,
148 | z0_prior = z0_prior,
149 | device = device,
150 | obsrv_std = obsrv_std,
151 | use_poisson_proc = args.poisson,
152 | use_binary_classif = args.classif,
153 | linear_classifier = args.linear_classif,
154 | classif_per_tp = classif_per_tp,
155 | n_labels = n_labels,
156 | train_classif_w_reconstr = (args.dataset == "physionet")
157 | ).to(device)
158 |
159 | return model
160 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/diffeq_solver.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import time
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | import lib.utils as utils
13 | from torch.distributions.multivariate_normal import MultivariateNormal
14 |
15 | # git clone https://github.com/rtqichen/torchdiffeq.git
16 | from torchdiffeq import odeint as odeint
17 |
18 | #####################################################################################################
19 |
20 | class DiffeqSolver(nn.Module):
21 | def __init__(self, input_dim, ode_func, method, latents,
22 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")):
23 | super(DiffeqSolver, self).__init__()
24 |
25 | self.ode_method = method
26 | self.latents = latents
27 | self.device = device
28 | self.ode_func = ode_func
29 |
30 | self.odeint_rtol = odeint_rtol
31 | self.odeint_atol = odeint_atol
32 |
33 | def forward(self, first_point, time_steps_to_predict, backwards = False):
34 | """
35 | # Decode the trajectory through ODE Solver
36 | """
37 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1]
38 | n_dims = first_point.size()[-1]
39 |
40 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict,
41 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
42 | pred_y = pred_y.permute(1,2,0,3)
43 |
44 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001)
45 | assert(pred_y.size()[0] == n_traj_samples)
46 | assert(pred_y.size()[1] == n_traj)
47 |
48 | return pred_y
49 |
50 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict,
51 | n_traj_samples = 1):
52 | """
53 | # Decode the trajectory through ODE Solver using samples from the prior
54 |
55 | time_steps_to_predict: time steps at which we want to sample the new trajectory
56 | """
57 | func = self.ode_func.sample_next_point_from_prior
58 |
59 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict,
60 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
61 | # shape: [n_traj_samples, n_traj, n_tp, n_dim]
62 | pred_y = pred_y.permute(1,2,0,3)
63 | return pred_y
64 |
65 |
66 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 | import lib.utils as utils
11 | from torch.distributions import Categorical, Normal
12 | import lib.utils as utils
13 | from torch.nn.modules.rnn import LSTM, GRU
14 | from lib.utils import get_device
15 |
16 |
17 | # GRU description:
18 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
19 | class GRU_unit(nn.Module):
20 | def __init__(self, latent_dim, input_dim,
21 | update_gate = None,
22 | reset_gate = None,
23 | new_state_net = None,
24 | n_units = 100,
25 | device = torch.device("cpu")):
26 | super(GRU_unit, self).__init__()
27 |
28 | if update_gate is None:
29 | self.update_gate = nn.Sequential(
30 | nn.Linear(latent_dim * 2 + input_dim, n_units),
31 | nn.Tanh(),
32 | nn.Linear(n_units, latent_dim),
33 | nn.Sigmoid())
34 | utils.init_network_weights(self.update_gate)
35 | else:
36 | self.update_gate = update_gate
37 |
38 | if reset_gate is None:
39 | self.reset_gate = nn.Sequential(
40 | nn.Linear(latent_dim * 2 + input_dim, n_units),
41 | nn.Tanh(),
42 | nn.Linear(n_units, latent_dim),
43 | nn.Sigmoid())
44 | utils.init_network_weights(self.reset_gate)
45 | else:
46 | self.reset_gate = reset_gate
47 |
48 | if new_state_net is None:
49 | self.new_state_net = nn.Sequential(
50 | nn.Linear(latent_dim * 2 + input_dim, n_units),
51 | nn.Tanh(),
52 | nn.Linear(n_units, latent_dim * 2))
53 | utils.init_network_weights(self.new_state_net)
54 | else:
55 | self.new_state_net = new_state_net
56 |
57 |
58 | def forward(self, y_mean, y_std, x, masked_update = True):
59 | y_concat = torch.cat([y_mean, y_std, x], -1)
60 |
61 | update_gate = self.update_gate(y_concat)
62 | reset_gate = self.reset_gate(y_concat)
63 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)
64 |
65 | new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat))
66 | new_state_std = new_state_std.abs()
67 |
68 | new_y = (1-update_gate) * new_state + update_gate * y_mean
69 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std
70 |
71 | assert(not torch.isnan(new_y).any())
72 |
73 | if masked_update:
74 | # IMPORTANT: assumes that x contains both data and mask
75 | # update only the hidden states for hidden state only if at least one feature is present for the current time point
76 | n_data_dims = x.size(-1)//2
77 | mask = x[:, :, n_data_dims:]
78 | utils.check_mask(x[:, :, :n_data_dims], mask)
79 |
80 | mask = (torch.sum(mask, -1, keepdim = True) > 0).float()
81 |
82 | assert(not torch.isnan(mask).any())
83 |
84 | new_y = mask * new_y + (1-mask) * y_mean
85 | new_y_std = mask * new_y_std + (1-mask) * y_std
86 |
87 | if torch.isnan(new_y).any():
88 | print("new_y is nan!")
89 | print(mask)
90 | print(y_mean)
91 | print(prev_new_y)
92 | exit()
93 |
94 | new_y_std = new_y_std.abs()
95 | return new_y, new_y_std
96 |
97 |
98 |
99 | class Encoder_z0_RNN(nn.Module):
100 | def __init__(self, latent_dim, input_dim, lstm_output_size = 20,
101 | use_delta_t = True, device = torch.device("cpu")):
102 |
103 | super(Encoder_z0_RNN, self).__init__()
104 |
105 | self.gru_rnn_output_size = lstm_output_size
106 | self.latent_dim = latent_dim
107 | self.input_dim = input_dim
108 | self.device = device
109 | self.use_delta_t = use_delta_t
110 |
111 | self.hiddens_to_z0 = nn.Sequential(
112 | nn.Linear(self.gru_rnn_output_size, 50),
113 | nn.Tanh(),
114 | nn.Linear(50, latent_dim * 2),)
115 |
116 | utils.init_network_weights(self.hiddens_to_z0)
117 |
118 | input_dim = self.input_dim
119 |
120 | if use_delta_t:
121 | self.input_dim += 1
122 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device)
123 |
124 | def forward(self, data, time_steps, run_backwards = True):
125 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
126 |
127 | # data shape: [n_traj, n_tp, n_dims]
128 | # shape required for rnn: (seq_len, batch, input_size)
129 | # t0: not used here
130 | n_traj = data.size(0)
131 |
132 | assert(not torch.isnan(data).any())
133 | assert(not torch.isnan(time_steps).any())
134 |
135 | data = data.permute(1,0,2)
136 |
137 | if run_backwards:
138 | # Look at data in the reverse order: from later points to the first
139 | data = utils.reverse(data)
140 |
141 | if self.use_delta_t:
142 | delta_t = time_steps[1:] - time_steps[:-1]
143 | if run_backwards:
144 | # we are going backwards in time with
145 | delta_t = utils.reverse(delta_t)
146 | # append zero delta t in the end
147 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device)))
148 | delta_t = delta_t.unsqueeze(1).repeat((1,n_traj)).unsqueeze(-1)
149 | data = torch.cat((delta_t, data),-1)
150 |
151 | outputs, _ = self.gru_rnn(data)
152 |
153 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size)
154 | last_output = outputs[-1]
155 |
156 | self.extra_info ={"rnn_outputs": outputs, "time_points": time_steps}
157 |
158 | mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
159 | std = std.abs()
160 |
161 | assert(not torch.isnan(mean).any())
162 | assert(not torch.isnan(std).any())
163 |
164 | return mean.unsqueeze(0), std.unsqueeze(0)
165 |
166 |
167 |
168 |
169 |
170 | class Encoder_z0_ODE_RNN(nn.Module):
171 | # Derive z0 by running ode backwards.
172 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i
173 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1
174 | # Continue until we get to z0
175 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None,
176 | z0_dim = None, GRU_update = None,
177 | n_gru_units = 100,
178 | device = torch.device("cpu")):
179 |
180 | super(Encoder_z0_ODE_RNN, self).__init__()
181 |
182 | if z0_dim is None:
183 | self.z0_dim = latent_dim
184 | else:
185 | self.z0_dim = z0_dim
186 |
187 | if GRU_update is None:
188 | self.GRU_update = GRU_unit(latent_dim, input_dim,
189 | n_units = n_gru_units,
190 | device=device).to(device)
191 | else:
192 | self.GRU_update = GRU_update
193 |
194 | self.z0_diffeq_solver = z0_diffeq_solver
195 | self.latent_dim = latent_dim
196 | self.input_dim = input_dim
197 | self.device = device
198 | self.extra_info = None
199 |
200 | self.transform_z0 = nn.Sequential(
201 | nn.Linear(latent_dim * 2, 100),
202 | nn.Tanh(),
203 | nn.Linear(100, self.z0_dim * 2),)
204 | utils.init_network_weights(self.transform_z0)
205 |
206 |
207 | def forward(self, data, time_steps, run_backwards = True, save_info = False):
208 | # data, time_steps -- observations and their time stamps
209 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
210 | assert(not torch.isnan(data).any())
211 | assert(not torch.isnan(time_steps).any())
212 |
213 | n_traj, n_tp, n_dims = data.size()
214 | if len(time_steps) == 1:
215 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
216 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
217 |
218 | xi = data[:,0,:].unsqueeze(0)
219 |
220 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
221 | extra_info = None
222 | else:
223 |
224 | last_yi, last_yi_std, _, extra_info = self.run_odernn(
225 | data, time_steps, run_backwards = run_backwards,
226 | save_info = save_info)
227 |
228 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim)
229 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim)
230 |
231 | mean_z0, std_z0 = utils.split_last_dim( self.transform_z0( torch.cat((means_z0, std_z0), -1)))
232 | std_z0 = std_z0.abs()
233 | if save_info:
234 | self.extra_info = extra_info
235 |
236 | return mean_z0, std_z0
237 |
238 |
239 | def run_odernn(self, data, time_steps,
240 | run_backwards = True, save_info = False):
241 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
242 |
243 | n_traj, n_tp, n_dims = data.size()
244 | extra_info = []
245 |
246 | t0 = time_steps[-1]
247 | if run_backwards:
248 | t0 = time_steps[0]
249 |
250 | device = get_device(data)
251 |
252 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
253 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)
254 |
255 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]
256 |
257 | interval_length = time_steps[-1] - time_steps[0]
258 | minimum_step = interval_length / 50
259 |
260 | #print("minimum step: {}".format(minimum_step))
261 |
262 | assert(not torch.isnan(data).any())
263 | assert(not torch.isnan(time_steps).any())
264 |
265 | latent_ys = []
266 | # Run ODE backwards and combine the y(t) estimates using gating
267 | time_points_iter = range(0, len(time_steps))
268 | if run_backwards:
269 | time_points_iter = reversed(time_points_iter)
270 |
271 | for i in time_points_iter:
272 | """
273 | if (prev_t - t_i) < minimum_step:
274 |
275 | time_points = torch.stack((prev_t, t_i))
276 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t)
277 |
278 | assert(not torch.isnan(inc).any())
279 |
280 | ode_sol = prev_y + inc
281 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)
282 |
283 | assert(not torch.isnan(ode_sol).any())
284 | else:
285 | """
286 | # Skip doing the above hacks
287 | if 1:
288 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int())
289 |
290 | time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp).to(device)
291 | ode_sol = self.z0_diffeq_solver(prev_y, time_points)
292 |
293 | assert(not torch.isnan(ode_sol).any())
294 |
295 | if False and torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
296 | print("Error: first point of the ODE is not equal to initial value")
297 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
298 | exit()
299 | #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001)
300 |
301 | yi_ode = ode_sol[:, :, -1, :]
302 | xi = data[:,i,:].unsqueeze(0)
303 |
304 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)
305 |
306 | prev_y, prev_std = yi, yi_std
307 | prev_t, t_i = time_steps[i], time_steps[i-1]
308 |
309 | latent_ys.append(yi)
310 |
311 | if save_info:
312 | d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data,
313 | "yi": yi.detach(), "yi_std": yi_std.detach(),
314 | "time_points": time_points.detach(), "ode_sol": ode_sol.detach()}
315 | extra_info.append(d)
316 |
317 | latent_ys = torch.stack(latent_ys, 1)
318 |
319 | assert(not torch.isnan(yi).any())
320 | assert(not torch.isnan(yi_std).any())
321 |
322 | return yi, yi_std, latent_ys, extra_info
323 |
324 |
325 |
326 | class Decoder(nn.Module):
327 | def __init__(self, latent_dim, input_dim):
328 | super(Decoder, self).__init__()
329 | # decode data from latent space where we are solving an ODE back to the data space
330 |
331 | decoder = nn.Sequential(
332 | nn.Linear(latent_dim, input_dim),)
333 |
334 | utils.init_network_weights(decoder)
335 | self.decoder = decoder
336 |
337 | def forward(self, data):
338 | return self.decoder(data)
339 |
340 |
341 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/latent_ode.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import sklearn as sk
8 | import numpy as np
9 | #import gc
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn.functional import relu
13 |
14 | import lib.utils as utils
15 | from lib.utils import get_device
16 | from lib.encoder_decoder import *
17 | from lib.likelihood_eval import *
18 |
19 | from torch.distributions.multivariate_normal import MultivariateNormal
20 | from torch.distributions.normal import Normal
21 | from torch.distributions import kl_divergence, Independent
22 | from lib.base_models import VAE_Baseline
23 |
24 |
25 |
26 | class LatentODE(VAE_Baseline):
27 | def __init__(self, input_dim, latent_dim, encoder_z0, decoder, diffeq_solver,
28 | z0_prior, device, obsrv_std = None,
29 | use_binary_classif = False, use_poisson_proc = False,
30 | linear_classifier = False,
31 | classif_per_tp = False,
32 | n_labels = 1,
33 | train_classif_w_reconstr = False):
34 |
35 | super(LatentODE, self).__init__(
36 | input_dim = input_dim, latent_dim = latent_dim,
37 | z0_prior = z0_prior,
38 | device = device, obsrv_std = obsrv_std,
39 | use_binary_classif = use_binary_classif,
40 | classif_per_tp = classif_per_tp,
41 | linear_classifier = linear_classifier,
42 | use_poisson_proc = use_poisson_proc,
43 | n_labels = n_labels,
44 | train_classif_w_reconstr = train_classif_w_reconstr)
45 |
46 | self.encoder_z0 = encoder_z0
47 | self.diffeq_solver = diffeq_solver
48 | self.decoder = decoder
49 | self.use_poisson_proc = use_poisson_proc
50 |
51 | def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps,
52 | mask = None, n_traj_samples = 1, run_backwards = True, mode = None):
53 |
54 | if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \
55 | isinstance(self.encoder_z0, Encoder_z0_RNN):
56 |
57 | truth_w_mask = truth
58 | if mask is not None:
59 | truth_w_mask = torch.cat((truth, mask), -1)
60 | first_point_mu, first_point_std = self.encoder_z0(
61 | truth_w_mask, truth_time_steps, run_backwards = run_backwards)
62 |
63 | means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
64 | sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
65 | first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
66 |
67 | else:
68 | raise Exception("Unknown encoder type {}".format(type(self.encoder_z0).__name__))
69 |
70 | first_point_std = first_point_std.abs()
71 | assert(torch.sum(first_point_std < 0) == 0.)
72 |
73 | if self.use_poisson_proc:
74 | n_traj_samples, n_traj, n_dims = first_point_enc.size()
75 | # append a vector of zeros to compute the integral of lambda
76 | zeros = torch.zeros([n_traj_samples, n_traj,self.input_dim]).to(get_device(truth))
77 | first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
78 | means_z0_aug = torch.cat((means_z0, zeros), -1)
79 | else:
80 | first_point_enc_aug = first_point_enc
81 | means_z0_aug = means_z0
82 |
83 | assert(not torch.isnan(time_steps_to_predict).any())
84 | assert(not torch.isnan(first_point_enc).any())
85 | assert(not torch.isnan(first_point_enc_aug).any())
86 |
87 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
88 |
89 | # from soraxas_toolbox import timer
90 | # timer.stamp("diff")
91 | sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict)
92 | # timer.stamp("rest")
93 |
94 | if self.use_poisson_proc:
95 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
96 |
97 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.)
98 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)
99 |
100 | pred_x = self.decoder(sol_y)
101 |
102 | all_extra_info = {
103 | "first_point": (first_point_mu, first_point_std, first_point_enc),
104 | "latent_traj": sol_y.detach()
105 | }
106 |
107 | if self.use_poisson_proc:
108 | # intergral of lambda from the last step of ODE Solver
109 | all_extra_info["int_lambda"] = int_lambda[:,:,-1,:]
110 | all_extra_info["log_lambda_y"] = log_lambda_y
111 |
112 | if self.use_binary_classif:
113 | if self.classif_per_tp:
114 | all_extra_info["label_predictions"] = self.classifier(sol_y)
115 | else:
116 | all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1)
117 |
118 | return pred_x, all_extra_info
119 |
120 |
121 | def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples = 1):
122 | # input_dim = starting_point.size()[-1]
123 | # starting_point = starting_point.view(1,1,input_dim)
124 |
125 | # Sample z0 from prior
126 | starting_point_enc = self.z0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1)
127 |
128 | starting_point_enc_aug = starting_point_enc
129 | if self.use_poisson_proc:
130 | n_traj_samples, n_traj, n_dims = starting_point_enc.size()
131 | # append a vector of zeros to compute the integral of lambda
132 | zeros = torch.zeros(n_traj_samples, n_traj,self.input_dim).to(self.device)
133 | starting_point_enc_aug = torch.cat((starting_point_enc, zeros), -1)
134 |
135 | sol_y = self.diffeq_solver.sample_traj_from_prior(starting_point_enc_aug, time_steps_to_predict,
136 | n_traj_samples = 3)
137 |
138 | if self.use_poisson_proc:
139 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
140 |
141 | return self.decoder(sol_y)
142 |
143 |
144 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/likelihood_eval.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import gc
7 | import numpy as np
8 | import sklearn as sk
9 | import numpy as np
10 | #import gc
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn.functional import relu
14 |
15 | import lib.utils as utils
16 | from lib.utils import get_device
17 | from lib.encoder_decoder import *
18 | from lib.likelihood_eval import *
19 |
20 | from torch.distributions.multivariate_normal import MultivariateNormal
21 | from torch.distributions.normal import Normal
22 | from torch.distributions import kl_divergence, Independent
23 |
24 |
25 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
26 | n_data_points = mu_2d.size()[-1]
27 |
28 | if n_data_points > 0:
29 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
30 | log_prob = gaussian.log_prob(data_2d)
31 | log_prob = log_prob / n_data_points
32 | else:
33 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
34 | return log_prob
35 |
36 |
37 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas):
38 | # masked_log_lambdas and masked_data
39 | n_data_points = masked_data.size()[-1]
40 |
41 | if n_data_points > 0:
42 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
43 | #log_prob = log_prob / n_data_points
44 | else:
45 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
46 | return log_prob
47 |
48 |
49 |
50 | def compute_binary_CE_loss(label_predictions, mortality_label):
51 | #print("Computing binary classification loss: compute_CE_loss")
52 |
53 | mortality_label = mortality_label.reshape(-1)
54 |
55 | if len(label_predictions.size()) == 1:
56 | label_predictions = label_predictions.unsqueeze(0)
57 |
58 | n_traj_samples = label_predictions.size(0)
59 | label_predictions = label_predictions.reshape(n_traj_samples, -1)
60 |
61 | idx_not_nan = ~torch.isnan(mortality_label)
62 | if len(idx_not_nan) == 0.:
63 | print("All are labels are NaNs!")
64 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label))
65 |
66 | label_predictions = label_predictions[:,idx_not_nan]
67 | mortality_label = mortality_label[idx_not_nan]
68 |
69 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0:
70 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.")
71 |
72 | assert(not torch.isnan(label_predictions).any())
73 | assert(not torch.isnan(mortality_label).any())
74 |
75 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
76 | mortality_label = mortality_label.repeat(n_traj_samples, 1)
77 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)
78 |
79 | # divide by number of patients in a batch
80 | ce_loss = ce_loss / n_traj_samples
81 | return ce_loss
82 |
83 |
84 | def compute_multiclass_CE_loss(label_predictions, true_label, mask, weight=None):
85 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss")
86 |
87 | if (len(label_predictions.size()) == 3):
88 | label_predictions = label_predictions.unsqueeze(0)
89 |
90 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()
91 |
92 |
93 | _ori_mask = mask
94 |
95 | if False:
96 | # assert(not torch.isnan(label_predictions).any())
97 | # assert(not torch.isnan(true_label).any())
98 |
99 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
100 | true_label = true_label.repeat(n_traj_samples, 1, 1)
101 |
102 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims)
103 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)
104 |
105 | # choose time points with at least one measurement
106 | mask = torch.sum(mask, -1) > 0
107 |
108 | # repeat the mask for each label to mark that the label for this time point is present
109 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0)
110 |
111 | label_mask = mask
112 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1)
113 | label_mask = label_mask.repeat(n_traj_samples,1,1,1)
114 |
115 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
116 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)
117 |
118 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
119 | assert(label_predictions.size(-1) == true_label.size(-1))
120 | # targets are in one-hot encoding -- convert to indices
121 | _, true_label = true_label.max(-1)
122 |
123 | res = []
124 | for i in range(true_label.size(0)):
125 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool())
126 | labels = torch.masked_select(true_label[i], label_mask[i].bool())
127 |
128 | pred_masked = pred_masked.reshape(-1, n_dims)
129 |
130 | if (len(labels) == 0):
131 | continue
132 |
133 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
134 | res.append(ce_loss)
135 |
136 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
137 | ce_loss = torch.mean(ce_loss)
138 |
139 | else:
140 | cel = nn.CrossEntropyLoss(weight=weight)
141 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
142 | assert(label_predictions.size(-1) == true_label.size(-1))
143 | # targets are in one-hot encoding -- convert to indices
144 | _, true_label = true_label.max(-1)
145 |
146 | if (mask == 1).all():
147 | # SKIP applying for mask
148 | true_label = true_label.repeat(n_traj_samples * n_traj, 1, 1).flatten()
149 | # true_label = true_label.repeat(n_traj_samples, 1, 1).flatten()
150 | label_predictions = label_predictions.reshape(-1, n_dims)
151 | _ce_loss = cel(label_predictions, true_label.long())
152 | else:
153 | time_steps_to_consider = (torch.sum(_ori_mask, -1) > 0).repeat(n_traj_samples, 1, 1).flatten()
154 | true_label = true_label.repeat(n_traj_samples, 1, 1).flatten()
155 | label_predictions = label_predictions.reshape(-1, n_dims)
156 | _ce_loss = cel(label_predictions[time_steps_to_consider], true_label[time_steps_to_consider].long())
157 | ce_loss = _ce_loss
158 | # print(ce_loss)
159 | # print(_ce_loss)
160 |
161 |
162 |
163 | # # divide by number of patients in a batch
164 | # ce_loss = ce_loss / n_traj_samples
165 | return ce_loss
166 |
167 |
168 |
169 |
170 | def compute_masked_likelihood(mu, data, mask, likelihood_func, _type, obsrv_std=None):
171 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
172 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
173 |
174 | #time_steps_to_consider = (torch.sum(mask, -1) > 0).repeat(n_traj_samples, 1, 1)
175 | time_steps_to_consider = (torch.sum(mask, -1) > 0)
176 |
177 | if False:
178 |
179 | res = []
180 | for i in range(n_traj_samples):
181 | for k in range(n_traj):
182 | for j in range(n_dims):
183 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool())
184 |
185 | #assert(torch.sum(data_masked == 0.) < 10)
186 |
187 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool())
188 | # print(mu_masked.detach().cpu().numpy())
189 | # print(data_masked.detach().cpu().numpy())
190 | # print('---')
191 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j))
192 | res.append(log_prob)
193 | # shape: [n_traj*n_traj_samples, 1]
194 |
195 | res = torch.stack(res, 0).to(get_device(data))
196 | res = res.reshape((n_traj_samples, n_traj, n_dims))
197 | # Take mean over the number of dimensions
198 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean
199 | res = res.transpose(0,1)
200 |
201 |
202 |
203 |
204 | # create a mask to map original shape into an unstructured shape
205 | _mask = mask.bool()
206 | # set target shape as the one without n_timepoints
207 | target_shape = n_traj_samples, n_traj, n_dims
208 | # apply mask to obtained an unstructured tensor
209 | _mu = mu[_mask]
210 | _data = data[_mask]
211 |
212 | if _type == 'logprob':
213 | # compute log prob as unstructured data shape
214 | gaussian = Independent(Normal(loc = _mu, scale=obsrv_std), 0)
215 | log_prob = gaussian.log_prob(_data)
216 |
217 | # map the masked log-prob back to the original shape
218 | lol = torch.zeros(data.shape).to(data.device)
219 | lol[_mask] = log_prob
220 |
221 | # create a diviser to average the log prob across timepoints
222 | diviser = _mask.float().sum(2)
223 | # replace any zero as 1
224 | diviser[diviser==0] = 1
225 |
226 | # get mean along dim
227 | lol = lol.sum(2) / diviser
228 |
229 | elif _type == 'mse':
230 | # compute mse as unstructured data shape
231 | #mse = nn.MSELoss()(_mu, _data)
232 | mse = (_mu - _data)**2
233 |
234 | # map the masked mse back to the original shape
235 | lol = torch.zeros(data.shape).to(data.device)
236 | lol[_mask] = mse
237 |
238 | # create a diviser to average the log prob across timepoints
239 | diviser = _mask.float().sum(2)
240 | # replace any zero as 1
241 | diviser[diviser==0] = 1
242 |
243 | # get mean along timepoints
244 | lol = lol.sum(2) / diviser
245 |
246 | # marginalise the dim
247 | lol = lol.mean(-1)
248 |
249 | # transpose back to the expected shape
250 | lol = lol.transpose(0,1)
251 | # assert torch.isclose(lol, res).all(), (lol, res)
252 |
253 |
254 | return lol
255 |
256 |
257 | return res
258 |
259 |
260 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None):
261 | # these cases are for plotting through plot_estim_density
262 | if (len(mu.size()) == 3):
263 | # add additional dimension for gp samples
264 | mu = mu.unsqueeze(0)
265 |
266 | if (len(data.size()) == 2):
267 | # add additional dimension for gp samples and time step
268 | data = data.unsqueeze(0).unsqueeze(2)
269 | elif (len(data.size()) == 3):
270 | # add additional dimension for gp samples
271 | data = data.unsqueeze(0)
272 |
273 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
274 |
275 | assert(data.size()[-1] == n_dims)
276 |
277 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
278 | if mask is None:
279 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
280 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
281 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
282 |
283 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std)
284 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1)
285 | else:
286 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
287 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices)
288 | #res = compute_masked_likelihood(mu, data, mask, func)
289 | res = compute_masked_likelihood(mu, data, mask, func, _type='logprob', obsrv_std=obsrv_std)
290 | return res
291 |
292 |
293 |
294 | def mse(mu, data, indices = None):
295 | n_data_points = mu.size()[-1]
296 |
297 | if n_data_points > 0:
298 | mse = nn.MSELoss()(mu, data)
299 | else:
300 | mse = torch.zeros([1]).to(get_device(data)).squeeze()
301 | return mse
302 |
303 |
304 | def compute_mse(mu, data, mask = None):
305 | # these cases are for plotting through plot_estim_density
306 | if (len(mu.size()) == 3):
307 | # add additional dimension for gp samples
308 | mu = mu.unsqueeze(0)
309 |
310 | if (len(data.size()) == 2):
311 | # add additional dimension for gp samples and time step
312 | data = data.unsqueeze(0).unsqueeze(2)
313 | elif (len(data.size()) == 3):
314 | # add additional dimension for gp samples
315 | data = data.unsqueeze(0)
316 |
317 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
318 | assert(data.size()[-1] == n_dims)
319 |
320 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
321 | if mask is None:
322 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
323 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
324 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
325 | res = mse(mu_flat, data_flat)
326 | else:
327 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
328 | #res = compute_masked_likelihood(mu, data, mask, mse)
329 | res = compute_masked_likelihood(mu, data, mask, mse, _type='mse')
330 | return res
331 |
332 |
333 |
334 |
335 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None):
336 | # Compute Poisson likelihood
337 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process
338 | # Sum log lambdas across all time points
339 | if mask is None:
340 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"]
341 | # Sum over data dims
342 | poisson_log_l = torch.mean(poisson_log_l, -1)
343 | else:
344 | # Compute likelihood of the data under the predictions
345 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
346 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1)
347 |
348 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
349 | int_lambda = info["int_lambda"]
350 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda)
351 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f)
352 | poisson_log_l = poisson_log_l.permute(1,0)
353 | # Take mean over n_traj
354 | #poisson_log_l = torch.mean(poisson_log_l, 1)
355 |
356 | # poisson_log_l shape: [n_traj_samples, n_traj]
357 | return poisson_log_l
358 |
359 |
360 |
361 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/ode_func.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.utils.spectral_norm import spectral_norm
10 |
11 | import lib.utils as utils
12 |
13 | #####################################################################################################
14 |
15 | class ODEFunc(nn.Module):
16 | def __init__(self, input_dim, latent_dim, ode_func_net, device = torch.device("cpu")):
17 | """
18 | input_dim: dimensionality of the input
19 | latent_dim: dimensionality used for ODE. Analog of a continous latent state
20 | """
21 | super(ODEFunc, self).__init__()
22 |
23 | self.input_dim = input_dim
24 | self.device = device
25 |
26 | utils.init_network_weights(ode_func_net)
27 | self.gradient_net = ode_func_net
28 |
29 | def forward(self, t_local, y, backwards = False):
30 | """
31 | Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
32 |
33 | t_local: current time point
34 | y: value at the current time point
35 | """
36 | grad = self.get_ode_gradient_nn(t_local, y)
37 | if backwards:
38 | grad = -grad
39 | return grad
40 |
41 | def get_ode_gradient_nn(self, t_local, y):
42 | return self.gradient_net(y)
43 |
44 | def sample_next_point_from_prior(self, t_local, y):
45 | """
46 | t_local: current time point
47 | y: value at the current time point
48 | """
49 | return self.get_ode_gradient_nn(t_local, y)
50 |
51 | #####################################################################################################
52 |
53 | class ODEFunc_w_Poisson(ODEFunc):
54 |
55 | def __init__(self, input_dim, latent_dim, ode_func_net,
56 | lambda_net, device = torch.device("cpu")):
57 | """
58 | input_dim: dimensionality of the input
59 | latent_dim: dimensionality used for ODE. Analog of a continous latent state
60 | """
61 | super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, ode_func_net, device)
62 |
63 | self.latent_ode = ODEFunc(input_dim = input_dim,
64 | latent_dim = latent_dim,
65 | ode_func_net = ode_func_net,
66 | device = device)
67 |
68 | self.latent_dim = latent_dim
69 | self.lambda_net = lambda_net
70 | # The computation of poisson likelihood can become numerically unstable.
71 | #The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T]
72 | #Exponent of lambda can also take large values
73 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant
74 | self.const_for_lambda = torch.Tensor([100.]).to(device)
75 |
76 | def extract_poisson_rate(self, augmented, final_result = True):
77 | y, log_lambdas, int_lambda = None, None, None
78 |
79 | assert(augmented.size(-1) == self.latent_dim + self.input_dim)
80 | latent_lam_dim = self.latent_dim // 2
81 |
82 | if len(augmented.size()) == 3:
83 | int_lambda = augmented[:,:,-self.input_dim:]
84 | y_latent_lam = augmented[:,:,:-self.input_dim]
85 |
86 | log_lambdas = self.lambda_net(y_latent_lam[:,:,-latent_lam_dim:])
87 | y = y_latent_lam[:,:,:-latent_lam_dim]
88 |
89 | elif len(augmented.size()) == 4:
90 | int_lambda = augmented[:,:,:,-self.input_dim:]
91 | y_latent_lam = augmented[:,:,:,:-self.input_dim]
92 |
93 | log_lambdas = self.lambda_net(y_latent_lam[:,:,:,-latent_lam_dim:])
94 | y = y_latent_lam[:,:,:,:-latent_lam_dim]
95 |
96 | # Multiply the intergral over lambda by a constant
97 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn)
98 | if final_result:
99 | int_lambda = int_lambda * self.const_for_lambda
100 |
101 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas)
102 | assert(y.size(-1) == latent_lam_dim)
103 |
104 | return y, log_lambdas, int_lambda, y_latent_lam
105 |
106 |
107 | def get_ode_gradient_nn(self, t_local, augmented):
108 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result = False)
109 | dydt_dldt = self.latent_ode(t_local, y_latent_lam)
110 |
111 | log_lam = log_lam - torch.log(self.const_for_lambda)
112 | return torch.cat((dydt_dldt, torch.exp(log_lam)),-1)
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/ode_rnn.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 |
11 | import lib.utils as utils
12 | from lib.encoder_decoder import *
13 | from lib.likelihood_eval import *
14 |
15 | from torch.distributions.multivariate_normal import MultivariateNormal
16 | from torch.distributions.normal import Normal
17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
18 |
19 | from torch.distributions.normal import Normal
20 | from torch.distributions import Independent
21 | from torch.nn.parameter import Parameter
22 | from lib.base_models import Baseline
23 |
24 |
25 | class ODE_RNN(Baseline):
26 | def __init__(self, input_dim, latent_dim, device = torch.device("cpu"),
27 | z0_diffeq_solver = None, n_gru_units = 100, n_units = 100,
28 | concat_mask = False, obsrv_std = 0.1, use_binary_classif = False,
29 | classif_per_tp = False, n_labels = 1, train_classif_w_reconstr = False):
30 |
31 | Baseline.__init__(self, input_dim, latent_dim, device = device,
32 | obsrv_std = obsrv_std, use_binary_classif = use_binary_classif,
33 | classif_per_tp = classif_per_tp,
34 | n_labels = n_labels,
35 | train_classif_w_reconstr = train_classif_w_reconstr)
36 |
37 | ode_rnn_encoder_dim = latent_dim
38 |
39 | self.ode_gru = Encoder_z0_ODE_RNN(
40 | latent_dim = ode_rnn_encoder_dim,
41 | input_dim = (input_dim) * 2, # input and the mask
42 | z0_diffeq_solver = z0_diffeq_solver,
43 | n_gru_units = n_gru_units,
44 | device = device).to(device)
45 |
46 | self.z0_diffeq_solver = z0_diffeq_solver
47 |
48 | self.decoder = nn.Sequential(
49 | nn.Linear(latent_dim, n_units),
50 | nn.Tanh(),
51 | nn.Linear(n_units, input_dim),)
52 |
53 | utils.init_network_weights(self.decoder)
54 |
55 |
56 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps,
57 | mask = None, n_traj_samples = None, mode = None):
58 |
59 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0):
60 | raise Exception("Extrapolation mode not implemented for ODE-RNN")
61 |
62 | # time_steps_to_predict and truth_time_steps should be the same
63 | assert(len(truth_time_steps) == len(time_steps_to_predict))
64 | assert(mask is not None)
65 |
66 | data_and_mask = data
67 | if mask is not None:
68 | data_and_mask = torch.cat([data, mask],-1)
69 |
70 | _, _, latent_ys, _ = self.ode_gru.run_odernn(
71 | data_and_mask, truth_time_steps, run_backwards = False)
72 |
73 | latent_ys = latent_ys.permute(0,2,1,3)
74 | last_hidden = latent_ys[:,:,-1,:]
75 |
76 | #assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)
77 |
78 | outputs = self.decoder(latent_ys)
79 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
80 | first_point = data[:,0,:]
81 | outputs = utils.shift_outputs(outputs, first_point)
82 |
83 | extra_info = {"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:])}
84 |
85 | if self.use_binary_classif:
86 | if self.classif_per_tp:
87 | extra_info["label_predictions"] = self.classifier(latent_ys)
88 | else:
89 | extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1)
90 |
91 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
92 | return outputs, extra_info
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/plotting.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import matplotlib
7 | # matplotlib.use('TkAgg')
8 | # matplotlib.use('Agg')
9 | import matplotlib.pyplot
10 | import matplotlib.pyplot as plt
11 | from matplotlib.lines import Line2D
12 |
13 | import os
14 | from scipy.stats import kde
15 |
16 | import numpy as np
17 | import subprocess
18 | import torch
19 | import lib.utils as utils
20 | import matplotlib.gridspec as gridspec
21 | from lib.utils import get_device
22 |
23 | from lib.encoder_decoder import *
24 | from lib.rnn_baselines import *
25 | from lib.ode_rnn import *
26 | import torch.nn.functional as functional
27 | from torch.distributions.normal import Normal
28 | from lib.latent_ode import LatentODE
29 |
30 | from lib.likelihood_eval import masked_gaussian_log_density
31 | # try:
32 | # import umap
33 | # except:
34 | # print("Couldn't import umap")
35 |
36 | from generate_timeseries import Periodic_1d
37 | from person_activity import PersonActivity
38 |
39 | from lib.utils import compute_loss_all_batches
40 |
41 |
42 | SMALL_SIZE = 14
43 | MEDIUM_SIZE = 16
44 | BIGGER_SIZE = 18
45 | LARGE_SIZE = 22
46 |
47 | def init_fonts(main_font_size = LARGE_SIZE):
48 | plt.rc('font', size=main_font_size) # controls default text sizes
49 | plt.rc('axes', titlesize=main_font_size) # fontsize of the axes title
50 | plt.rc('axes', labelsize=main_font_size - 2) # fontsize of the x and y labels
51 | plt.rc('xtick', labelsize=main_font_size - 2) # fontsize of the tick labels
52 | plt.rc('ytick', labelsize=main_font_size - 2) # fontsize of the tick labels
53 | plt.rc('legend', fontsize=main_font_size - 2) # legend fontsize
54 | plt.rc('figure', titlesize=main_font_size) # fontsize of the figure title
55 |
56 |
57 | def plot_trajectories(ax, traj, time_steps, min_y = None, max_y = None, title = "",
58 | add_to_plot = False, label = None, add_legend = False, dim_to_show = 0,
59 | linestyle = '-', marker = 'o', mask = None, color = None, linewidth = 1):
60 | # expected shape of traj: [n_traj, n_timesteps, n_dims]
61 | # The function will produce one line per trajectory (n_traj lines in total)
62 | if not add_to_plot:
63 | ax.cla()
64 | ax.set_title(title)
65 | ax.set_xlabel('Time')
66 | ax.set_ylabel('x')
67 |
68 | if min_y is not None:
69 | ax.set_ylim(bottom = min_y)
70 |
71 | if max_y is not None:
72 | ax.set_ylim(top = max_y)
73 |
74 | for i in range(traj.size()[0]):
75 | d = traj[i].cpu().numpy()[:, dim_to_show]
76 | ts = time_steps.cpu().numpy()
77 | if mask is not None:
78 | m = mask[i].cpu().numpy()[:, dim_to_show]
79 | d = d[m == 1]
80 | ts = ts[m == 1]
81 | ax.plot(ts, d, linestyle = linestyle, label = label, marker=marker, color = color, linewidth = linewidth)
82 |
83 | if add_legend:
84 | ax.legend()
85 |
86 |
87 | def plot_std(ax, traj, traj_std, time_steps, min_y = None, max_y = None, title = "",
88 | add_to_plot = False, label = None, alpha=0.2, color = None):
89 |
90 | # take only the first (and only?) dimension
91 | mean_minus_std = (traj - traj_std).cpu().numpy()[:, :, 0]
92 | mean_plus_std = (traj + traj_std).cpu().numpy()[:, :, 0]
93 |
94 | for i in range(traj.size()[0]):
95 | ax.fill_between(time_steps.cpu().numpy(), mean_minus_std[i], mean_plus_std[i],
96 | alpha=alpha, color = color)
97 |
98 |
99 |
100 | def plot_vector_field(ax, odefunc, latent_dim, device):
101 | # Code borrowed from https://github.com/rtqichen/ffjord/blob/29c016131b702b307ceb05c70c74c6e802bb8a44/diagnostics/viz_toy.py
102 | K = 13j
103 | y, x = np.mgrid[-6:6:K, -6:6:K]
104 | K = int(K.imag)
105 | zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
106 | if latent_dim > 2:
107 | # Plots dimensions 0 and 2
108 | zs = torch.cat((zs, torch.zeros(K * K, latent_dim-2)), 1)
109 | dydt = odefunc(0, zs)
110 | dydt = -dydt.cpu().detach().numpy()
111 | if latent_dim > 2:
112 | dydt = dydt[:,:2]
113 |
114 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
115 | dydt = (dydt / mag)
116 | dydt = dydt.reshape(K, K, 2)
117 |
118 | ax.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], #color = dydt[:, :, 0],
119 | cmap="coolwarm", linewidth=2)
120 |
121 | # ax.quiver(
122 | # x, y, dydt[:, :, 0], dydt[:, :, 1],
123 | # np.exp(logmag), cmap="coolwarm", pivot="mid", scale = 100,
124 | # )
125 | ax.set_xlim(-6, 6)
126 | ax.set_ylim(-6, 6)
127 | #ax.axis("off")
128 |
129 |
130 |
131 | def get_meshgrid(npts, int_y1, int_y2):
132 | min_y1, max_y1 = int_y1
133 | min_y2, max_y2 = int_y2
134 |
135 | y1_grid = np.linspace(min_y1, max_y1, npts)
136 | y2_grid = np.linspace(min_y2, max_y2, npts)
137 |
138 | xx, yy = np.meshgrid(y1_grid, y2_grid)
139 |
140 | flat_inputs = np.concatenate((np.expand_dims(xx.flatten(),1), np.expand_dims(yy.flatten(),1)), 1)
141 | flat_inputs = torch.from_numpy(flat_inputs).float()
142 |
143 | return xx, yy, flat_inputs
144 |
145 |
146 | def add_white(cmap):
147 | cmaplist = [cmap(i) for i in range(cmap.N)]
148 | # force the first color entry to be grey
149 | cmaplist[0] = (1.,1.,1.,1.0)
150 | # create the new map
151 | cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)
152 | return cmap
153 |
154 |
155 | class Visualizations():
156 | def __init__(self, device):
157 | self.init_visualization()
158 | init_fonts(SMALL_SIZE)
159 | self.device = device
160 |
161 | def init_visualization(self):
162 | self.fig = plt.figure(figsize=(12, 7), facecolor='white')
163 |
164 | self.ax_traj = []
165 | for i in range(1,4):
166 | self.ax_traj.append(self.fig.add_subplot(2,3,i, frameon=False))
167 |
168 | # self.ax_density = []
169 | # for i in range(4,7):
170 | # self.ax_density.append(self.fig.add_subplot(3,3,i, frameon=False))
171 |
172 | #self.ax_samples_same_traj = self.fig.add_subplot(3,3,7, frameon=False)
173 | self.ax_latent_traj = self.fig.add_subplot(2,3,4, frameon=False)
174 | self.ax_vector_field = self.fig.add_subplot(2,3,5, frameon=False)
175 | self.ax_traj_from_prior = self.fig.add_subplot(2,3,6, frameon=False)
176 |
177 | self.plot_limits = {}
178 | plt.show(block=False)
179 |
180 | def set_plot_lims(self, ax, name):
181 | if name not in self.plot_limits:
182 | self.plot_limits[name] = (ax.get_xlim(), ax.get_ylim())
183 | return
184 |
185 | xlim, ylim = self.plot_limits[name]
186 | ax.set_xlim(xlim)
187 | ax.set_ylim(ylim)
188 |
189 | def draw_one_density_plot(self, ax, model, data_dict, traj_id,
190 | multiply_by_poisson = False):
191 |
192 | scale = 5
193 | cmap = add_white(plt.cm.get_cmap('Blues', 9)) # plt.cm.BuGn_r
194 | cmap2 = add_white(plt.cm.get_cmap('Reds', 9)) # plt.cm.BuGn_r
195 | #cmap = plt.cm.get_cmap('viridis')
196 |
197 | data = data_dict["data_to_predict"]
198 | time_steps = data_dict["tp_to_predict"]
199 | mask = data_dict["mask_predicted_data"]
200 |
201 | observed_data = data_dict["observed_data"]
202 | observed_time_steps = data_dict["observed_tp"]
203 | observed_mask = data_dict["observed_mask"]
204 |
205 | npts = 50
206 | xx, yy, z0_grid = get_meshgrid(npts = npts, int_y1 = (-scale,scale), int_y2 = (-scale,scale))
207 | z0_grid = z0_grid.to(get_device(data))
208 |
209 | if model.latent_dim > 2:
210 | z0_grid = torch.cat((z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim-2)), 1)
211 |
212 | if model.use_poisson_proc:
213 | n_traj, n_dims = z0_grid.size()
214 | # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda
215 | zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim]).to(get_device(data))
216 | z0_grid_aug = torch.cat((z0_grid, zeros), -1)
217 | else:
218 | z0_grid_aug = z0_grid
219 |
220 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
221 | sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps)
222 |
223 | if model.use_poisson_proc:
224 | sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
225 |
226 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.)
227 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)
228 |
229 | pred_x = model.decoder(sol_y)
230 |
231 | # Plot density for one trajectory
232 | one_traj = data[traj_id]
233 | mask_one_traj = None
234 | if mask is not None:
235 | mask_one_traj = mask[traj_id].unsqueeze(0)
236 | mask_one_traj = mask_one_traj.repeat(npts**2,1,1).unsqueeze(0)
237 |
238 | ax.cla()
239 |
240 | # Plot: prior
241 | prior_density_grid = model.z0_prior.log_prob(z0_grid.unsqueeze(0)).squeeze(0)
242 | # Sum the density over two dimensions
243 | prior_density_grid = torch.sum(prior_density_grid, -1)
244 |
245 | # =================================================
246 | # Plot: p(x | y(t0))
247 |
248 | masked_gaussian_log_density_grid = masked_gaussian_log_density(pred_x,
249 | one_traj.repeat(npts**2,1,1).unsqueeze(0),
250 | mask = mask_one_traj,
251 | obsrv_std = model.obsrv_std).squeeze(-1)
252 |
253 | # Plot p(t | y(t0))
254 | if model.use_poisson_proc:
255 | poisson_info = {}
256 | poisson_info["int_lambda"] = int_lambda[:,:,-1,:]
257 | poisson_info["log_lambda_y"] = log_lambda_y
258 |
259 | poisson_log_density_grid = compute_poisson_proc_likelihood(
260 | one_traj.repeat(npts**2,1,1).unsqueeze(0),
261 | pred_x, poisson_info, mask = mask_one_traj)
262 | poisson_log_density_grid = poisson_log_density_grid.squeeze(0)
263 |
264 | # =================================================
265 | # Plot: p(x , y(t0))
266 |
267 | log_joint_density = prior_density_grid + masked_gaussian_log_density_grid
268 | if multiply_by_poisson:
269 | log_joint_density = log_joint_density + poisson_log_density_grid
270 |
271 | density_grid = torch.exp(log_joint_density)
272 |
273 | density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
274 | density_grid = density_grid.cpu().numpy()
275 |
276 | ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1)
277 |
278 | # =================================================
279 | # Plot: q(y(t0)| x)
280 | #self.ax_density.set_title("Red: q(y(t0) | x) Blue: p(x, y(t0))")
281 | ax.set_xlabel('z1(t0)')
282 | ax.set_ylabel('z2(t0)')
283 |
284 | data_w_mask = observed_data[traj_id].unsqueeze(0)
285 | if observed_mask is not None:
286 | data_w_mask = torch.cat((data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1)
287 | z0_mu, z0_std = model.encoder_z0(
288 | data_w_mask, observed_time_steps)
289 |
290 | if model.use_poisson_proc:
291 | z0_mu = z0_mu[:, :, :model.latent_dim]
292 | z0_std = z0_std[:, :, :model.latent_dim]
293 |
294 | q_z0 = Normal(z0_mu, z0_std)
295 |
296 | q_density_grid = q_z0.log_prob(z0_grid)
297 | # Sum the density over two dimensions
298 | q_density_grid = torch.sum(q_density_grid, -1)
299 | density_grid = torch.exp(q_density_grid)
300 |
301 | density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
302 | density_grid = density_grid.cpu().numpy()
303 |
304 | ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)
305 |
306 |
307 |
308 | def draw_all_plots_one_dim(self, data_dict, model,
309 | plot_name = "", save = False, experimentID = 0.):
310 |
311 | data = data_dict["data_to_predict"]
312 | time_steps = data_dict["tp_to_predict"]
313 | mask = data_dict["mask_predicted_data"]
314 |
315 | observed_data = data_dict["observed_data"]
316 | observed_time_steps = data_dict["observed_tp"]
317 | observed_mask = data_dict["observed_mask"]
318 |
319 | device = get_device(time_steps)
320 |
321 | time_steps_to_predict = time_steps
322 | if isinstance(model, LatentODE):
323 | # sample at the original time points
324 | time_steps_to_predict = utils.linspace_vector(time_steps[0], time_steps[-1], 100).to(device)
325 |
326 | reconstructions, info = model.get_reconstruction(time_steps_to_predict,
327 | observed_data, observed_time_steps, mask = observed_mask, n_traj_samples = 10)
328 |
329 | n_traj_to_show = 3
330 | # plot only 10 trajectories
331 | data_for_plotting = observed_data[:n_traj_to_show]
332 | mask_for_plotting = observed_mask[:n_traj_to_show]
333 | reconstructions_for_plotting = reconstructions.mean(dim=0)[:n_traj_to_show]
334 | reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show]
335 |
336 | dim_to_show = 0
337 | max_y = max(
338 | data_for_plotting[:,:,dim_to_show].cpu().numpy().max(),
339 | reconstructions[:,:,dim_to_show].cpu().numpy().max())
340 | min_y = min(
341 | data_for_plotting[:,:,dim_to_show].cpu().numpy().min(),
342 | reconstructions[:,:,dim_to_show].cpu().numpy().min())
343 |
344 | ############################################
345 | # Plot reconstructions, true postrior and approximate posterior
346 |
347 | cmap = plt.cm.get_cmap('Set1')
348 | for traj_id in range(3):
349 | # Plot observations
350 | plot_trajectories(self.ax_traj[traj_id],
351 | data_for_plotting[traj_id].unsqueeze(0), observed_time_steps,
352 | mask = mask_for_plotting[traj_id].unsqueeze(0),
353 | min_y = min_y, max_y = max_y, #title="True trajectories",
354 | marker = 'o', linestyle='', dim_to_show = dim_to_show,
355 | color = cmap(2))
356 | # Plot reconstructions
357 | plot_trajectories(self.ax_traj[traj_id],
358 | reconstructions_for_plotting[traj_id].unsqueeze(0), time_steps_to_predict,
359 | min_y = min_y, max_y = max_y, title="Sample {} (data space)".format(traj_id), dim_to_show = dim_to_show,
360 | add_to_plot = True, marker = '', color = cmap(3), linewidth = 3)
361 | # Plot variance estimated over multiple samples from approx posterior
362 | plot_std(self.ax_traj[traj_id],
363 | reconstructions_for_plotting[traj_id].unsqueeze(0), reconstr_std[traj_id].unsqueeze(0),
364 | time_steps_to_predict, alpha=0.5, color = cmap(3))
365 | self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id))
366 |
367 | # Plot true posterior and approximate posterior
368 | # self.draw_one_density_plot(self.ax_density[traj_id],
369 | # model, data_dict, traj_id = traj_id,
370 | # multiply_by_poisson = False)
371 | # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id))
372 | # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id))
373 | ############################################
374 | # Get several samples for the same trajectory
375 | # one_traj = data_for_plotting[:1]
376 | # first_point = one_traj[:,0]
377 |
378 | # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict,
379 | # observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5)
380 | # samples_same_traj = samples_same_traj.squeeze(1)
381 |
382 | # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '')
383 | # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "",
384 | # label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)")
385 |
386 | ############################################
387 | # Plot trajectories from prior
388 |
389 | try:
390 | if isinstance(model, LatentODE):
391 | torch.manual_seed(1991)
392 | np.random.seed(1991)
393 |
394 | traj_from_prior = model.sample_traj_from_prior(time_steps_to_predict, n_traj_samples = 3)
395 | # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension
396 | traj_from_prior = traj_from_prior.squeeze(1)
397 |
398 | plot_trajectories(self.ax_traj_from_prior, traj_from_prior, time_steps_to_predict,
399 | marker = '', linewidth = 3)
400 | self.ax_traj_from_prior.set_title("Samples from prior (data space)", pad = 20)
401 | #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior")
402 | except AttributeError:
403 | pass
404 | ################################################
405 |
406 | # Plot z0
407 | # first_point_mu, first_point_std, first_point_enc = info["first_point"]
408 |
409 | # dim1 = 0
410 | # dim2 = 1
411 | # self.ax_z0.cla()
412 | # # first_point_enc shape: [1, n_traj, n_dims]
413 | # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2])
414 | # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)")
415 | # self.ax_z0.set_xlabel('dim {}'.format(dim1))
416 | # self.ax_z0.set_ylabel('dim {}'.format(dim2))
417 |
418 | try:
419 | ################################################
420 | # Show vector field
421 | self.ax_vector_field.cla()
422 | plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func, model.latent_dim, device)
423 | self.ax_vector_field.set_title("Slice of vector field (latent space)", pad = 20)
424 | self.set_plot_lims(self.ax_vector_field, "vector_field")
425 | #self.ax_vector_field.set_ylim((-0.5, 1.5))
426 | except AttributeError:
427 | pass
428 |
429 | ################################################
430 | # Plot trajectories in the latent space
431 |
432 | # shape before [1, n_traj, n_tp, n_latent_dims]
433 | # Take only the first sample from approx posterior
434 | latent_traj = info["latent_traj"][0,:n_traj_to_show]
435 | # shape before permute: [1, n_tp, n_latent_dims]
436 |
437 | self.ax_latent_traj.cla()
438 | cmap = plt.cm.get_cmap('Accent')
439 | n_latent_dims = latent_traj.size(-1)
440 |
441 | custom_labels = {}
442 | for i in range(n_latent_dims):
443 | col = cmap(i)
444 | plot_trajectories(self.ax_latent_traj, latent_traj, time_steps_to_predict,
445 | title="Latent trajectories z(t) (latent space)", dim_to_show = i, color = col,
446 | marker = '', add_to_plot = True,
447 | linewidth = 3)
448 | custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col)
449 |
450 | self.ax_latent_traj.set_ylabel("z")
451 | self.ax_latent_traj.set_title("Latent trajectories z(t) (latent space)", pad = 20)
452 | self.ax_latent_traj.legend(custom_labels.values(), custom_labels.keys(), loc = 'lower left')
453 | self.set_plot_lims(self.ax_latent_traj, "latent_traj")
454 |
455 | ################################################
456 |
457 | self.fig.tight_layout()
458 | plt.draw()
459 |
460 | if save:
461 | dirname = "plots/" + str(experimentID) + "/"
462 | os.makedirs(dirname, exist_ok=True)
463 | self.fig.savefig(dirname + plot_name)
464 |
465 |
466 |
467 |
468 |
469 |
--------------------------------------------------------------------------------
/latentode/latent_ode/lib/rnn_baselines.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 |
11 | import lib.utils as utils
12 | from lib.utils import get_device
13 | from lib.encoder_decoder import *
14 | from lib.likelihood_eval import *
15 |
16 | from torch.distributions.multivariate_normal import MultivariateNormal
17 | from torch.distributions.normal import Normal
18 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
19 |
20 | from torch.distributions.normal import Normal
21 | from torch.distributions import Independent
22 | from torch.nn.parameter import Parameter
23 | from lib.base_models import Baseline, VAE_Baseline
24 |
25 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
26 | # Exponential decay of the hidden states for RNN
27 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/
28 |
29 | # Exp decay between hidden states
30 | class GRUCellExpDecay(RNNCellBase):
31 | def __init__(self, input_size, input_size_for_decay, hidden_size, device, bias=True):
32 | super(GRUCellExpDecay, self).__init__(input_size, hidden_size, bias, num_chunks=3)
33 |
34 | self.device = device
35 | self.input_size_for_decay = input_size_for_decay
36 | self.decay = nn.Sequential(nn.Linear(input_size_for_decay, 1),)
37 | utils.init_network_weights(self.decay)
38 |
39 | def gru_exp_decay_cell(self, input, hidden, w_ih, w_hh, b_ih, b_hh):
40 | # INPORTANT: assumes that cum delta t is the last dimension of the input
41 | batch_size, n_dims = input.size()
42 |
43 | # "input" contains the data, mask and also cumulative deltas for all inputs
44 | cum_delta_ts = input[:, -self.input_size_for_decay:]
45 | data = input[:, :-self.input_size_for_decay]
46 |
47 | decay = torch.exp( - torch.min(torch.max(
48 | torch.zeros([1]).to(self.device), self.decay(cum_delta_ts)),
49 | torch.ones([1]).to(self.device) * 1000 ))
50 |
51 | hidden = hidden * decay
52 |
53 | gi = torch.mm(data, w_ih.t()) + b_ih
54 | gh = torch.mm(hidden, w_hh.t()) + b_hh
55 | i_r, i_i, i_n = gi.chunk(3, 1)
56 | h_r, h_i, h_n = gh.chunk(3, 1)
57 |
58 | resetgate = torch.sigmoid(i_r + h_r)
59 | inputgate = torch.sigmoid(i_i + h_i)
60 | newgate = torch.tanh(i_n + resetgate * h_n)
61 | hy = newgate + inputgate * (hidden - newgate)
62 | return hy
63 |
64 | def forward(self, input, hx=None):
65 | # type: (Tensor, Optional[Tensor]) -> Tensor
66 | #self.check_forward_input(input)
67 | if hx is None:
68 | hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
69 | #self.check_forward_hidden(input, hx, '')
70 |
71 | return self.gru_exp_decay_cell(
72 | input, hx,
73 | self.weight_ih, self.weight_hh,
74 | self.bias_ih, self.bias_hh
75 | )
76 |
77 |
78 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
79 | # Imputation with a weighed average of previous value and empirical mean
80 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/
81 | def get_cum_delta_ts(data, delta_ts, mask):
82 | n_traj, n_tp, n_dims = data.size()
83 |
84 | cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
85 | missing_index = np.where(mask.cpu().numpy() == 0)
86 |
87 | for idx in range(missing_index[0].shape[0]):
88 | i = missing_index[0][idx]
89 | j = missing_index[1][idx]
90 | k = missing_index[2][idx]
91 |
92 | if j != 0 and j != (n_tp-1):
93 | cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k]
94 | cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize
95 |
96 | return cum_delta_ts
97 |
98 |
99 | # adapted from GRU-D implementation: https://github.com/zhiyongc/GRU-D/
100 | # very slow
101 | def impute_using_input_decay(data, delta_ts, mask, w_input_decay, b_input_decay):
102 | n_traj, n_tp, n_dims = data.size()
103 |
104 | cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
105 | missing_index = np.where(mask.cpu().numpy() == 0)
106 |
107 | data_last_obsv = np.copy(data.cpu().numpy())
108 | for idx in range(missing_index[0].shape[0]):
109 | i = missing_index[0][idx]
110 | j = missing_index[1][idx]
111 | k = missing_index[2][idx]
112 |
113 | if j != 0 and j != (n_tp-1):
114 | cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k]
115 | if j != 0:
116 | data_last_obsv[i,j,k] = data_last_obsv[i,j-1,k] # last observation
117 | cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize
118 |
119 | data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data))
120 |
121 | zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data))
122 | decay = torch.exp( - torch.min( torch.max(zeros,
123 | w_input_decay * cum_delta_ts + b_input_decay), zeros + 1000 ))
124 |
125 | data_means = torch.mean(data, 1).unsqueeze(1)
126 |
127 | data_imputed = data * mask + (1-mask) * (decay * data_last_obsv + (1-decay) * data_means)
128 | return data_imputed
129 |
130 |
131 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
132 |
133 | def run_rnn(inputs, delta_ts, cell, first_hidden=None,
134 | mask = None, feed_previous=False, n_steps=0,
135 | decoder = None, input_decay_params = None,
136 | feed_previous_w_prob = 0.,
137 | masked_update = True):
138 | if (feed_previous or feed_previous_w_prob) and decoder is None:
139 | raise Exception("feed_previous is set to True -- please specify RNN decoder")
140 |
141 | if n_steps == 0:
142 | n_steps = inputs.size(1)
143 |
144 | if (feed_previous or feed_previous_w_prob) and mask is None:
145 | mask = torch.ones((inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs))
146 |
147 | if isinstance(cell, GRUCellExpDecay):
148 | cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask)
149 |
150 | if input_decay_params is not None:
151 | w_input_decay, b_input_decay = input_decay_params
152 | inputs = impute_using_input_decay(inputs, delta_ts, mask,
153 | w_input_decay, b_input_decay)
154 |
155 | all_hiddens = []
156 | hidden = first_hidden
157 |
158 | if hidden is not None:
159 | all_hiddens.append(hidden)
160 | n_steps -= 1
161 |
162 | for i in range(n_steps):
163 | delta_t = delta_ts[:,i]
164 | if i == 0:
165 | rnn_input = inputs[:,i]
166 | elif feed_previous:
167 | rnn_input = decoder(hidden)
168 | elif feed_previous_w_prob > 0:
169 | feed_prev = np.random.uniform() > feed_previous_w_prob
170 | if feed_prev:
171 | rnn_input = decoder(hidden)
172 | else:
173 | rnn_input = inputs[:,i]
174 | else:
175 | rnn_input = inputs[:,i]
176 |
177 | if mask is not None:
178 | mask_i = mask[:,i,:]
179 | rnn_input = torch.cat((rnn_input, mask_i), -1)
180 |
181 | if isinstance(cell, GRUCellExpDecay):
182 | cum_delta_t = cum_delta_ts[:,i]
183 | input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1)
184 | else:
185 | input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1)
186 |
187 | prev_hidden = hidden
188 | hidden = cell(input_w_t, hidden)
189 |
190 | if masked_update and (mask is not None) and (prev_hidden is not None):
191 | # update only the hidden states for hidden state only if at least one feature is present for the current time point
192 | summed_mask = (torch.sum(mask_i, -1, keepdim = True) > 0).float()
193 | assert(not torch.isnan(summed_mask).any())
194 | hidden = summed_mask * hidden + (1-summed_mask) * prev_hidden
195 |
196 | all_hiddens.append(hidden)
197 |
198 | all_hiddens = torch.stack(all_hiddens, 0)
199 | all_hiddens = all_hiddens.permute(1,0,2).unsqueeze(0)
200 | return hidden, all_hiddens
201 |
202 |
203 |
204 |
205 | class Classic_RNN(Baseline):
206 | def __init__(self, input_dim, latent_dim, device,
207 | concat_mask = False, obsrv_std = 0.1,
208 | use_binary_classif = False,
209 | linear_classifier = False,
210 | classif_per_tp = False,
211 | input_space_decay = False,
212 | cell = "gru", n_units = 100,
213 | n_labels = 1,
214 | train_classif_w_reconstr = False):
215 |
216 | super(Classic_RNN, self).__init__(input_dim, latent_dim, device,
217 | obsrv_std = obsrv_std,
218 | use_binary_classif = use_binary_classif,
219 | classif_per_tp = classif_per_tp,
220 | linear_classifier = linear_classifier,
221 | n_labels = n_labels,
222 | train_classif_w_reconstr = train_classif_w_reconstr)
223 |
224 | self.concat_mask = concat_mask
225 |
226 | encoder_dim = int(input_dim)
227 | if concat_mask:
228 | encoder_dim = encoder_dim * 2
229 |
230 | self.decoder = nn.Sequential(
231 | nn.Linear(latent_dim, n_units),
232 | nn.Tanh(),
233 | nn.Linear(n_units, input_dim),)
234 |
235 | #utils.init_network_weights(self.encoder)
236 | utils.init_network_weights(self.decoder)
237 |
238 | if cell == "gru":
239 | self.rnn_cell = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t
240 | elif cell == "expdecay":
241 | self.rnn_cell = GRUCellExpDecay(
242 | input_size = encoder_dim,
243 | input_size_for_decay = input_dim,
244 | hidden_size = latent_dim,
245 | device = device)
246 | else:
247 | raise Exception("Unknown RNN cell: {}".format(cell))
248 |
249 | if input_space_decay:
250 | self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device)
251 | self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device)
252 | self.input_space_decay = input_space_decay
253 |
254 | self.z0_net = lambda hidden_state: hidden_state
255 |
256 |
257 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps,
258 | mask = None, n_traj_samples = 1, mode = None):
259 |
260 | assert(mask is not None)
261 | n_traj, n_tp, n_dims = data.size()
262 |
263 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0):
264 | raise Exception("Extrapolation mode not implemented for RNN models")
265 |
266 | # for classic RNN time_steps_to_predict should be the same as truth_time_steps
267 | assert(len(truth_time_steps) == len(time_steps_to_predict))
268 |
269 | batch_size = data.size(0)
270 | zero_delta_t = torch.Tensor([0.]).to(self.device)
271 |
272 | delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
273 | delta_ts = torch.cat((delta_ts, zero_delta_t))
274 | if len(delta_ts.size()) == 1:
275 | # delta_ts are shared for all trajectories in a batch
276 | assert(data.size(1) == delta_ts.size(0))
277 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))
278 |
279 | input_decay_params = None
280 | if self.input_space_decay:
281 | input_decay_params = (self.w_input_decay, self.b_input_decay)
282 |
283 | if mask is not None:
284 | utils.check_mask(data, mask)
285 |
286 | hidden_state, all_hiddens = run_rnn(data, delta_ts,
287 | cell = self.rnn_cell, mask = mask,
288 | input_decay_params = input_decay_params,
289 | feed_previous_w_prob = (0. if self.use_binary_classif else 0.5),
290 | decoder = self.decoder)
291 |
292 | outputs = self.decoder(all_hiddens)
293 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
294 | first_point = data[:,0,:]
295 | outputs = utils.shift_outputs(outputs, first_point)
296 |
297 | extra_info = {"first_point": (hidden_state.unsqueeze(0), 0.0, hidden_state.unsqueeze(0))}
298 |
299 | if self.use_binary_classif:
300 | if self.classif_per_tp:
301 | extra_info["label_predictions"] = self.classifier(all_hiddens)
302 | else:
303 | extra_info["label_predictions"] = self.classifier(hidden_state).reshape(1,-1)
304 |
305 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
306 | return outputs, extra_info
307 |
308 |
309 |
310 | class RNN_VAE(VAE_Baseline):
311 | def __init__(self, input_dim, latent_dim, rec_dims,
312 | z0_prior, device,
313 | concat_mask = False, obsrv_std = 0.1,
314 | input_space_decay = False,
315 | use_binary_classif = False,
316 | classif_per_tp =False,
317 | linear_classifier = False,
318 | cell = "gru", n_units = 100,
319 | n_labels = 1,
320 | train_classif_w_reconstr = False):
321 |
322 | super(RNN_VAE, self).__init__(
323 | input_dim = input_dim, latent_dim = latent_dim,
324 | z0_prior = z0_prior,
325 | device = device, obsrv_std = obsrv_std,
326 | use_binary_classif = use_binary_classif,
327 | classif_per_tp = classif_per_tp,
328 | linear_classifier = linear_classifier,
329 | n_labels = n_labels,
330 | train_classif_w_reconstr = train_classif_w_reconstr)
331 |
332 | self.concat_mask = concat_mask
333 |
334 | encoder_dim = int(input_dim)
335 | if concat_mask:
336 | encoder_dim = encoder_dim * 2
337 |
338 | if cell == "gru":
339 | self.rnn_cell_enc = GRUCell(encoder_dim + 1, rec_dims) # +1 for delta t
340 | self.rnn_cell_dec = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t
341 | elif cell == "expdecay":
342 | self.rnn_cell_enc = GRUCellExpDecay(
343 | input_size = encoder_dim,
344 | input_size_for_decay = input_dim,
345 | hidden_size = rec_dims,
346 | device = device)
347 | self.rnn_cell_dec = GRUCellExpDecay(
348 | input_size = encoder_dim,
349 | input_size_for_decay = input_dim,
350 | hidden_size = latent_dim,
351 | device = device)
352 | else:
353 | raise Exception("Unknown RNN cell: {}".format(cell))
354 |
355 | self.z0_net = nn.Sequential(
356 | nn.Linear(rec_dims, n_units),
357 | nn.Tanh(),
358 | nn.Linear(n_units, latent_dim * 2),)
359 | utils.init_network_weights(self.z0_net)
360 |
361 | self.decoder = nn.Sequential(
362 | nn.Linear(latent_dim, n_units),
363 | nn.Tanh(),
364 | nn.Linear(n_units, input_dim),)
365 |
366 | #utils.init_network_weights(self.encoder)
367 | utils.init_network_weights(self.decoder)
368 |
369 | if input_space_decay:
370 | self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device)
371 | self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(self.device)
372 | self.input_space_decay = input_space_decay
373 |
374 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps,
375 | mask = None, n_traj_samples = 1, mode = None):
376 |
377 | assert(mask is not None)
378 |
379 | batch_size = data.size(0)
380 | zero_delta_t = torch.Tensor([0.]).to(self.device)
381 |
382 | # run encoder backwards
383 | run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1])
384 |
385 | if run_backwards:
386 | # Look at data in the reverse order: from later points to the first
387 | data = utils.reverse(data)
388 | mask = utils.reverse(mask)
389 |
390 | delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
391 | if run_backwards:
392 | # we are going backwards in time
393 | delta_ts = utils.reverse(delta_ts)
394 |
395 |
396 | delta_ts = torch.cat((delta_ts, zero_delta_t))
397 | if len(delta_ts.size()) == 1:
398 | # delta_ts are shared for all trajectories in a batch
399 | assert(data.size(1) == delta_ts.size(0))
400 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))
401 |
402 | input_decay_params = None
403 | if self.input_space_decay:
404 | input_decay_params = (self.w_input_decay, self.b_input_decay)
405 |
406 | hidden_state, _ = run_rnn(data, delta_ts,
407 | cell = self.rnn_cell_enc, mask = mask,
408 | input_decay_params = input_decay_params)
409 |
410 | z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state))
411 | z0_std = z0_std.abs()
412 | z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std)
413 |
414 | # Decoder # # # # # # # # # # # # # # # # # # # #
415 | delta_ts = torch.cat((zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1]))
416 | if len(delta_ts.size()) == 1:
417 | delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))
418 |
419 | _, all_hiddens = run_rnn(data, delta_ts,
420 | cell = self.rnn_cell_dec,
421 | first_hidden = z0_sample, feed_previous = True,
422 | n_steps = time_steps_to_predict.size(0),
423 | decoder = self.decoder,
424 | input_decay_params = input_decay_params)
425 |
426 | outputs = self.decoder(all_hiddens)
427 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
428 | first_point = data[:,0,:]
429 | outputs = utils.shift_outputs(outputs, first_point)
430 |
431 | extra_info = {"first_point": (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0))}
432 |
433 | if self.use_binary_classif:
434 | if self.classif_per_tp:
435 | extra_info["label_predictions"] = self.classifier(all_hiddens)
436 | else:
437 | extra_info["label_predictions"] = self.classifier(z0_mean).reshape(1,-1)
438 |
439 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
440 | return outputs, extra_info
441 |
442 |
443 |
444 |
--------------------------------------------------------------------------------
/latentode/latent_ode/mujoco_physics.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Authors: Yulia Rubanova and Ricky Chen
4 | ###########################
5 |
6 | import os
7 | import numpy as np
8 | import torch
9 | from lib.utils import get_dict_template
10 | import lib.utils as utils
11 | from torchvision.datasets.utils import download_url
12 |
13 | class HopperPhysics(object):
14 |
15 | T = 200
16 | D = 14
17 |
18 | n_training_samples = 10000
19 |
20 | training_file = 'training.pt'
21 |
22 | def __init__(self, root, download = True, generate=False, device = torch.device("cpu")):
23 | self.root = root
24 | if download:
25 | self._download()
26 |
27 | if generate:
28 | self._generate_dataset()
29 |
30 | if not self._check_exists():
31 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
32 |
33 | data_file = os.path.join(self.data_folder, self.training_file)
34 |
35 | self.data = torch.Tensor(torch.load(data_file)).to(device)
36 | self.data, self.data_min, self.data_max = utils.normalize_data(self.data)
37 |
38 | self.device =device
39 |
40 | def visualize(self, traj, plot_name = 'traj', dirname='hopper_imgs', video_name = None):
41 | r"""Generates images of the trajectory and stores them as /traj-.jpg"""
42 |
43 | T, D = traj.size()
44 |
45 | traj = traj.cpu() * self.data_max.cpu() + self.data_min.cpu()
46 |
47 | try:
48 | from dm_control import suite # noqa: F401
49 | except ImportError as e:
50 | raise Exception('Deepmind Control Suite is required to visualize the dataset.') from e
51 |
52 | try:
53 | from PIL import Image # noqa: F401
54 | except ImportError as e:
55 | raise Exception('PIL is required to visualize the dataset.') from e
56 |
57 | def save_image(data, filename):
58 | im = Image.fromarray(data)
59 | im.save(filename)
60 |
61 | os.makedirs(dirname, exist_ok=True)
62 |
63 | env = suite.load('hopper', 'stand')
64 | physics = env.physics
65 |
66 | for t in range(T):
67 | with physics.reset_context():
68 | physics.data.qpos[:] = traj[t, :D // 2]
69 | physics.data.qvel[:] = traj[t, D // 2:]
70 | save_image(
71 | physics.render(height=480, width=640, camera_id=0),
72 | os.path.join(dirname, plot_name + '-{:03d}.jpg'.format(t))
73 | )
74 |
75 | def _generate_dataset(self):
76 | if self._check_exists():
77 | return
78 | os.makedirs(self.data_folder, exist_ok=True)
79 | print('Generating dataset...')
80 | train_data = self._generate_random_trajectories(self.n_training_samples)
81 | torch.save(train_data, os.path.join(self.data_folder, self.training_file))
82 |
83 | def _download(self):
84 | if self._check_exists():
85 | return
86 |
87 | print("Downloading the dataset [325MB] ...")
88 | os.makedirs(self.data_folder, exist_ok=True)
89 | url = "http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt"
90 | download_url(url, self.data_folder, "training.pt", None)
91 |
92 | def _generate_random_trajectories(self, n_samples):
93 |
94 | try:
95 | from dm_control import suite # noqa: F401
96 | except ImportError as e:
97 | raise Exception('Deepmind Control Suite is required to generate the dataset.') from e
98 |
99 | env = suite.load('hopper', 'stand')
100 | physics = env.physics
101 |
102 | # Store the state of the RNG to restore later.
103 | st0 = np.random.get_state()
104 | np.random.seed(123)
105 |
106 | data = np.zeros((n_samples, self.T, self.D))
107 | for i in range(n_samples):
108 | with physics.reset_context():
109 | # x and z positions of the hopper. We want z > 0 for the hopper to stay above ground.
110 | physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2)
111 | physics.data.qpos[2:] = np.random.uniform(-2, 2, size=physics.data.qpos[2:].shape)
112 | physics.data.qvel[:] = np.random.uniform(-5, 5, size=physics.data.qvel.shape)
113 | for t in range(self.T):
114 | data[i, t, :self.D // 2] = physics.data.qpos
115 | data[i, t, self.D // 2:] = physics.data.qvel
116 | physics.step()
117 |
118 | # Restore RNG.
119 | np.random.set_state(st0)
120 | return data
121 |
122 | def _check_exists(self):
123 | return os.path.exists(os.path.join(self.data_folder, self.training_file))
124 |
125 | @property
126 | def data_folder(self):
127 | return os.path.join(self.root, self.__class__.__name__)
128 |
129 | # def __getitem__(self, index):
130 | # return self.data[index]
131 |
132 | def get_dataset(self):
133 | return self.data
134 |
135 | def __len__(self):
136 | return len(self.data)
137 |
138 | def size(self, ind = None):
139 | if ind is not None:
140 | return self.data.shape[ind]
141 | return self.data.shape
142 |
143 | def __repr__(self):
144 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
145 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
146 | fmt_str += ' Root Location: {}\n'.format(self.root)
147 | return fmt_str
148 |
149 |
--------------------------------------------------------------------------------
/latentode/latent_ode/person_activity.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Authors: Yulia Rubanova and Ricky Chen
4 | ###########################
5 |
6 | import os
7 |
8 | import lib.utils as utils
9 | import numpy as np
10 | import tarfile
11 | import torch
12 | from torch.utils.data import DataLoader
13 | from torchvision.datasets.utils import download_url
14 | from lib.utils import get_device
15 |
16 | # Adapted from: https://github.com/rtqichen/time-series-datasets
17 |
18 | class PersonActivity(object):
19 | urls = [
20 | 'https://archive.ics.uci.edu/ml/machine-learning-databases/00196/ConfLongDemo_JSI.txt',
21 | ]
22 |
23 | tag_ids = [
24 | "010-000-024-033", #"ANKLE_LEFT",
25 | "010-000-030-096", #"ANKLE_RIGHT",
26 | "020-000-033-111", #"CHEST",
27 | "020-000-032-221" #"BELT"
28 | ]
29 |
30 | tag_dict = {k: i for i, k in enumerate(tag_ids)}
31 |
32 | label_names = [
33 | "walking",
34 | "falling",
35 | "lying down",
36 | "lying",
37 | "sitting down",
38 | "sitting",
39 | "standing up from lying",
40 | "on all fours",
41 | "sitting on the ground",
42 | "standing up from sitting",
43 | "standing up from sit on grnd"
44 | ]
45 |
46 | #label_dict = {k: i for i, k in enumerate(label_names)}
47 |
48 | #Merge similar labels into one class
49 | label_dict = {
50 | "walking": 0,
51 | "falling": 1,
52 | "lying": 2,
53 | "lying down": 2,
54 | "sitting": 3,
55 | "sitting down" : 3,
56 | "standing up from lying": 4,
57 | "standing up from sitting": 4,
58 | "standing up from sit on grnd": 4,
59 | "on all fours": 5,
60 | "sitting on the ground": 6
61 | }
62 |
63 |
64 | def __init__(self, root, download=False,
65 | reduce='average', max_seq_length = 50,
66 | n_samples = None, device = torch.device("cpu")):
67 |
68 | self.root = root
69 | self.reduce = reduce
70 | self.max_seq_length = max_seq_length
71 |
72 | if download:
73 | self.download()
74 |
75 | if not self._check_exists():
76 | raise RuntimeError('Dataset not found. You can use download=True to download it')
77 |
78 | if device == torch.device("cpu"):
79 | self.data = torch.load(os.path.join(self.processed_folder, self.data_file), map_location='cpu')
80 | else:
81 | self.data = torch.load(os.path.join(self.processed_folder, self.data_file))
82 |
83 | if n_samples is not None:
84 | self.data = self.data[:n_samples]
85 |
86 | def download(self):
87 | if self._check_exists():
88 | return
89 |
90 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91 |
92 | os.makedirs(self.raw_folder, exist_ok=True)
93 | os.makedirs(self.processed_folder, exist_ok=True)
94 |
95 | def save_record(records, record_id, tt, vals, mask, labels):
96 | tt = torch.tensor(tt).to(self.device)
97 |
98 | vals = torch.stack(vals)
99 | mask = torch.stack(mask)
100 | labels = torch.stack(labels)
101 |
102 | # flatten the measurements for different tags
103 | vals = vals.reshape(vals.size(0), -1)
104 | mask = mask.reshape(mask.size(0), -1)
105 | assert(len(tt) == vals.size(0))
106 | assert(mask.size(0) == vals.size(0))
107 | assert(labels.size(0) == vals.size(0))
108 |
109 | #records.append((record_id, tt, vals, mask, labels))
110 |
111 | seq_length = len(tt)
112 | # split the long time series into smaller ones
113 | offset = 0
114 | slide = self.max_seq_length // 2
115 |
116 | while (offset + self.max_seq_length < seq_length):
117 | idx = range(offset, offset + self.max_seq_length)
118 |
119 | first_tp = tt[idx][0]
120 | records.append((record_id, tt[idx] - first_tp, vals[idx], mask[idx], labels[idx]))
121 | offset += slide
122 |
123 | for url in self.urls:
124 | filename = url.rpartition('/')[2]
125 | download_url(url, self.raw_folder, filename, None)
126 |
127 | print('Processing {}...'.format(filename))
128 |
129 | dirname = os.path.join(self.raw_folder)
130 | records = []
131 | first_tp = None
132 |
133 | for txtfile in os.listdir(dirname):
134 | with open(os.path.join(dirname, txtfile)) as f:
135 | lines = f.readlines()
136 | prev_time = -1
137 | tt = []
138 |
139 | record_id = None
140 | for l in lines:
141 | cur_record_id, tag_id, time, date, val1, val2, val3, label = l.strip().split(',')
142 | value_vec = torch.Tensor((float(val1), float(val2), float(val3))).to(self.device)
143 | time = float(time)
144 |
145 | if cur_record_id != record_id:
146 | if record_id is not None:
147 | save_record(records, record_id, tt, vals, mask, labels)
148 | tt, vals, mask, nobs, labels = [], [], [], [], []
149 | record_id = cur_record_id
150 |
151 | tt = [torch.zeros(1).to(self.device)]
152 | vals = [torch.zeros(len(self.tag_ids),3).to(self.device)]
153 | mask = [torch.zeros(len(self.tag_ids),3).to(self.device)]
154 | nobs = [torch.zeros(len(self.tag_ids)).to(self.device)]
155 | labels = [torch.zeros(len(self.label_names)).to(self.device)]
156 |
157 | first_tp = time
158 | time = round((time - first_tp)/ 10**5)
159 | prev_time = time
160 | else:
161 | # for speed -- we actually don't need to quantize it in Latent ODE
162 | time = round((time - first_tp)/ 10**5) # quatizing by 100 ms. 10,000 is one millisecond, 10,000,000 is one second
163 |
164 | if time != prev_time:
165 | tt.append(time)
166 | vals.append(torch.zeros(len(self.tag_ids),3).to(self.device))
167 | mask.append(torch.zeros(len(self.tag_ids),3).to(self.device))
168 | nobs.append(torch.zeros(len(self.tag_ids)).to(self.device))
169 | labels.append(torch.zeros(len(self.label_names)).to(self.device))
170 | prev_time = time
171 |
172 | if tag_id in self.tag_ids:
173 | n_observations = nobs[-1][self.tag_dict[tag_id]]
174 | if (self.reduce == 'average') and (n_observations > 0):
175 | prev_val = vals[-1][self.tag_dict[tag_id]]
176 | new_val = (prev_val * n_observations + value_vec) / (n_observations + 1)
177 | vals[-1][self.tag_dict[tag_id]] = new_val
178 | else:
179 | vals[-1][self.tag_dict[tag_id]] = value_vec
180 |
181 | mask[-1][self.tag_dict[tag_id]] = 1
182 | nobs[-1][self.tag_dict[tag_id]] += 1
183 |
184 | if label in self.label_names:
185 | if torch.sum(labels[-1][self.label_dict[label]]) == 0:
186 | labels[-1][self.label_dict[label]] = 1
187 | else:
188 | assert tag_id == 'RecordID', 'Read unexpected tag id {}'.format(tag_id)
189 | save_record(records, record_id, tt, vals, mask, labels)
190 |
191 | torch.save(
192 | records,
193 | os.path.join(self.processed_folder, 'data.pt')
194 | )
195 |
196 | print('Done!')
197 |
198 | def _check_exists(self):
199 | for url in self.urls:
200 | filename = url.rpartition('/')[2]
201 | if not os.path.exists(
202 | os.path.join(self.processed_folder, 'data.pt')
203 | ):
204 | return False
205 | return True
206 |
207 | @property
208 | def raw_folder(self):
209 | return os.path.join(self.root, self.__class__.__name__, 'raw')
210 |
211 | @property
212 | def processed_folder(self):
213 | return os.path.join(self.root, self.__class__.__name__, 'processed')
214 |
215 | @property
216 | def data_file(self):
217 | return 'data.pt'
218 |
219 | def __getitem__(self, index):
220 | return self.data[index]
221 |
222 | def __len__(self):
223 | return len(self.data)
224 |
225 | def __repr__(self):
226 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
227 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
228 | fmt_str += ' Root Location: {}\n'.format(self.root)
229 | fmt_str += ' Max length: {}\n'.format(self.max_seq_length)
230 | fmt_str += ' Reduce: {}\n'.format(self.reduce)
231 | return fmt_str
232 |
233 | def get_person_id(record_id):
234 | # The first letter is the person id
235 | person_id = record_id[0]
236 | person_id = ord(person_id) - ord("A")
237 | return person_id
238 |
239 |
240 |
241 | def variable_time_collate_fn_activity(batch, args, device = torch.device("cpu"), data_type = "train"):
242 | """
243 | Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where
244 | - record_id is a patient id
245 | - tt is a 1-dimensional tensor containing T time values of observations.
246 | - vals is a (T, D) tensor containing observed values for D variables.
247 | - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
248 | - labels is a list of labels for the current patient, if labels are available. Otherwise None.
249 | Returns:
250 | combined_tt: The union of all time observations.
251 | combined_vals: (M, T, D) tensor containing the observed values.
252 | combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
253 | """
254 | D = batch[0][2].shape[1]
255 | N = batch[0][-1].shape[1] # number of labels
256 |
257 | combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True)
258 | combined_tt = combined_tt.to(device)
259 |
260 | offset = 0
261 | combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
262 | combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
263 | combined_labels = torch.zeros([len(batch), len(combined_tt), N]).to(device)
264 |
265 | for b, (record_id, tt, vals, mask, labels) in enumerate(batch):
266 | tt = tt.to(device)
267 | vals = vals.to(device)
268 | mask = mask.to(device)
269 | labels = labels.to(device)
270 |
271 | indices = inverse_indices[offset:offset + len(tt)]
272 | offset += len(tt)
273 |
274 | combined_vals[b, indices] = vals
275 | combined_mask[b, indices] = mask
276 | combined_labels[b, indices] = labels
277 |
278 | combined_tt = combined_tt.float()
279 |
280 | if torch.max(combined_tt) != 0.:
281 | combined_tt = combined_tt / torch.max(combined_tt)
282 |
283 | data_dict = {
284 | "data": combined_vals,
285 | "time_steps": combined_tt,
286 | "mask": combined_mask,
287 | "labels": combined_labels}
288 |
289 | data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type)
290 | return data_dict
291 |
292 |
293 | if __name__ == '__main__':
294 | torch.manual_seed(1991)
295 |
296 | dataset = PersonActivity('data/PersonActivity', download=True)
297 | dataloader = DataLoader(dataset, batch_size=30, shuffle=True, collate_fn= variable_time_collate_fn_activity)
298 | dataloader.__iter__().next()
299 |
--------------------------------------------------------------------------------
/latentode/latent_ode/physionet.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Authors: Yulia Rubanova and Ricky Chen
4 | ###########################
5 |
6 | import os
7 | import matplotlib
8 | if os.path.exists("/Users/yulia"):
9 | matplotlib.use('TkAgg')
10 | else:
11 | matplotlib.use('Agg')
12 | import matplotlib.pyplot
13 | import matplotlib.pyplot as plt
14 |
15 | import lib.utils as utils
16 | import numpy as np
17 | import tarfile
18 | import torch
19 | from torch.utils.data import DataLoader
20 | from torchvision.datasets.utils import download_url
21 | from lib.utils import get_device
22 |
23 | # Adapted from: https://github.com/rtqichen/time-series-datasets
24 |
25 | # get minimum and maximum for each feature across the whole dataset
26 | def get_data_min_max(records):
27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28 |
29 | data_min, data_max = None, None
30 | inf = torch.Tensor([float("Inf")])[0].to(device)
31 |
32 | for b, (record_id, tt, vals, mask, labels) in enumerate(records):
33 | n_features = vals.size(-1)
34 |
35 | batch_min = []
36 | batch_max = []
37 | for i in range(n_features):
38 | non_missing_vals = vals[:,i][mask[:,i] == 1]
39 | if len(non_missing_vals) == 0:
40 | batch_min.append(inf)
41 | batch_max.append(-inf)
42 | else:
43 | batch_min.append(torch.min(non_missing_vals))
44 | batch_max.append(torch.max(non_missing_vals))
45 |
46 | batch_min = torch.stack(batch_min)
47 | batch_max = torch.stack(batch_max)
48 |
49 | if (data_min is None) and (data_max is None):
50 | data_min = batch_min
51 | data_max = batch_max
52 | else:
53 | data_min = torch.min(data_min, batch_min)
54 | data_max = torch.max(data_max, batch_max)
55 |
56 | return data_min, data_max
57 |
58 |
59 | class PhysioNet(object):
60 |
61 | urls = [
62 | 'https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download',
63 | 'https://physionet.org/files/challenge-2012/1.0.0/set-b.tar.gz?download',
64 | ]
65 |
66 | outcome_urls = ['https://physionet.org/files/challenge-2012/1.0.0/Outcomes-a.txt']
67 |
68 | params = [
69 | 'Age', 'Gender', 'Height', 'ICUType', 'Weight', 'Albumin', 'ALP', 'ALT', 'AST', 'Bilirubin', 'BUN',
70 | 'Cholesterol', 'Creatinine', 'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'Mg',
71 | 'MAP', 'MechVent', 'Na', 'NIDiasABP', 'NIMAP', 'NISysABP', 'PaCO2', 'PaO2', 'pH', 'Platelets', 'RespRate',
72 | 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT', 'Urine', 'WBC'
73 | ]
74 |
75 | params_dict = {k: i for i, k in enumerate(params)}
76 |
77 | labels = [ "SAPS-I", "SOFA", "Length_of_stay", "Survival", "In-hospital_death" ]
78 | labels_dict = {k: i for i, k in enumerate(labels)}
79 |
80 | def __init__(self, root, train=True, download=False,
81 | quantization = 0.1, n_samples = None, device = torch.device("cpu")):
82 |
83 | self.root = root
84 | self.train = train
85 | self.reduce = "average"
86 | self.quantization = quantization
87 |
88 | if download:
89 | self.download()
90 |
91 | if not self._check_exists():
92 | raise RuntimeError('Dataset not found. You can use download=True to download it')
93 |
94 | if self.train:
95 | data_file = self.training_file
96 | else:
97 | data_file = self.test_file
98 |
99 | if device == torch.device("cpu"):
100 | self.data = torch.load(os.path.join(self.processed_folder, data_file), map_location='cpu')
101 | self.labels = torch.load(os.path.join(self.processed_folder, self.label_file), map_location='cpu')
102 | else:
103 | self.data = torch.load(os.path.join(self.processed_folder, data_file))
104 | self.labels = torch.load(os.path.join(self.processed_folder, self.label_file))
105 |
106 | if n_samples is not None:
107 | self.data = self.data[:n_samples]
108 | self.labels = self.labels[:n_samples]
109 |
110 |
111 | def download(self):
112 | if self._check_exists():
113 | return
114 |
115 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
116 |
117 | os.makedirs(self.raw_folder, exist_ok=True)
118 | os.makedirs(self.processed_folder, exist_ok=True)
119 |
120 | # Download outcome data
121 | for url in self.outcome_urls:
122 | filename = url.rpartition('/')[2]
123 | download_url(url, self.raw_folder, filename, None)
124 |
125 | txtfile = os.path.join(self.raw_folder, filename)
126 | with open(txtfile) as f:
127 | lines = f.readlines()
128 | outcomes = {}
129 | for l in lines[1:]:
130 | l = l.rstrip().split(',')
131 | record_id, labels = l[0], np.array(l[1:]).astype(float)
132 | outcomes[record_id] = torch.Tensor(labels).to(self.device)
133 |
134 | torch.save(
135 | labels,
136 | os.path.join(self.processed_folder, filename.split('.')[0] + '.pt')
137 | )
138 |
139 | for url in self.urls:
140 | filename = url.rpartition('/')[2]
141 | download_url(url, self.raw_folder, filename, None)
142 | tar = tarfile.open(os.path.join(self.raw_folder, filename), "r:gz")
143 | tar.extractall(self.raw_folder)
144 | tar.close()
145 |
146 | print('Processing {}...'.format(filename))
147 |
148 | dirname = os.path.join(self.raw_folder, filename.split('.')[0])
149 | patients = []
150 | total = 0
151 | for txtfile in os.listdir(dirname):
152 | record_id = txtfile.split('.')[0]
153 | with open(os.path.join(dirname, txtfile)) as f:
154 | lines = f.readlines()
155 | prev_time = 0
156 | tt = [0.]
157 | vals = [torch.zeros(len(self.params)).to(self.device)]
158 | mask = [torch.zeros(len(self.params)).to(self.device)]
159 | nobs = [torch.zeros(len(self.params))]
160 | for l in lines[1:]:
161 | total += 1
162 | time, param, val = l.split(',')
163 | # Time in hours
164 | time = float(time.split(':')[0]) + float(time.split(':')[1]) / 60.
165 | # round up the time stamps (up to 6 min by default)
166 | # used for speed -- we actually don't need to quantize it in Latent ODE
167 | time = round(time / self.quantization) * self.quantization
168 |
169 | if time != prev_time:
170 | tt.append(time)
171 | vals.append(torch.zeros(len(self.params)).to(self.device))
172 | mask.append(torch.zeros(len(self.params)).to(self.device))
173 | nobs.append(torch.zeros(len(self.params)).to(self.device))
174 | prev_time = time
175 |
176 | if param in self.params_dict:
177 | #vals[-1][self.params_dict[param]] = float(val)
178 | n_observations = nobs[-1][self.params_dict[param]]
179 | if self.reduce == 'average' and n_observations > 0:
180 | prev_val = vals[-1][self.params_dict[param]]
181 | new_val = (prev_val * n_observations + float(val)) / (n_observations + 1)
182 | vals[-1][self.params_dict[param]] = new_val
183 | else:
184 | vals[-1][self.params_dict[param]] = float(val)
185 | mask[-1][self.params_dict[param]] = 1
186 | nobs[-1][self.params_dict[param]] += 1
187 | else:
188 | assert param == 'RecordID', 'Read unexpected param {}'.format(param)
189 | tt = torch.tensor(tt).to(self.device)
190 | vals = torch.stack(vals)
191 | mask = torch.stack(mask)
192 |
193 | labels = None
194 | if record_id in outcomes:
195 | # Only training set has labels
196 | labels = outcomes[record_id]
197 | # Out of 5 label types provided for Physionet, take only the last one -- mortality
198 | labels = labels[4]
199 |
200 | patients.append((record_id, tt, vals, mask, labels))
201 |
202 | torch.save(
203 | patients,
204 | os.path.join(self.processed_folder,
205 | filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
206 | )
207 |
208 | print('Done!')
209 |
210 | def _check_exists(self):
211 | for url in self.urls:
212 | filename = url.rpartition('/')[2]
213 |
214 | if not os.path.exists(
215 | os.path.join(self.processed_folder,
216 | filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
217 | ):
218 | return False
219 | return True
220 |
221 | @property
222 | def raw_folder(self):
223 | return os.path.join(self.root, self.__class__.__name__, 'raw')
224 |
225 | @property
226 | def processed_folder(self):
227 | return os.path.join(self.root, self.__class__.__name__, 'processed')
228 |
229 | @property
230 | def training_file(self):
231 | return 'set-a_{}.pt'.format(self.quantization)
232 |
233 | @property
234 | def test_file(self):
235 | return 'set-b_{}.pt'.format(self.quantization)
236 |
237 | @property
238 | def label_file(self):
239 | return 'Outcomes-a.pt'
240 |
241 | def __getitem__(self, index):
242 | return self.data[index]
243 |
244 | def __len__(self):
245 | return len(self.data)
246 |
247 | def get_label(self, record_id):
248 | return self.labels[record_id]
249 |
250 | def __repr__(self):
251 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
252 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
253 | fmt_str += ' Split: {}\n'.format('train' if self.train is True else 'test')
254 | fmt_str += ' Root Location: {}\n'.format(self.root)
255 | fmt_str += ' Quantization: {}\n'.format(self.quantization)
256 | fmt_str += ' Reduce: {}\n'.format(self.reduce)
257 | return fmt_str
258 |
259 | def visualize(self, timesteps, data, mask, plot_name):
260 | width = 15
261 | height = 15
262 |
263 | non_zero_attributes = (torch.sum(mask,0) > 2).numpy()
264 | non_zero_idx = [i for i in range(len(non_zero_attributes)) if non_zero_attributes[i] == 1.]
265 | n_non_zero = sum(non_zero_attributes)
266 |
267 | mask = mask[:, non_zero_idx]
268 | data = data[:, non_zero_idx]
269 |
270 | params_non_zero = [self.params[i] for i in non_zero_idx]
271 | params_dict = {k: i for i, k in enumerate(params_non_zero)}
272 |
273 | n_col = 3
274 | n_row = n_non_zero // n_col + (n_non_zero % n_col > 0)
275 | fig, ax_list = plt.subplots(n_row, n_col, figsize=(width, height), facecolor='white')
276 |
277 | #for i in range(len(self.params)):
278 | for i in range(n_non_zero):
279 | param = params_non_zero[i]
280 | param_id = params_dict[param]
281 |
282 | tp_mask = mask[:,param_id].long()
283 |
284 | tp_cur_param = timesteps[tp_mask == 1.]
285 | data_cur_param = data[tp_mask == 1., param_id]
286 |
287 | ax_list[i // n_col, i % n_col].plot(tp_cur_param.numpy(), data_cur_param.numpy(), marker='o')
288 | ax_list[i // n_col, i % n_col].set_title(param)
289 |
290 | fig.tight_layout()
291 | fig.savefig(plot_name)
292 | plt.close(fig)
293 |
294 |
295 | def variable_time_collate_fn(batch, args, device = torch.device("cpu"), data_type = "train",
296 | data_min = None, data_max = None):
297 | """
298 | Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where
299 | - record_id is a patient id
300 | - tt is a 1-dimensional tensor containing T time values of observations.
301 | - vals is a (T, D) tensor containing observed values for D variables.
302 | - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
303 | - labels is a list of labels for the current patient, if labels are available. Otherwise None.
304 | Returns:
305 | combined_tt: The union of all time observations.
306 | combined_vals: (M, T, D) tensor containing the observed values.
307 | combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
308 | """
309 | D = batch[0][2].shape[1]
310 | combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True)
311 | combined_tt = combined_tt.to(device)
312 |
313 | offset = 0
314 | combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
315 | combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
316 |
317 | combined_labels = None
318 | N_labels = 1
319 |
320 | combined_labels = torch.zeros(len(batch), N_labels) + torch.tensor(float('nan'))
321 | combined_labels = combined_labels.to(device = device)
322 |
323 | for b, (record_id, tt, vals, mask, labels) in enumerate(batch):
324 | tt = tt.to(device)
325 | vals = vals.to(device)
326 | mask = mask.to(device)
327 | if labels is not None:
328 | labels = labels.to(device)
329 |
330 | indices = inverse_indices[offset:offset + len(tt)]
331 | offset += len(tt)
332 |
333 | combined_vals[b, indices] = vals
334 | combined_mask[b, indices] = mask
335 |
336 | if labels is not None:
337 | combined_labels[b] = labels
338 |
339 | combined_vals, _, _ = utils.normalize_masked_data(combined_vals, combined_mask,
340 | att_min = data_min, att_max = data_max)
341 |
342 | if torch.max(combined_tt) != 0.:
343 | combined_tt = combined_tt / torch.max(combined_tt)
344 |
345 | data_dict = {
346 | "data": combined_vals,
347 | "time_steps": combined_tt,
348 | "mask": combined_mask,
349 | "labels": combined_labels}
350 |
351 | data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type)
352 | return data_dict
353 |
354 | if __name__ == '__main__':
355 | torch.manual_seed(1991)
356 |
357 | dataset = PhysioNet('data/physionet', train=False, download=True)
358 | dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=variable_time_collate_fn)
359 | print(dataloader.__iter__().next())
360 |
--------------------------------------------------------------------------------
/latentode/latent_ode/train-activity.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif -b 100
4 |
--------------------------------------------------------------------------------
/latentode/latent_ode/train-ecg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | #python3 run_models.py --niters 200 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 1000
4 |
5 | #python3 run_models.py --niters 400 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 4000
6 |
7 | # smaller learning rate
8 | python3 run_models.py --niters 400 -n 10000 -l 15 --dataset ecg --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 50 --gru-units 50 --classif --linear-classif -b 2000 --lr 0.0001
9 |
--------------------------------------------------------------------------------
/latentode/latent_ode/train-periodic100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 --viz -t 100
4 |
--------------------------------------------------------------------------------
/latentode/latent_ode/train-periodic1000.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 --viz -t 1000
4 |
--------------------------------------------------------------------------------
/latentode/ours_impl/__pycache__/lv_field.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/latentode/ours_impl/__pycache__/lv_field.cpython-38.pyc
--------------------------------------------------------------------------------
/latentode/ours_impl/lv_field.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | import FrEIA.framework as Ff
9 | import FrEIA.modules as Fm
10 |
11 |
12 | class LVField(nn.Module):
13 |
14 | def __init__(self, dim=3, augmented_dim=6, hidden_dim=200, num_layers=8, e_vec_factor=1.0, t_d_factor=1.0):
15 | super().__init__()
16 | vec = torch.zeros(augmented_dim)
17 | vec[:dim] = .1
18 | self.dim = dim
19 | self.augmented_dim = augmented_dim
20 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0))
21 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0)
22 |
23 | self.e_vec_factor = e_vec_factor
24 | self.t_d_factor = t_d_factor
25 |
26 | # build inn
27 | def subnet_fc(dims_in, dims_out):
28 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(),
29 | nn.Linear(hidden_dim, dims_out))
30 |
31 | self.inn = Ff.SequenceINN(augmented_dim)
32 | for k in range(num_layers):
33 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc
34 | # ,permute_soft=True
35 | )
36 |
37 | # H = 1000
38 | # self._nn = torch.nn.Sequential(
39 | # torch.nn.Linear(augmented_dim, H),
40 | # torch.nn.ReLU(),
41 | # torch.nn.Linear(H, H),
42 | # torch.nn.ReLU(),
43 | # torch.nn.Linear(H, H),
44 | # torch.nn.ReLU(),
45 | # torch.nn.Linear(H, augmented_dim),
46 | # )
47 |
48 | # def node(self, t, x):
49 | # return self._nn(x)
50 |
51 |
52 | # def eigen_ode(self,w_vec,init_x,t_d):
53 | # e_val, e_vec= w_vec[0], w_vec[1:]
54 | # init_v = torch.mm(torch.linalg.inv(e_vec), init_x.reshape((-1,1)))
55 |
56 | # return torch.mm(init_v.T * e_vec, torch.exp(e_val[:, None] * t_d)).T
57 |
58 | def eigen_ode(self, w_vec,init_x,t_d):
59 | from torchdiffeq import odeint, odeint_adjoint
60 |
61 | #print(init_x.shape)
62 | #print(t_d.shape)
63 | # out = odeint(func=self.node, y0=init_x, t=t_d, method='midpoint', options={
64 | # 'step_size': 1,
65 | # })
66 | # out = out.transpose(0,1)
67 | # #print(out)
68 | # #print(out.shape)
69 | # return out
70 |
71 |
72 | e_val,e_vec=(w_vec[0],w_vec[1:])
73 |
74 | # #_e_vec = e_vec
75 | # #_t_d = t_d
76 | # #_t_d = t_d * 0.001
77 | #
78 | # # for act
79 | # _e_vec = e_vec * 0.001
80 | # _t_d = t_d * 0.1
81 | #
82 | # # for periodic
83 | # _e_vec = e_vec #* 0.001
84 | # _t_d = t_d #* 0.1
85 | #
86 | #
87 | # _e_vec = e_vec * 1.0
88 | # _t_d = t_d * 0.01
89 |
90 | _e_vec = e_vec * self.e_vec_factor
91 | _t_d = t_d * self.t_d_factor
92 |
93 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None])
94 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))),
95 | torch.exp(
96 | e_val[:,None] * _t_d
97 | #self._nn(t_d[:, None]).T
98 | )
99 | .expand((init_x.shape[0],-1,-1))).transpose(1,2)
100 | return rs
101 |
102 |
103 | def _forward(self, init_v, t_d, padding, remove_padding_after):
104 |
105 | if padding:
106 | init_v = self.pad(init_v)
107 |
108 | init_v_in = self.inn(init_v)[0]
109 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d)
110 |
111 | _ori_shape = eval_lin.shape
112 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0]
113 | out = out.reshape(_ori_shape)
114 | if remove_padding_after:
115 | out = out[:, :, :self.dim]
116 | return out
117 |
118 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True):
119 | # return self._forward(init_v, t_d, padding, remove_padding_after)
120 | # print(init_v.shape)
121 | if len(init_v.shape) == 2:
122 | # batch of multiple traj
123 | return self._forward(init_v, t_d, padding, remove_padding_after)
124 | elif len(init_v.shape) == 3:
125 | # batch of multiple traj
126 | # TODO make it forward a whole batch
127 | out = []
128 | for i in range(init_v.shape[0]):
129 | out.append(self(init_v[i], t_d, padding, remove_padding_after))
130 | # timer.print_stats()
131 | return torch.stack(out)
132 | else:
133 | raise NotImplementedError(f"input has dimensionality {init_v.shape}")
134 |
--------------------------------------------------------------------------------
/robotic/.combine_stats.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/.combine_stats.py.swp
--------------------------------------------------------------------------------
/robotic/.run_stiff_time.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/.run_stiff_time.py.swp
--------------------------------------------------------------------------------
/robotic/3D_Cshape_top.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/3D_Cshape_top.mat
--------------------------------------------------------------------------------
/robotic/S_gt_traj.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/S_gt_traj.npy
--------------------------------------------------------------------------------
/robotic/bench_output/out.txt:
--------------------------------------------------------------------------------
1 | ==============================
2 | dataset C
3 | dim=3
4 | seed = 0
5 | ==============================
6 | dataset S
7 | dim=2
8 | ==============================
9 | dataset S
10 | dim=2
11 | seed = 0
12 | ==============================
13 | dataset S
14 | dim=2
15 | seed = 0
16 |
--------------------------------------------------------------------------------
/robotic/cube_pick.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/robotic/cube_pick.npy
--------------------------------------------------------------------------------
/rotating_MNIST/LV_stat_dic_0.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/LV_stat_dic_0.tar
--------------------------------------------------------------------------------
/rotating_MNIST/autoencoder_state_dic_0.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/autoencoder_state_dic_0.tar
--------------------------------------------------------------------------------
/rotating_MNIST/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as pl
3 | import torch
4 | import scipy
5 | from mnist import MNIST
6 | from scipy import ndimage
7 | import torch.nn as nn
8 | import torchdiffeq
9 | from scipy import ndimage
10 | import FrEIA.framework as Ff
11 | import FrEIA.modules as Fm
12 |
13 | mndata = MNIST('samples')
14 | images, labels = mndata.load_training()
15 | ds=np.array(images)
16 | d_l=np.array(labels)
17 | dt3=ds[d_l[:]==3]
18 | dt3_tor=torch.tensor(dt3).reshape((-1,28,28))
19 |
20 |
21 | class Autoencoder(nn.Module):
22 | def __init__(self):
23 | super(Autoencoder, self).__init__()
24 | self.encoder = nn.Sequential( # like the Composition layer you built
25 | nn.Conv2d(1, 16, 3, stride=2, padding=1),
26 | nn.ReLU(),
27 | nn.Conv2d(16, 32, 3, stride=2, padding=1),
28 | nn.ReLU(),
29 | nn.Conv2d(32, 64, 7)
30 |
31 | )
32 | self.decoder = nn.Sequential(
33 | nn.ConvTranspose2d(64, 32, 7),
34 | nn.ReLU(),
35 | nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
36 | nn.ReLU(),
37 | nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
38 | nn.Sigmoid()
39 | )
40 |
41 | def forward(self, x):
42 | x = self.encoder(x)
43 | #x=nn.Flatten()(x)
44 | #x=nn.Unflatten(dim=1,unflattened_size=(64,1,1))(x)
45 | x = self.decoder(x)
46 | return x
47 |
48 | def train(model, num_epochs=5, batch_size=64, learning_rate=1e-4):
49 | torch.manual_seed(42)
50 | criterion = nn.MSELoss()
51 | optimizer = torch.optim.Adam(model.parameters(),
52 | lr=learning_rate,
53 | weight_decay=1e-5)
54 | train_loader = torch.utils.data.DataLoader(all_imgs_l,
55 | batch_size=batch_size,
56 | shuffle=True)
57 | outputs = []
58 | for epoch in range(num_epochs):
59 | for data in train_loader:
60 | img = data
61 | recon = model(img)
62 | loss = criterion(recon, img)
63 | loss.backward()
64 | optimizer.step()
65 | optimizer.zero_grad()
66 | if((epoch+1)%5==0):
67 | print('Epoch:{}, Loss:{:.4f}'.format(epoch, float(loss)))
68 | outputs.append((epoch, img, recon),)
69 | return outputs
70 |
71 | model = Autoencoder()
72 | gt_angles=np.linspace(5,177,440)
73 | class LVField(nn.Module):
74 |
75 | def __init__(self, dim=64, augmented_dim=64+32, hidden_dim=500, num_layers=8, e_vec_factor=1e-7, t_d_factor=1e-4):
76 | super().__init__()
77 | vec = torch.zeros(augmented_dim)
78 | vec[:dim] = .1
79 | self.dim = dim
80 | self.augmented_dim = augmented_dim
81 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0))
82 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0)
83 |
84 | self.e_vec_factor = torch.tensor(e_vec_factor)
85 | self.t_d_factor = torch.tensor(t_d_factor)
86 |
87 | # build inn
88 | def subnet_fc(dims_in, dims_out):
89 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(),
90 | nn.Linear(hidden_dim, dims_out))
91 |
92 | self.inn = Ff.SequenceINN(augmented_dim)
93 | for k in range(num_layers):
94 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc
95 | # ,permute_soft=True
96 | )
97 |
98 |
99 | def eigen_ode(self, w_vec,init_x,t_d):
100 | from torchdiffeq import odeint, odeint_adjoint
101 |
102 |
103 |
104 | e_val,e_vec=(w_vec[0],w_vec[1:])
105 |
106 | #e_val = -e_val**2 - 1e-10
107 |
108 |
109 | _e_vec = e_vec * self.e_vec_factor
110 | _t_d = t_d * self.t_d_factor
111 |
112 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None])
113 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))),
114 | torch.exp(
115 | e_val[:,None] * _t_d
116 | #self._nn(t_d[:, None]).T
117 | )
118 | .expand((init_x.shape[0],-1,-1))).transpose(1,2)
119 | return rs
120 |
121 |
122 | def _forward(self, init_v, t_d, padding, remove_padding_after):
123 |
124 | if padding:
125 | init_v = self.pad(init_v)
126 |
127 | init_v_in = self.inn(init_v)[0]
128 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d)
129 |
130 | _ori_shape = eval_lin.shape
131 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0]
132 | out = out.reshape(_ori_shape)
133 | if remove_padding_after:
134 | out = out[:, :, :self.dim]
135 | return out
136 |
137 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True):
138 | # return self._forward(init_v, t_d, padding, remove_padding_after)
139 | # print(init_v.shape)
140 | if len(init_v.shape) == 2:
141 | # batch of multiple traj
142 | return self._forward(init_v, t_d, padding, remove_padding_after)
143 | elif len(init_v.shape) == 3:
144 | # batch of multiple traj
145 | # TODO make it forward a whole batch
146 | out = []
147 | for i in range(init_v.shape[0]):
148 | out.append(self(init_v[i], t_d, padding, remove_padding_after))
149 | # timer.print_stats()
150 | return torch.stack(out)
151 | #else:
152 | # raise NotImplementedError(f"input has dimensionality {init_v.shape}")
153 | #create ground truth
154 | save_list=[]
155 |
156 | for n in range(100,200):
157 | #print(n)
158 | s_list=[]
159 | for i in range(len(gt_angles)):
160 | #pl.figure()
161 | rotated = ndimage.rotate(dt3_tor[n], gt_angles[i],reshape=False)
162 | #rot_len=int(len(rotated)/2)
163 | #rotated=rotated[]
164 | s_list.append(torch.tensor(rotated)[None])
165 | s_list_tor=torch.cat(s_list)
166 | save_list.append(s_list_tor[None])
167 | all_tor_ds=torch.cat(save_list)
168 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float()
169 | all_tor_ds_max=torch.max(all_tor_ds.reshape(44000,-1),dim=1)
170 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1))
171 | m=LVField()
172 |
173 |
174 |
175 | #create ground truth
176 | save_list=[]
177 |
178 | for n in range(100,200):
179 | #print(n)
180 | s_list=[]
181 | for i in range(len(gt_angles)):
182 | #pl.figure()
183 | rotated = ndimage.rotate(dt3_tor[n], gt_angles[i],reshape=False)
184 | #rot_len=int(len(rotated)/2)
185 | #rotated=rotated[]
186 | s_list.append(torch.tensor(rotated)[None])
187 | s_list_tor=torch.cat(s_list)
188 | save_list.append(s_list_tor[None])
189 |
190 | all_tor_ds=torch.cat(save_list)
191 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float()
192 | all_tor_ds_max=torch.max(all_tor_ds.reshape(44000,-1),dim=1)
193 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1))
194 |
195 |
196 | model.load_state_dict(torch.load('autoencoder_state_dic_0.tar'))
197 | m.load_state_dict(torch.load('LV_stat_dic_0.tar'))
198 |
199 | all_tor_rr=all_tor_ds_s.reshape((100,440,1,28,28))
200 | test_c=all_tor_rr[:,0]
201 | enc=model.encoder(test_c)
202 |
203 | #GPU
204 | device='cpu'
205 | m_gpu=m.to(device)
206 | st_gpu=enc[0,:,0,0][None].to(device)
207 |
208 |
209 |
210 | nt_gpu_exp=torch.arange(0,44,0.1).to(device)
211 |
212 | import time
213 | time_list=[]
214 | for i in range(len(enc)):
215 | st_gpu=enc[i,:,0,0][None].to(device)
216 | start_time = time.time()
217 | out_traj=m_gpu.forward(st_gpu,nt_gpu_exp)
218 | end_time=time.time()
219 | time_list.append(end_time-start_time)
220 | print("mean (ms):")
221 | print(np.array(time_list).mean()*1000)
222 | print('std:')
223 | print(np.array(time_list).std()*1000)
224 |
225 | st_gpu=enc[:,:,0,0][None].to(device)
226 | out_traj=m_gpu.forward(st_gpu,nt_gpu_exp)
227 |
228 | rot_t1=model.decoder(out_traj[:,:,:,None,None].to('cpu').reshape((-1,64,1,1))).detach()
229 |
230 | print('mse')
231 | print(nn.MSELoss()(rot_t1,all_tor_ds_s))
232 |
--------------------------------------------------------------------------------
/rotating_MNIST/readme.md:
--------------------------------------------------------------------------------
1 | Run run_rotate.py to train and save weights, then run evaluate.py to evaluate
2 |
--------------------------------------------------------------------------------
/rotating_MNIST/rotating3_2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/rotating3_2.pdf
--------------------------------------------------------------------------------
/rotating_MNIST/run_rotate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as pl
3 | import torch
4 | import scipy
5 | from mnist import MNIST
6 |
7 | mndata = MNIST('samples')
8 |
9 | images, labels = mndata.load_training()
10 | ds=np.array(images)
11 | d_l=np.array(labels)
12 | dt3=ds[d_l[:]==3]
13 | dt3_tor=torch.tensor(dt3).reshape((-1,28,28))
14 |
15 |
16 | from scipy import ndimage
17 | save_list=[]
18 | print('generating data')
19 | for n in range(0,100):
20 |
21 | s_list=[]
22 | for i in range(5,180,4):
23 | #pl.figure()
24 | rotated = ndimage.rotate(dt3_tor[n], i,reshape=False)
25 | #rot_len=int(len(rotated)/2)
26 | #rotated=rotated[]
27 | s_list.append(torch.tensor(rotated)[None])
28 | s_list_tor=torch.cat(s_list)
29 | save_list.append(s_list_tor[None])
30 | save_list_tor=torch.cat(save_list)
31 | print('finished generating data')
32 | save_list_tor_sample=save_list_tor
33 | traj=save_list_tor_sample
34 | import torch.nn as nn
35 | import torchdiffeq
36 |
37 | class Autoencoder(nn.Module):
38 | def __init__(self):
39 | super(Autoencoder, self).__init__()
40 | self.encoder = nn.Sequential( # like the Composition layer you built
41 | nn.Conv2d(1, 16, 3, stride=2, padding=1),
42 | nn.ReLU(),
43 | nn.Conv2d(16, 32, 3, stride=2, padding=1),
44 | nn.ReLU(),
45 | nn.Conv2d(32, 64, 7)
46 |
47 | )
48 | self.decoder = nn.Sequential(
49 | nn.ConvTranspose2d(64, 32, 7),
50 | nn.ReLU(),
51 | nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
52 | nn.ReLU(),
53 | nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
54 | nn.Sigmoid()
55 | )
56 |
57 | def forward(self, x):
58 | x = self.encoder(x)
59 | #x=nn.Flatten()(x)
60 | #x=nn.Unflatten(dim=1,unflattened_size=(64,1,1))(x)
61 | x = self.decoder(x)
62 | return x
63 |
64 | all_imgs=traj[:,:,:,:].reshape(-1,1,28,28).float()
65 | all_imgs_max=torch.max(all_imgs.reshape(4400,-1),dim=1)
66 | all_imgs=all_imgs/all_imgs_max.values.reshape((-1,1,1,1))
67 | all_imgs_l=list(all_imgs)
68 |
69 | def train(model, num_epochs=5, batch_size=64, learning_rate=1e-4):
70 | torch.manual_seed(42)
71 | criterion = nn.MSELoss() # mean square error loss
72 | optimizer = torch.optim.Adam(model.parameters(),
73 | lr=learning_rate,
74 | weight_decay=1e-5) # <--
75 | train_loader = torch.utils.data.DataLoader(all_imgs_l,
76 | batch_size=batch_size,
77 | shuffle=True)
78 | outputs = []
79 | for epoch in range(num_epochs):
80 | for data in train_loader:
81 | img = data
82 | recon = model(img)
83 | loss = criterion(recon, img)
84 | loss.backward()
85 | optimizer.step()
86 | optimizer.zero_grad()
87 |
88 | print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
89 | outputs.append((epoch, img, recon),)
90 | return outputs
91 |
92 | from scipy import ndimage
93 | save_list=[]
94 | print('generating test data')
95 | for n in range(100,200):
96 | print(n)
97 | s_list=[]
98 | for i in range(5,180,4):
99 | #pl.figure()
100 | rotated = ndimage.rotate(dt3_tor[n], i,reshape=False)
101 | #rot_len=int(len(rotated)/2)
102 | #rotated=rotated[]
103 | s_list.append(torch.tensor(rotated)[None])
104 | s_list_tor=torch.cat(s_list)
105 | save_list.append(s_list_tor[None])
106 | all_tor_ds=torch.cat(save_list)
107 |
108 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float()
109 | all_tor_ds_max=torch.max(all_tor_ds.reshape(4400,-1),dim=1)
110 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1))
111 | print('generated test data')
112 |
113 | print('train conv autoencoder')
114 | seed = 0
115 | torch.manual_seed(seed)
116 | import random
117 | random.seed(seed)
118 | np.random.seed(seed)
119 | device='cpu'
120 |
121 | model = Autoencoder()
122 |
123 | max_epochs = 50
124 | outputs = train(model, num_epochs=max_epochs)
125 |
126 | import FrEIA.framework as Ff
127 | import FrEIA.modules as Fm
128 |
129 |
130 | class LVField(nn.Module):
131 |
132 | def __init__(self, dim=64, augmented_dim=64+32, hidden_dim=500, num_layers=8, e_vec_factor=1e-7, t_d_factor=1e-4):
133 | super().__init__()
134 | vec = torch.zeros(augmented_dim)
135 | vec[:dim] = .1
136 | self.dim = dim
137 | self.augmented_dim = augmented_dim
138 | self.w_vec = nn.Parameter(torch.cat([vec.unsqueeze(0),torch.eye(augmented_dim)],dim=0))
139 | self.pad = nn.ConstantPad1d((0, augmented_dim - dim), 0)
140 |
141 | self.e_vec_factor = torch.tensor(e_vec_factor)
142 | self.t_d_factor = torch.tensor(t_d_factor)
143 |
144 | # build inn
145 | def subnet_fc(dims_in, dims_out):
146 | return nn.Sequential(nn.Linear(dims_in, hidden_dim), nn.ReLU(),
147 | nn.Linear(hidden_dim, dims_out))
148 |
149 | self.inn = Ff.SequenceINN(augmented_dim)
150 | for k in range(num_layers):
151 | self.inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc
152 | # ,permute_soft=True
153 | )
154 |
155 |
156 | def eigen_ode(self, w_vec,init_x,t_d):
157 | from torchdiffeq import odeint, odeint_adjoint
158 |
159 |
160 |
161 | e_val,e_vec=(w_vec[0],w_vec[1:])
162 |
163 | #e_val = -e_val**2 - 1e-10
164 |
165 |
166 | _e_vec = e_vec * self.e_vec_factor
167 | _t_d = t_d * self.t_d_factor
168 |
169 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None])
170 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))),
171 | torch.exp(
172 | e_val[:,None] * _t_d
173 | #self._nn(t_d[:, None]).T
174 | )
175 | .expand((init_x.shape[0],-1,-1))).transpose(1,2)
176 | return rs
177 |
178 |
179 | def _forward(self, init_v, t_d, padding, remove_padding_after):
180 |
181 | if padding:
182 | init_v = self.pad(init_v)
183 |
184 | init_v_in = self.inn(init_v)[0]
185 | eval_lin = self.eigen_ode(self.w_vec,init_v_in,t_d)
186 |
187 | _ori_shape = eval_lin.shape
188 | out = self.inn(eval_lin.reshape(-1, eval_lin.shape[-1]),rev=True)[0]
189 | out = out.reshape(_ori_shape)
190 | if remove_padding_after:
191 | out = out[:, :, :self.dim]
192 | return out
193 |
194 | def forward(self,init_v,t_d, padding=True, remove_padding_after=True):
195 | # return self._forward(init_v, t_d, padding, remove_padding_after)
196 | # print(init_v.shape)
197 | if len(init_v.shape) == 2:
198 | # batch of multiple traj
199 | return self._forward(init_v, t_d, padding, remove_padding_after)
200 | elif len(init_v.shape) == 3:
201 | # batch of multiple traj
202 | # TODO make it forward a whole batch
203 | out = []
204 | for i in range(init_v.shape[0]):
205 | out.append(self(init_v[i], t_d, padding, remove_padding_after))
206 | # timer.print_stats()
207 | return torch.stack(out)
208 | else:
209 | raise NotImplementedError(f"input has dimensionality {init_v.shape}")
210 |
211 | # seed = 0
212 | # torch.manual_seed(seed)
213 | # import random
214 | # random.seed(seed)
215 | # np.random.seed(seed)
216 | m=LVField()
217 | opt_ours=torch.optim.Adam(m.parameters(),lr=0.0005,weight_decay=1e-5)
218 | all_tor_ds=all_tor_ds.reshape(-1,1,28,28).float()
219 | all_tor_ds_max=torch.max(all_tor_ds.reshape(4400,-1),dim=1)
220 | all_tor_ds_s=all_tor_ds/all_tor_ds_max.values.reshape((-1,1,1,1))
221 | tt=model.encoder(all_imgs)
222 | tt_re=tt.reshape((-1,44,64))
223 | nt=torch.arange(0,44).float()
224 | start_p=tt_re[:,0].detach().clone()
225 | for i in range(5000):
226 | opt_ours.zero_grad()
227 | out_traj_o=m.forward(start_p,nt)
228 |
229 | nloss=nn.MSELoss()(out_traj_o,tt_re.detach())
230 | if(i%20==0):
231 | print(str(i)+':'+str(float(nloss)))
232 | nloss.backward()
233 | opt_ours.step()
234 | torch.save(m.state_dict(), 'LV_stat_dic_0.tar')
235 | torch.save(model.state_dict(), 'autoencoder_state_dic_0.tar')
236 |
--------------------------------------------------------------------------------
/rotating_MNIST/samples/t10k-images-idx3-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/t10k-images-idx3-ubyte
--------------------------------------------------------------------------------
/rotating_MNIST/samples/t10k-labels-idx1-ubyte:
--------------------------------------------------------------------------------
1 | '
--------------------------------------------------------------------------------
/rotating_MNIST/samples/train-images-idx3-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/train-images-idx3-ubyte
--------------------------------------------------------------------------------
/rotating_MNIST/samples/train-labels-idx1-ubyte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/rotating_MNIST/samples/train-labels-idx1-ubyte
--------------------------------------------------------------------------------
/run-all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | (cd rotating_MNIST/ && python3 evaluate.py)
4 | (cd stiff_ode/ && python3 run_rober.py)
5 | (cd robotic/ && python3 run.py)
6 | (cd systems/ && python3 LV_train_run_ours.py)
7 | (cd systems/ && python3 LV_train_theirs.py)
8 | (cd systems/ && python3 LV_train_run_theirs.py)
9 | (cd systems/ && python3 lor_train_ours.py)
10 | (cd systems/ && python3 lor_train_theirs.py)
11 | (cd systems/ && python3 lor_run_all.py)
12 | (cd latentode/latent_ode/ && ./train-periodic100.sh)
13 | (cd latentode/latent_ode/ && ./train-periodic1000.sh)
14 | (cd latentode/latent_ode/ && ./train-activity.sh)
15 | (cd latentode/latent_ode/ && ./train-ecg.sh)
16 |
17 |
--------------------------------------------------------------------------------
/stiff_ode/run_rober.py:
--------------------------------------------------------------------------------
1 | from scipy.integrate import solve_ivp
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import platform
5 | def robertson_conserved ( t, y ):
6 | h = np.sum ( y, axis = 0 )
7 | return h
8 |
9 | def robertson_deriv ( t, y ):
10 | y1 = y[0]
11 | y2 = y[1]
12 | y3 = y[2]
13 | dydt = np.zeros(3)
14 | dydt[0] = - 0.04 * y1 + 10000.0 * y2 * y3
15 | dydt[1] = 0.04 * y1 - 10000.0 * y2 * y3 - 30000000.0 * y2 * y2
16 | dydt[2] = + 30000000.0 * y2 * y2
17 | return dydt
18 | from scipy.integrate import solve_ivp
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import platform
22 |
23 | tmin = 0.0
24 | tmin_logsp=-6
25 | tmax_logsp = 6#10000.0
26 | tmax=1000000.0
27 | y0 = np.array ( [ 1.0, 0.0, 0.0 ] )
28 | tspan = np.array ( [ tmin, tmax ] )
29 | t = np.logspace ( tmin_logsp, tmax_logsp, num=50 )
30 | #
31 | # Use the LSODA solver, that is suitable for stiff systems.
32 | #
33 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=t, method = 'LSODA' )
34 |
35 | import torch
36 | import torch.nn as n
37 | import FrEIA.framework as Ff
38 | import FrEIA.modules as Fm
39 |
40 | tt=torch.tensor(np.log10(sol.t))
41 | yy=torch.tensor(sol.y.T)
42 |
43 | seed = 0
44 | torch.manual_seed(seed)
45 | import random
46 | random.seed(seed)
47 | np.random.seed(seed)
48 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]])
49 | device='cpu'
50 | t_d=tt
51 | t_d = t_d.to(device)
52 | xy_d =torch.tensor(yy).to(device)
53 |
54 | xy_d_max=torch.max(xy_d,dim=0).values
55 | factor=1./xy_d_max
56 | hidden_size=1500
57 | num_layers=5
58 |
59 | f_x=n.Sequential(
60 | n.Linear(6, 30),
61 | n.Tanh(),
62 | n.Linear(30, 30),
63 | n.Tanh(),
64 | n.Linear(30, 30),
65 | n.Tanh(),
66 | n.Linear(30, 6)
67 | )
68 |
69 | xy_d_max=torch.max(xy_d,dim=0).values
70 |
71 | def fx(t,x):
72 | return(f_x(x))
73 |
74 | def for_inn(x):
75 | return(inn(x)[0])
76 | def rev_inn(x):
77 | return(inn(x,rev=True)[0])
78 | def rev_mse_inn_eig(rf,x_gt):
79 | return(torch.mean(torch.norm(rf-x_gt,dim=1)))
80 | def linear_val_ode(w_vec,init_v,t_d):
81 | init_v_in=rev_inn(init_v)
82 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d)
83 | ori_shape = eval_lin.shape
84 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))
85 | return(eval_out.reshape(ori_shape))
86 | def linear_val_ode2(init_v,t_d):
87 | init_v_in=rev_inn(init_v)
88 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01}
89 | eval_out=for_inn(eval_lin)
90 | return(eval_out)
91 |
92 |
93 |
94 | seed = 42
95 | torch.manual_seed(seed)
96 | import random
97 | random.seed(seed)
98 | np.random.seed(seed)
99 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]])
100 | device='cpu'
101 |
102 | t_d = tt.to(device)
103 | xy_d =torch.tensor(yy).to(device)
104 |
105 | xy_d_max=torch.max(xy_d,dim=0).values
106 | factor=1./xy_d_max
107 | hidden_size=1500
108 | num_layers=5
109 |
110 | f_x=n.Sequential(
111 | n.Linear(6, 30),
112 | n.Tanh(),
113 | n.Linear(30, 30),
114 | n.Tanh(),
115 | n.Linear(30, 30),
116 | n.Tanh(),
117 | n.Linear(30, 6)
118 | )
119 |
120 | xy_d_max=torch.max(xy_d,dim=0).values
121 |
122 | def fx(t,x):
123 | return(f_x(x))
124 |
125 | def for_inn(x):
126 | return(inn(x)[0])
127 | def rev_inn(x):
128 | return(inn(x,rev=True)[0])
129 | def rev_mse_inn_eig(rf,x_gt):
130 | return(torch.mean(torch.norm(rf-x_gt,dim=1)))
131 | def linear_val_ode(w_vec,init_v,t_d):
132 | init_v_in=rev_inn(init_v)
133 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d)
134 | ori_shape = eval_lin.shape
135 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))
136 | return(eval_out.reshape(ori_shape))
137 | def linear_val_ode2(init_v,t_d):
138 | init_v_in=rev_inn(init_v)
139 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01}
140 | eval_out=for_inn(eval_lin)
141 | return(eval_out)
142 |
143 |
144 |
145 | N_DIM = 6
146 | def subnet_fc(dims_in, dims_out):
147 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(),
148 | n.Linear(hidden_size, dims_out))
149 |
150 | inn = Ff.SequenceINN(N_DIM)
151 | for k in range(num_layers):
152 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True)
153 |
154 |
155 | optimizer_comb = torch.optim.Adam(
156 | [{'params': f_x.parameters(),'lr': 0.0001},{'params': inn.parameters(),
157 | 'lr': 0.0001}])
158 | print(sum(p.numel() for p in inn.parameters()))
159 |
160 | startings=tt_tors.clone().detach()
161 |
162 | #Training loop
163 | import timeit
164 | epoch_time=[]
165 | import tqdm
166 | from tqdm import trange
167 |
168 | tt_tors = tt_tors.to(device)
169 | #t_d = t_d.to(device)
170 | inn.to(device)
171 | import torchdiffeq
172 | for i in trange(0, 5000):
173 | optimizer_comb.zero_grad()
174 | loss=0.0
175 | start = timeit.default_timer()
176 | eval_nl=linear_val_ode2(tt_tors,t_d)
177 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01}
178 | """
179 | for j in range(len(xy_d_list)):
180 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
181 |
182 | #torchdiffeq.odeint(fx,
183 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
184 | # method='euler')[:,0,:]
185 |
186 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
187 | #loss+=loss_cur
188 | """
189 | #factor
190 | eval_cal=eval_nl[:,:3]
191 | eval_gt=factor*xy_d
192 | eval_gt=xy_d
193 | loss_cur = torch.mean(torch.norm((eval_cal-factor*eval_gt),dim=1))
194 | loss+=loss_cur
195 |
196 | loss.backward()
197 | optimizer_comb.step()
198 | end = timeit.default_timer()
199 | epoch_time.append(end-start)
200 | if(i%100==0):
201 | print('Combined loss:'+str(i)+': '+str(loss))
202 | ep_time=np.array(epoch_time)
203 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}')
204 | #print(f'total: {ep_time.sum() / run_for * target:.2f}')
205 |
206 | torch.save(f_x.state_dict(),'n_stiff_fx.tar')
207 | torch.save(inn.state_dict(),'n_stiff_inn.tar')
208 |
209 | device2='cuda'
210 | tt_tors_n=torch.tensor([[1.0,0.,0.,0.,0.,0.]])
211 | texp = np.logspace ( tmin_logsp, tmax_logsp, num=500 )
212 | tt_tors_interp=torch.tensor(texp)
213 | t_d_exp=torch.log10(torch.tensor(tt_tors_interp))
214 | #eval_nl=linear_val_ode2(tt_tors_n,t_d_exp)
215 | tt_tors_n = tt_tors_n.to(device2)
216 | t_d_exp=t_d_exp.to(device2)
217 | inn=inn.to(device2)
218 | f_x=f_x.to(device2)
219 |
220 | q_time=[]
221 | for i in range(10):
222 | start = timeit.default_timer()
223 | eval_nl=linear_val_ode2(tt_tors_n,t_d_exp)
224 | end = timeit.default_timer()
225 | q_time.append(end-start)
226 | q_time_np=np.array(q_time)
227 | print("mean time:")
228 | print(q_time_np.mean()*1000)
229 | print("std time:")
230 | print(q_time_np.std()*1000)
231 |
232 | eval_nl=linear_val_ode2(tt_tors,t_d_exp)
233 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01}
234 | """
235 | for j in range(len(xy_d_list)):
236 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
237 |
238 | #torchdiffeq.odeint(fx,
239 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
240 | # method='euler')[:,0,:]
241 |
242 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
243 | #loss+=loss_cur
244 | """
245 | #factor
246 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=texp, method = 'LSODA' )
247 | yy_exp=torch.tensor(sol.y.T)
248 | eval_cal=eval_nl[:,:3]
249 | eval_gt=torch.tensor(yy_exp).to(device)
250 |
251 | print('MAE:')
252 | print(torch.mean(torch.norm(eval_nl.detach()[:,:3].to('cpu')-factor*yy_exp.to('cpu'),dim=1)))
253 |
--------------------------------------------------------------------------------
/stiff_ode/run_rober_comp_dopri.py:
--------------------------------------------------------------------------------
1 | from scipy.integrate import solve_ivp
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import platform
5 | def robertson_conserved ( t, y ):
6 | h = np.sum ( y, axis = 0 )
7 | return h
8 |
9 | def robertson_deriv ( t, y ):
10 | y1 = y[0]
11 | y2 = y[1]
12 | y3 = y[2]
13 | dydt = np.zeros(3)
14 | dydt[0] = - 0.04 * y1 + 10000.0 * y2 * y3
15 | dydt[1] = 0.04 * y1 - 10000.0 * y2 * y3 - 30000000.0 * y2 * y2
16 | dydt[2] = + 30000000.0 * y2 * y2
17 | return dydt
18 | from scipy.integrate import solve_ivp
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import platform
22 |
23 | tmin = 0.0
24 | tmin_logsp=-6
25 | tmax_logsp = 6#10000.0
26 | tmax=1000000.0
27 | y0 = np.array ( [ 1.0, 0.0, 0.0 ] )
28 | tspan = np.array ( [ tmin, tmax ] )
29 | t = np.logspace ( tmin_logsp, tmax_logsp, num=50 )
30 | #
31 | # Use the LSODA solver, that is suitable for stiff systems.
32 | #
33 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=t, method = 'LSODA' )
34 |
35 | import torch
36 | import torch.nn as n
37 | import FrEIA.framework as Ff
38 | import FrEIA.modules as Fm
39 |
40 | tt=torch.tensor(np.log10(sol.t))
41 | yy=torch.tensor(sol.y.T)
42 |
43 | seed = 0
44 | torch.manual_seed(seed)
45 | import random
46 | random.seed(seed)
47 | np.random.seed(seed)
48 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]])
49 | device='cpu'
50 | t_d=tt
51 | t_d = t_d.to(device)
52 | xy_d =torch.tensor(yy).to(device)
53 |
54 | xy_d_max=torch.max(xy_d,dim=0).values
55 | factor=1./xy_d_max
56 | hidden_size=1500
57 | num_layers=5
58 |
59 | f_x=n.Sequential(
60 | n.Linear(6, 30),
61 | n.Tanh(),
62 | n.Linear(30, 30),
63 | n.Tanh(),
64 | n.Linear(30, 30),
65 | n.Tanh(),
66 | n.Linear(30, 6)
67 | )
68 |
69 | xy_d_max=torch.max(xy_d,dim=0).values
70 |
71 | def fx(t,x):
72 | return(f_x(x))
73 |
74 | def for_inn(x):
75 | return(inn(x)[0])
76 | def rev_inn(x):
77 | return(inn(x,rev=True)[0])
78 | def rev_mse_inn_eig(rf,x_gt):
79 | return(torch.mean(torch.norm(rf-x_gt,dim=1)))
80 | def linear_val_ode(w_vec,init_v,t_d):
81 | init_v_in=rev_inn(init_v)
82 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d)
83 | ori_shape = eval_lin.shape
84 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))
85 | return(eval_out.reshape(ori_shape))
86 | def linear_val_ode2(init_v,t_d):
87 | init_v_in=rev_inn(init_v)
88 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01}
89 | eval_out=for_inn(eval_lin)
90 | return(eval_out)
91 |
92 |
93 |
94 | seed = 42
95 | torch.manual_seed(seed)
96 | import random
97 | import torchdiffeq
98 | random.seed(seed)
99 | np.random.seed(seed)
100 | f_x2=n.Sequential(
101 | n.Linear(6, 150),
102 | n.Tanh(),
103 | n.Linear(150, 150),
104 | n.Tanh(),
105 | n.Linear(150, 150),
106 | n.Tanh(),
107 | n.Linear(150, 6)
108 | )
109 | device='cpu'
110 | t_d = tt.to(device)
111 | tt_tors=torch.tensor([[1.,0.,0.,0.,0.,0.]])
112 | tt_tors=tt_tors.to(device)
113 | xy_d =torch.tensor(yy).to(device)
114 | xy_d_max=torch.max(xy_d,dim=0).values
115 | factor=1./xy_d_max
116 | def fx2(t,x):
117 | return(f_x2(x))
118 | optimizer = torch.optim.Adam(
119 | [{'params': f_x2.parameters(),'lr': 0.0001}])
120 | import timeit
121 | epoch_time=[]
122 | import tqdm
123 | from tqdm import trange
124 | for i in trange(0, 5000):
125 | optimizer.zero_grad()
126 | loss=0.0
127 | start = timeit.default_timer()
128 | eval_nl=torchdiffeq.odeint(fx2,tt_tors,t_d,atol=1e-5,method='dopri5')[:,0,:]#options={'step_size':0.01}
129 | """
130 | for j in range(len(xy_d_list)):
131 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
132 |
133 | #torchdiffeq.odeint(fx,
134 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
135 | # method='euler')[:,0,:]
136 |
137 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
138 | #loss+=loss_cur
139 | """
140 | #factor
141 | eval_cal=eval_nl[:,:3]
142 | eval_gt=factor*xy_d
143 | loss_cur = torch.mean(torch.norm(eval_cal-eval_gt,dim=1))
144 | loss+=loss_cur
145 |
146 | loss.backward()
147 | optimizer.step()
148 | end = timeit.default_timer()
149 | epoch_time.append(end-start)
150 | if(i%100==0):
151 | print('Combined loss:'+str(i)+': '+str(loss))
152 | ep_time=np.array(epoch_time)
153 | torch.save(f_x2.state_dict(),'f_x2.tar')
154 |
155 |
156 | device2='cuda'
157 | tt_tors_n=torch.tensor([[1.0,0.,0.,0.,0.,0.]])
158 | texp = np.logspace ( tmin_logsp, tmax_logsp, num=500 )
159 | tt_tors_interp=torch.tensor(texp)
160 | t_d_exp=torch.log10(torch.tensor(tt_tors_interp))
161 | #eval_nl=linear_val_ode2(tt_tors_n,t_d_exp)
162 | tt_tors_n = tt_tors_n.to(device2)
163 | t_d_exp=t_d_exp.to(device2)
164 | f_x2=f_x2.to(device2)
165 |
166 | q_time=[]
167 | for i in range(10):
168 | start = timeit.default_timer()
169 | eval_nl=torchdiffeq.odeint(fx2,tt_tors_n,t_d_exp,atol=1e-5,method='dopri5')[:,0,:]
170 | end = timeit.default_timer()
171 | q_time.append(end-start)
172 | q_time_np=np.array(q_time)
173 | print("mean time:")
174 | print(q_time_np.mean()*1000)
175 | print("std time:")
176 | print(q_time_np.std()*1000)
177 |
178 | eval_nl=torchdiffeq.odeint(fx2,tt_tors_n,t_d_exp,atol=1e-5,method='dopri5')[:,0,:]
179 | #eval_nl=torchdiffeq.odeint(fx,tt_tors,t_d,atol=1e-7,method='dopri5')[:,0,:]#options={'step_size':0.01}
180 | """
181 | for j in range(len(xy_d_list)):
182 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
183 |
184 | #torchdiffeq.odeint(fx,
185 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
186 | # method='euler')[:,0,:]
187 |
188 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
189 | #loss+=loss_cur
190 | """
191 | #factor
192 | sol = solve_ivp ( robertson_deriv, tspan, y0,t_eval=texp, method = 'LSODA' )
193 | yy_exp=torch.tensor(sol.y.T)
194 | eval_cal=eval_nl[:,:3]
195 | eval_gt=torch.tensor(yy_exp).to(device)
196 |
197 | print('MAE:')
198 | print(torch.mean(torch.norm(eval_nl.detach()[:,:3].to('cpu')-factor*yy_exp.to('cpu'),dim=1)))
199 |
--------------------------------------------------------------------------------
/systems/LV_run_theirs.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | import matplotlib.pyplot as pl
5 | import torchdiffeq
6 | import torch.nn as n
7 | import FrEIA.framework as Ff
8 | import FrEIA.modules as Fm
9 | from mpl_toolkits import mplot3d
10 |
11 | ##########################################################3
12 | # system dynamics
13 |
14 | tt=torch.tensor([0.,5.])
15 | t_d=torch.linspace(0,7,700)
16 |
17 | devicestr="cpu"
18 | device_str="cuda"
19 |
20 | def test_fun(t,x):
21 | A=0.75
22 | B=0.75
23 | C=0.75
24 | D=0.75
25 | E=0.75
26 | F=0.75
27 | G=0.75
28 | vel=torch.zeros((3,1))
29 | vel[0]=x[0]*(A-B*x[1])
30 | vel[1]=x[1]*(-C+D*x[0]-E*x[2])
31 | vel[2]=x[2]*(-F+G*x[1])
32 | return(vel)
33 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.],
34 | [2,2.,2.,0.,0.,0.],
35 | [4,4.,3.,0.,0.,0.],
36 | [3,3.,4.,0.,0.,0.],
37 | [1,1.,5.,0.,0.,0.],
38 | [5,5.,1.,0.,0.,0.],
39 | [2,6.,2.,0.,0.,0.],
40 | [3,1.,4.,0.,0.,0.],
41 | [7,1.,2.,0.,0.,0.],
42 | [6,2.,4.,0.,0.,0.]])
43 |
44 | xy_d_list=[]
45 | for i in range(len(tt_tors)):
46 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3)))
47 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c
48 | xy_d_list.append(xy_d.clone().detach())
49 |
50 | ##########################################################3
51 |
52 | def test_theirs(method):
53 | global tt_tors, t_d
54 | def test_fun(t,x):
55 | A=0.75
56 | B=0.75
57 | C=0.75
58 | D=0.75
59 | E=0.75
60 | F=0.75
61 | G=0.75
62 | vel=torch.zeros((3,1))
63 | vel[0]=x[0]*(A-B*x[1])
64 | vel[1]=x[1]*(-C+D*x[0]-E*x[2])
65 | vel[2]=x[2]*(-F+G*x[1])
66 | return(vel)
67 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.],
68 | [2,2.,2.,0.,0.,0.],
69 | [4,4.,3.,0.,0.,0.],
70 | [3,3.,4.,0.,0.,0.],
71 | [1,1.,5.,0.,0.,0.],
72 | [5,5.,1.,0.,0.,0.],
73 | [2,6.,2.,0.,0.,0.],
74 | [3,1.,4.,0.,0.,0.],
75 | [7,1.,2.,0.,0.,0.],
76 | [6,2.,4.,0.,0.,0.]])
77 |
78 | xy_d_list=[]
79 | for i in range(len(tt_tors)):
80 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3)))
81 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c
82 | xy_d_list.append(xy_d.clone().detach())
83 |
84 | if method == 'euler':
85 | weight_fn = 'LV_euler0'
86 | elif method == 'midpoint':
87 | weight_fn = 'LV_mid0'
88 | elif method == 'rk4':
89 | weight_fn = 'LV_rk40'
90 | elif method == 'rk4':
91 | weight_fn = 'LV_rk40'
92 | elif method == 'dopri5':
93 | weight_fn = 'LV_dopri0'
94 | f_x=n.Sequential(
95 | n.Linear(6, 150),
96 | n.Tanh(),
97 | n.Linear(150, 150),
98 | n.Tanh(),
99 | n.Linear(150, 150),
100 | n.Tanh(),
101 | n.Linear(150, 150),
102 | n.Tanh(),
103 | n.Linear(150, 150),
104 | n.Tanh(),
105 | n.Linear(150, 6)
106 | )
107 | def fx(t,x):
108 | return(f_x(x))
109 |
110 | f_x.load_state_dict(torch.load(weight_fn))
111 | import time
112 | time_list=[]
113 |
114 | f_x = f_x.to(device_str)
115 | tt_tors = tt_tors.to(device_str)
116 | t_d = t_d.to(device_str)
117 | for i in range(10):
118 | tic = time.perf_counter()
119 | tx=torchdiffeq.odeint(fx,tt_tors,t_d,method=method,atol=1e-5,rtol=1e-5)
120 | toc = time.perf_counter()
121 | time_list.append(toc-tic)
122 | tx=tx.permute(1,0,2)
123 | #interpolation
124 | sum_num=0
125 |
126 | for i in range(1,len(xy_d_list)):
127 | xy_d=xy_d_list[i]
128 | error=torch.mean(torch.norm(tx[i][:,:3].to('cpu')-xy_d,dim=1)**2)
129 | sum_num+=error
130 |
131 | print(f'Interpolation MSE: {sum_num/len(xy_d_list):.4f}')
132 |
133 | times=np.array(time_list)
134 |
135 | print(method)
136 | print('mean time:')
137 | print(times.mean())
138 | print('time std:')
139 | print(times.std())
140 | t_d_test_tt=torch.linspace(0,7,700)
141 | test_tors_list=[]
142 | for i in range(2,6,1):
143 | for j in range(2,6,1):
144 | test_tors_list.append(torch.tensor([[float(i),float(j),2.,0.,0.,0.]]))
145 |
146 | test_tors=torch.cat(test_tors_list)
147 | test_d_list=[]
148 | for i in range(len(test_tors)):
149 | xy_d=torchdiffeq.odeint(test_fun,test_tors[i,:3][None].T,t_d_test_tt.to('cpu')).reshape((-1,3))
150 | test_d_list.append(xy_d.clone().detach())
151 |
152 | pushed_list_tt=[]
153 | diff_list=[]
154 |
155 | #w_vec=w_vec.to('cuda:0')
156 | #inn=inn.to('cuda:0')
157 | #t_d_test_tt=t_d_test_tt.to('cuda:0')
158 | f_x = f_x.to('cpu')
159 | for j in range(len(test_d_list)):
160 | pushed=torchdiffeq.odeint(fx,test_tors[j],t_d,method=method,atol=1e-5,rtol=1e-5)
161 | pushed_list_tt.append(pushed.clone().detach())
162 | print('Extrapolation Results:')
163 | i=0
164 | xy_d=test_d_list[i]
165 | sum_num=0
166 | for i in range(1,len(test_d_list)):
167 | xy_d=test_d_list[i].to(device_str)
168 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3].to('cpu')-xy_d.to('cpu'),dim=1)**2)
169 | sum_num+=error
170 |
171 | print(f'Extrapolation MSE: {sum_num/len(test_d_list):.4f}')
172 |
173 |
174 | test_theirs('euler')
175 | test_theirs('rk4')
176 | test_theirs('midpoint')
177 | test_theirs('dopri5')
178 | ##########################################################3
179 |
180 |
--------------------------------------------------------------------------------
/systems/LV_train_run_ours.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as pl
4 | import torchdiffeq
5 | import torch.nn as n
6 | import FrEIA.framework as Ff
7 | import FrEIA.modules as Fm
8 |
9 | device = 'cpu'
10 |
11 | xx=torch.tensor([0.,0.,0.])
12 | tt=torch.tensor([0.,5.])
13 | t_d=torch.linspace(0,7,70)
14 | torch.manual_seed(0)
15 | import random
16 | random.seed(0)
17 | np.random.seed(0)
18 | def test_fun(t,x):
19 | A=0.75
20 | B=0.75
21 | C=0.75
22 | D=0.75
23 | E=0.75
24 | F=0.75
25 | G=0.75
26 | vel=torch.zeros((3,1))
27 | vel[0]=x[0]*(A-B*x[1])
28 | vel[1]=x[1]*(-C+D*x[0]-E*x[2])
29 | vel[2]=x[2]*(-F+G*x[1])
30 | return(vel)
31 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.],
32 | [2,2.,2.,0.,0.,0.],
33 | [4,4.,3.,0.,0.,0.],
34 | [3,3.,4.,0.,0.,0.],
35 | [1,1.,5.,0.,0.,0.],
36 | [5,5.,1.,0.,0.,0.],
37 | [2,6.,2.,0.,0.,0.],
38 | [3,1.,4.,0.,0.,0.],
39 | [7,1.,2.,0.,0.,0.],
40 | [6,2.,4.,0.,0.,0.]])
41 |
42 |
43 | def eigen_ode(w_vec,init_x,t_d):
44 | e_val,e_vec=(w_vec[0],w_vec[1:])
45 | init_v=torch.mm(torch.linalg.inv(e_vec),init_x.reshape((-1,1)))
46 | #print(init_v)
47 | int_list=[]
48 |
49 | for i in range(0,6):
50 | int_c=0
51 | for j in range(0,6):
52 | int_c+=init_v[j]*e_vec[i,j]*torch.exp(e_val[j]*t_d)
53 | int_c.reshape((-1,1))
54 | int_list.append(int_c[None].clone())
55 | return(torch.cat(int_list).T)
56 |
57 | def eigen_ode__(w_vec,init_x,t_d):
58 |
59 | e_val,e_vec=(w_vec[0],w_vec[1:])
60 |
61 | _e_vec = e_vec * 1
62 | _t_d = t_d * 1
63 |
64 |
65 | init_v=torch.bmm(torch.inverse(_e_vec).expand(init_x.shape[0],-1,-1),init_x[:, :, None])
66 | rs=torch.bmm((init_v.transpose(1,2)*_e_vec.expand((init_x.shape[0],-1,-1))),
67 | torch.exp(
68 | e_val[:,None] * _t_d
69 | #self._nn(t_d[:, None]).T
70 | )
71 | .expand((init_x.shape[0],-1,-1))).transpose(1,2)
72 | return rs
73 |
74 | def loss_er(x_pred,x_gt):
75 | mse_l=torch.norm(x_pred-x_gt,dim=1)
76 | sum_c=0.0
77 | for i in range(len(mse_l)):
78 | sum_c+=(mse_l[i]**2)#*(1/float(i+1))
79 | return(sum_c/len(mse_l))
80 |
81 |
82 |
83 |
84 |
85 |
86 | def train(hidden_size, num_layers):
87 | print('='*20)
88 | #print(f'hidden_size={hidden_size}, num_layers={num_layers}')
89 | global tt_tors, xy_d_list, t_d
90 |
91 | def for_inn(x):
92 | return(inn(x)[0])
93 | def rev_inn(x):
94 | return(inn(x,rev=True)[0])
95 | def rev_mse_inn_eig(rf,x_gt):
96 | return(torch.mean(torch.norm(rf-x_gt,dim=1)))
97 | def linear_val_ode(w_vec,init_v,t_d):
98 | init_v_in=rev_inn(init_v)
99 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d)
100 | ori_shape = eval_lin.shape
101 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))
102 | return(eval_out.reshape(ori_shape))
103 |
104 | seed = 123
105 | torch.manual_seed(seed)
106 | import random
107 | random.seed(seed)
108 | np.random.seed(seed)
109 |
110 |
111 | e_val=torch.tensor([[0.1,0.1,0.1,0.,0.,0.]])
112 | e_vec=torch.eye(6)
113 | w_vec=torch.cat([e_val,e_vec],dim=0)
114 | w_vec.requires_grad=True
115 |
116 |
117 |
118 | xy_d_list=[]
119 | for i in range(len(tt_tors)):
120 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3)))
121 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T.to('cpu'),t_d.to('cpu')).reshape((-1,3))+noise_c
122 | xy_d_list.append(xy_d.clone().detach())
123 |
124 |
125 | N_DIM = 6
126 | def subnet_fc(dims_in, dims_out):
127 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(),
128 | n.Linear(hidden_size, dims_out))
129 |
130 | inn = Ff.SequenceINN(N_DIM)
131 | for k in range(num_layers):
132 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True)
133 |
134 |
135 | optimizer_comb = torch.optim.Adam(
136 | [{'params': w_vec,'lr': 0.0001},{'params': inn.parameters(),
137 | 'lr': 0.0001}])
138 | print(sum(p.numel() for p in inn.parameters()))
139 | opt=torch.optim.Adam([w_vec],lr=0.005)
140 | startings=tt_tors.clone().detach()
141 | #quick initialise
142 | #for i in range(0,1):
143 | for i in range(0,200):
144 | opt.zero_grad()
145 | loss = 0.
146 | for j in range(len(xy_d_list)):
147 | start_d=startings[j][None]
148 | output = eigen_ode(w_vec,start_d.to('cpu'),t_d.to('cpu'))
149 | loss_cur = loss_er(output[:,:3], xy_d_list[j][:,:3].to('cpu'))
150 | loss+=loss_cur
151 | loss.backward()
152 | opt.step()
153 |
154 | #Training loop
155 | import timeit
156 | epoch_time=[]
157 | import tqdm
158 | from tqdm import trange
159 |
160 | w_vec = w_vec.to(device)
161 | tt_tors = tt_tors.to(device)
162 | #t_d = t_d.to(device)
163 | inn.to(device)
164 | t_d = t_d.to(device)
165 | xy_d_list = torch.stack(xy_d_list).to(device)
166 |
167 |
168 |
169 | for i in trange(0, 5000):
170 | optimizer_comb.zero_grad()
171 | loss=0.0
172 | start = timeit.default_timer()
173 | eval_nl=linear_val_ode(w_vec,tt_tors,t_d)
174 | """
175 | for j in range(len(xy_d_list)):
176 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
177 |
178 | #torchdiffeq.odeint(fx,
179 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
180 | # method='euler')[:,0,:]
181 |
182 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
183 | #loss+=loss_cur
184 | """
185 | loss_cur = rev_mse_inn_eig(eval_nl[...,:3],xy_d_list)
186 | loss+=loss_cur
187 |
188 | loss.backward()
189 | optimizer_comb.step()
190 | end = timeit.default_timer()
191 | epoch_time.append(end-start)
192 | if(i%100==0):
193 | print('Combined loss:'+str(i)+': '+str(loss))
194 | ep_time=np.array(epoch_time)
195 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}')
196 | #print(f'total: {ep_time.sum() / run_for * target:.2f}')
197 |
198 |
199 | ###############################
200 | #Evaluation (Interpolation):
201 | ###############################
202 |
203 | xx=torch.tensor([0.,0.])
204 | tt=torch.tensor([0.,5.])
205 | t_d=torch.linspace(0,7,700)
206 |
207 |
208 | xy_d_list=[]
209 | for i in range(len(tt_tors)):
210 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T.to('cpu'),t_d).reshape((-1,3))
211 | xy_d_list.append(xy_d.clone().detach())
212 |
213 | xy_d_list = torch.stack(xy_d_list).to(device)
214 |
215 | test_d_list = xy_d_list
216 |
217 |
218 | t_d_test_tt=torch.linspace(0,7,700).to(device)
219 |
220 |
221 | import timeit
222 |
223 | pushed_list_tt=[]
224 | diff_list=[]
225 |
226 | #w_vec=w_vec.to('cuda:0')
227 | #inn=inn.to('cuda:0')
228 | #t_d_test_tt=t_d_test_tt.to('cuda:0')
229 |
230 | for j in range(len(tt_tors)):
231 | #test_tors[j]=test_tors[j].to('cuda:0')
232 | start = timeit.default_timer()
233 | #linear_val_ode(w_vec,tt_tors[j],t_d_test_tt)
234 | #pushed=torchdiffeq.odeint(fx,
235 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
236 | # method='euler')[:,0,:]
237 | pushed = linear_val_ode(w_vec,tt_tors[j],t_d_test_tt)
238 | stop=timeit.default_timer()
239 | diff_list.append(stop-start)
240 | pushed_list_tt.append(pushed.clone().detach())
241 |
242 | print('Interpolation:')
243 | print(f'time: {np.mean(np.array(diff_list)):.3f} {np.std(np.array(diff_list)):.3f}')
244 | #pl.figure(figsize=(4,4))
245 |
246 | i=0
247 | xy_d=test_d_list[i]
248 | #pl.plot(xy_d[:,0],xy_d[:,1],c='b',alpha=0.5,label='Ground truth')
249 |
250 | sum_num=0
251 |
252 | for i in range(1,len(xy_d_list)):
253 | xy_d=xy_d_list[i]
254 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3]-xy_d,dim=2)**2)
255 | sum_num+=error
256 |
257 | print(f'Interpolation MSE: {sum_num/len(xy_d_list):.4f}')
258 |
259 |
260 | ###############################
261 | #Evaluation (Extrapolation):
262 | ###############################
263 |
264 |
265 | test_tors_list=[]
266 | for i in range(2,6,1):
267 | for j in range(2,6,1):
268 | test_tors_list.append(torch.tensor([[float(i),float(j),2.,0.,0.,0.]]))
269 |
270 | test_tors=torch.cat(test_tors_list)
271 | test_d_list=[]
272 | for i in range(len(test_tors)):
273 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3)))
274 | xy_d=torchdiffeq.odeint(test_fun,test_tors[i,:3][None].T,t_d_test_tt.to('cpu')).reshape((-1,3))
275 | test_d_list.append(xy_d.clone().detach())
276 |
277 | test_tors=test_tors.to(device)
278 | import timeit
279 |
280 | pushed_list_tt=[]
281 | diff_list=[]
282 |
283 | #w_vec=w_vec.to('cuda:0')
284 | #inn=inn.to('cuda:0')
285 | #t_d_test_tt=t_d_test_tt.to('cuda:0')
286 |
287 | for j in range(len(test_d_list)):
288 | #test_tors[j]=test_tors[j].to('cuda:0')
289 | start = timeit.default_timer()
290 | #linear_val_ode(w_vec,test_tors[j],t_d_test_tt)
291 | pushed = linear_val_ode(w_vec,test_tors[j],t_d_test_tt)
292 | stop=timeit.default_timer()
293 | diff_list.append(stop-start)
294 | pushed_list_tt.append(pushed.clone().detach())
295 | print('Extrapolation Results:')
296 | print(f'time: {np.mean(np.array(diff_list)):.3f} {np.std(np.array(diff_list)):.3f}')
297 |
298 |
299 | i=0
300 | xy_d=test_d_list[i]
301 | sum_num=0
302 | for i in range(1,len(test_d_list)):
303 | xy_d=test_d_list[i].to(device)
304 | error=torch.mean(torch.norm(pushed_list_tt[i][...,:3]-xy_d,dim=2)**2)
305 | sum_num+=error
306 |
307 | print(f'Extrapolation MSE: {sum_num/len(test_d_list):.4f}')
308 |
309 |
310 |
311 | train(1500, 5)
312 |
313 |
--------------------------------------------------------------------------------
/systems/LV_train_theirs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as pl
4 | import torchdiffeq
5 | import torch.nn as n
6 | import FrEIA.framework as Ff
7 | import FrEIA.modules as Fm
8 |
9 | #xx=torch.tensor([0.,0.,0.])
10 | tt=torch.tensor([0.,5.])
11 | t_d=torch.linspace(0,7,70)
12 |
13 | import random
14 |
15 | random.seed(10)
16 | np.random.seed(0)
17 |
18 | device_str='cpu'
19 | def test_fun(t,x):
20 | A=0.75
21 | B=0.75
22 | C=0.75
23 | D=0.75
24 | E=0.75
25 | F=0.75
26 | G=0.75
27 | vel=torch.zeros((3,1))
28 | vel[0]=x[0]*(A-B*x[1])
29 | vel[1]=x[1]*(-C+D*x[0]-E*x[2])
30 | vel[2]=x[2]*(-F+G*x[1])
31 | return(vel)
32 | tt_tors=torch.tensor([[3,3.,1.,0.,0.,0.],
33 | [2,2.,2.,0.,0.,0.],
34 | [4,4.,3.,0.,0.,0.],
35 | [3,3.,4.,0.,0.,0.],
36 | [1,1.,5.,0.,0.,0.],
37 | [5,5.,1.,0.,0.,0.],
38 | [2,6.,2.,0.,0.,0.],
39 | [3,1.,4.,0.,0.,0.],
40 | [7,1.,2.,0.,0.,0.],
41 | [6,2.,4.,0.,0.,0.]])
42 |
43 | xy_d_list=[]
44 | for i in range(len(tt_tors)):
45 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.05*torch.ones((len(t_d),3)))
46 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c
47 | xy_d_list.append(xy_d.clone().detach())
48 |
49 |
50 | for kk in range(1):
51 | f_x=n.Sequential(
52 | n.Linear(6, 150),
53 | n.Tanh(),
54 | n.Linear(150, 150),
55 | n.Tanh(),
56 | n.Linear(150, 150),
57 | n.Tanh(),
58 | n.Linear(150, 150),
59 | n.Tanh(),
60 | n.Linear(150, 150),
61 | n.Tanh(),
62 | n.Linear(150, 6)
63 | )
64 |
65 | def fx(t,x):
66 | return(f_x(x))
67 |
68 |
69 | optimizer = torch.optim.Adam(
70 | [{'params': f_x.parameters(),'lr': 0.0001}])
71 |
72 |
73 | for i in range(0,5000):
74 | loss_all=0.
75 | optimizer.zero_grad()
76 | for j in range(0,len(tt_tors)):
77 | xx=xy_d_list[j]
78 | tx=torchdiffeq.odeint(fx,
79 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5,
80 | method='euler')[:,0,:]
81 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
82 | loss_all+=loss_c
83 | if(i%100==0):
84 | print(str(i)+'euler:'+str(loss_all/10))
85 | loss_all.backward()
86 | optimizer.step()
87 | torch.save(f_x.state_dict(),'LV_euler'+str(kk))
88 |
89 | f_x=n.Sequential(
90 | n.Linear(6, 150),
91 | n.Tanh(),
92 | n.Linear(150, 150),
93 | n.Tanh(),
94 | n.Linear(150, 150),
95 | n.Tanh(),
96 | n.Linear(150, 150),
97 | n.Tanh(),
98 | n.Linear(150, 150),
99 | n.Tanh(),
100 | n.Linear(150, 6)
101 | )
102 |
103 | def fx(t,x):
104 | return(f_x(x))
105 |
106 |
107 | optimizer = torch.optim.Adam(
108 | [{'params': f_x.parameters(),'lr': 0.0001}])
109 |
110 | print('Mid')
111 |
112 | for i in range(0,5000):
113 | loss_all=0.
114 | optimizer.zero_grad()
115 | for j in range(0,len(tt_tors)):
116 | xx=xy_d_list[j]
117 | tx=torchdiffeq.odeint(fx,
118 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5,
119 | method='midpoint')[:,0,:]
120 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
121 | loss_all+=loss_c
122 | if(i%100==0):
123 | print(str(i)+'mid:'+str(loss_all/10))
124 | loss_all.backward()
125 | optimizer.step()
126 | torch.save(f_x.state_dict(),'LV_mid'+str(kk))
127 |
128 | f_x=n.Sequential(
129 | n.Linear(6, 150),
130 | n.Tanh(),
131 | n.Linear(150, 150),
132 | n.Tanh(),
133 | n.Linear(150, 150),
134 | n.Tanh(),
135 | n.Linear(150, 150),
136 | n.Tanh(),
137 | n.Linear(150, 150),
138 | n.Tanh(),
139 | n.Linear(150, 6)
140 | )
141 |
142 | def fx(t,x):
143 | return(f_x(x))
144 |
145 |
146 | optimizer = torch.optim.Adam(
147 | [{'params': f_x.parameters(),'lr': 0.0001}])
148 |
149 | print('RK')
150 | for i in range(0,5000):
151 | loss_all=0.
152 | optimizer.zero_grad()
153 | for j in range(0,len(tt_tors)):
154 | xx=xy_d_list[j]
155 | tx=torchdiffeq.odeint(fx,
156 | tt_tors[j][None],t_d,atol=1e-2,#,rtol=1e-5,
157 | method='rk4')[:,0,:]
158 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
159 | loss_all+=loss_c
160 | if(i%100==0):
161 | print(str(i)+'rk:'+str(loss_all/10))
162 | loss_all.backward()
163 | optimizer.step()
164 |
165 | torch.save(f_x.state_dict(),'LV_rk4'+str(kk))
166 |
167 | f_x=n.Sequential(
168 | n.Linear(6, 150),
169 | n.Tanh(),
170 | n.Linear(150, 150),
171 | n.Tanh(),
172 | n.Linear(150, 150),
173 | n.Tanh(),
174 | n.Linear(150, 150),
175 | n.Tanh(),
176 | n.Linear(150, 150),
177 | n.Tanh(),
178 | n.Linear(150, 6)
179 | )
180 |
181 | def fx(t,x):
182 | return(f_x(x))
183 |
184 |
185 | optimizer = torch.optim.Adam(
186 | [{'params': f_x.parameters(),'lr': 0.0001}])
187 |
188 | print('Dop')
189 |
190 | for i in range(0,5000):
191 | loss_all=0.
192 | optimizer.zero_grad()
193 | for j in range(0,len(tt_tors)):
194 | xx=xy_d_list[j]
195 | tx=torchdiffeq.odeint(fx,
196 | tt_tors[j][None],t_d,rtol=1e-5, atol=1e-5,#,rtol=1e-5,
197 | method='dopri5')[:,0,:]
198 | loss_c=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
199 | loss_all+=loss_c
200 | if(i%100==0):
201 | print(str(i)+'euler:'+str(loss_all/10))
202 | loss_all.backward()
203 | optimizer.step()
204 |
205 | torch.save(f_x.state_dict(),'LV_dopri'+str(kk))
206 |
--------------------------------------------------------------------------------
/systems/f_x_base_save_good_eod2.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/systems/f_x_base_save_good_eod2.tar
--------------------------------------------------------------------------------
/systems/inn2_save_good_eod2.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhi/ODElearning_INN/41aaf87818bae2b340c7063b8d79654344b37c41/systems/inn2_save_good_eod2.tar
--------------------------------------------------------------------------------
/systems/lor_run_all.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | import matplotlib.pyplot as pl
5 | import torchdiffeq
6 | import torch.nn as n
7 | import FrEIA.framework as Ff
8 | import FrEIA.modules as Fm
9 | from mpl_toolkits import mplot3d
10 |
11 |
12 | ##########################################################3
13 | # system dynamics
14 |
15 | xx=torch.tensor([0.,0.,0.])
16 | tt=torch.tensor([0.,5.])
17 | t_d=torch.linspace(0,2,800)
18 | devicestr="cpu"
19 | device_str="cuda"
20 |
21 | def test_fun(t,x):
22 | sig=10.
23 | rho=28.
24 | beta=8/3
25 | vel=torch.zeros((3,1))
26 | vel[0]=sig*(x[1]-x[0])
27 | vel[1]=x[0]*(rho-x[2])-x[1]
28 | vel[2]=x[0]*x[1]-beta*x[2]
29 | return(vel)
30 |
31 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]])
32 | xy_d_list=[]
33 | for i in range(len(tt_tors)):
34 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3)))
35 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))#+noise_c #no noise for test
36 | xy_d_list.append(xy_d.clone().detach())
37 |
38 | pl.figure(figsize=(12,10))
39 | for i in range(len(xy_d_list)):
40 | xy_d=xy_d_list[i]
41 | pl.plot(xy_d[:,0],xy_d[:,1],marker='o')
42 |
43 | xx=xy_d_list[0]
44 |
45 | ##########################################################3
46 |
47 | def test_theirs(method, size):
48 | global tt_tors, t_d
49 | if method == 'euler':
50 | weight_fn = 'lor_euler.tar'
51 | elif method == 'midpoint':
52 | weight_fn = 'lor_mid.tar'
53 | elif method == 'rk4':
54 | weight_fn = 'lor_rk4.tar'
55 | elif method == 'dopri5':
56 | weight_fn = 'lor_dopri.tar'
57 | f_x=n.Sequential(
58 | n.Linear(6, 150),
59 | n.Tanh(),
60 | n.Linear(150, 150),
61 | n.Tanh(),
62 | n.Linear(150, 150),
63 | n.Tanh(),
64 | n.Linear(150, 150),
65 | n.Tanh(),
66 | n.Linear(150, 150),
67 | n.Tanh(),
68 | n.Linear(150, 6)
69 | )
70 | def fx(t,x):
71 | return(f_x(x))
72 |
73 | f_x.load_state_dict(torch.load(weight_fn))
74 | import time
75 | time_list=[]
76 |
77 | f_x = f_x.to(device_str)
78 | tt_tors = tt_tors.to(device_str)
79 | t_d = t_d.to(device_str)
80 | for i in range(10):
81 | tic = time.perf_counter()
82 | tx=torchdiffeq.odeint(fx,tt_tors,t_d,method=method,rtol=1e-5, atol=1e-5)
83 | toc = time.perf_counter()
84 | time_list.append(toc-tic)
85 |
86 | tx=tx.permute(1,0,2)
87 | #interpolation
88 | sum_num=0
89 |
90 | for i in range(1,len(xy_d_list)):
91 | xy_d=xy_d_list[i]
92 | error=torch.mean(torch.norm(tx[i][:,:3].to('cpu')-xy_d,dim=1))
93 | sum_num+=error
94 |
95 | print(f'Interpolation MAE: {sum_num/len(xy_d_list):.4f}')
96 |
97 | times=np.array(time_list)
98 |
99 | print(method)
100 | print('mean:')
101 | print(times.mean())
102 | print('euler std:')
103 | print(times.std())
104 |
105 |
106 | #test_theirs('euler')
107 | #test_theirs('rk4')
108 | #test_theirs('midpoint')
109 |
110 | ##########################################################3
111 |
112 | N_DIM = 6
113 | def subnet_fc(dims_in, dims_out):
114 | return n.Sequential(n.Linear(dims_in, 1500), n.ReLU(),
115 | n.Linear(1500, dims_out))
116 | inn2 = Ff.SequenceINN(N_DIM)
117 | for k in range(5):
118 | inn2.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc)
119 |
120 | f_x_base=n.Sequential(
121 | n.Linear(6, 30),
122 | n.Tanh(),
123 | n.Linear(30, 30),
124 | n.Tanh(),
125 | n.Linear(30, 30),
126 | n.Tanh(),
127 | n.Linear(30, 6)
128 | )
129 | def fx_base(t,x):
130 | return(f_x_base(x))
131 |
132 | def for_inn2(x):
133 | return(inn2(x)[0])
134 | def rev_inn2(x):
135 | return(inn2(x,rev=True)[0])
136 | def linear_val_ode2(init_v,t_d):
137 | init_v_in=rev_inn2(init_v)
138 | eval_lin=torchdiffeq.odeint(fx_base,init_v_in,t_d,
139 | method='euler')[:,0,:]#options={'step_size':0.01}
140 | eval_out=for_inn2(eval_lin)
141 | return(eval_out)
142 |
143 | f_x_base.load_state_dict(torch.load('f_x_base_save_good_eod2.tar'))
144 | inn2.load_state_dict(torch.load('inn2_save_good_eod2.tar'))
145 | f_x_base = f_x_base.to(device_str)
146 | inn2 = inn2.to(device_str)
147 |
148 | f_x_base = f_x_base.to(device_str)
149 | tt_tors = tt_tors.to(device_str)
150 | t_d = t_d.to(device_str)
151 |
152 | time_list=[]
153 | for i in range(10):
154 | tic = time.perf_counter()
155 | tx=linear_val_ode2(tt_tors,t_d)
156 | toc = time.perf_counter()
157 | time_list.append(toc-tic)
158 |
159 |
160 | ours=np.array(time_list)
161 | print('ours time mean:')
162 | print(ours.mean())
163 | print('ours time std:')
164 | print(ours.std())
165 |
166 |
167 | sum_num=0
168 |
169 | for i in range(0,len(tt_tors)):
170 | xy_d=xy_d_list[i]
171 | print(xy_d.shape)
172 | error=torch.mean(torch.norm(tx[:,:3].to('cpu')-xy_d,dim=1))
173 | sum_num+=error
174 |
175 | print(f'Ours Interpolation MAE: {sum_num:.4f}')
176 |
177 | test_theirs('euler')
178 | test_theirs('rk4')
179 | test_theirs('midpoint')
180 |
181 |
--------------------------------------------------------------------------------
/systems/lor_train_ours.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as pl
4 | import torchdiffeq
5 | import torch.nn as n
6 | import FrEIA.framework as Ff
7 | import FrEIA.modules as Fm
8 |
9 | device = 'cpu'
10 | seed = 0
11 | torch.manual_seed(seed)
12 | import random
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | xx=torch.tensor([0.,0.,0.])
16 | tt=torch.tensor([0.,5.])
17 | t_d=torch.linspace(0,2,80)
18 |
19 |
20 | device_str='cpu'
21 | def test_fun(t,x):
22 | sig=10.
23 | rho=28.
24 | beta=8/3
25 | vel=torch.zeros((3,1))
26 | vel[0]=sig*(x[1]-x[0])
27 | vel[1]=x[0]*(rho-x[2])-x[1]
28 | vel[2]=x[0]*x[1]-beta*x[2]
29 | return(vel)
30 |
31 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]])
32 | xy_d_list=[]
33 | for i in range(len(tt_tors)):
34 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3)))
35 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c
36 | xy_d_list.append(xy_d.clone().detach())
37 |
38 |
39 |
40 | def loss_er(x_pred,x_gt):
41 | mse_l=torch.norm(x_pred-x_gt,dim=1)
42 | sum_c=0.0
43 | for i in range(len(mse_l)):
44 | sum_c+=(mse_l[i])#*(1/float(i+1))
45 | return(sum_c/len(mse_l))
46 |
47 |
48 |
49 |
50 |
51 |
52 | def train(hidden_size, num_layers):
53 | print('='*20)
54 | #print(f'hidden_size={hidden_size}, num_layers={num_layers}')
55 | global tt_tors, xy_d_list, t_d
56 | #seed = 123
57 | torch.manual_seed(seed)
58 | import random
59 | random.seed(seed)
60 | np.random.seed(seed)
61 | f_x=n.Sequential(
62 | n.Linear(6, 30),
63 | n.Tanh(),
64 | n.Linear(30, 30),
65 | n.Tanh(),
66 | n.Linear(30, 30),
67 | n.Tanh(),
68 | n.Linear(30, 6)
69 | )
70 | def fx(t,x):
71 | return(f_x(x))
72 |
73 | def for_inn(x):
74 | return(inn(x)[0])
75 | def rev_inn(x):
76 | return(inn(x,rev=True)[0])
77 | def rev_mse_inn_eig(rf,x_gt):
78 | return(torch.mean(torch.norm(rf-x_gt,dim=1)))
79 | def linear_val_ode(w_vec,init_v,t_d):
80 | init_v_in=rev_inn(init_v)
81 | eval_lin=eigen_ode__(w_vec,init_v_in,t_d)
82 | ori_shape = eval_lin.shape
83 | eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))
84 | return(eval_out.reshape(ori_shape))
85 | def linear_val_ode2(init_v,t_d):
86 | init_v_in=rev_inn(init_v)
87 | eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,method='dopri5',atol=1e-5,rtol=1e-5)[:,0,:]#options={'step_size':0.01}
88 | eval_out=for_inn(eval_lin)
89 | return(eval_out)
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | N_DIM = 6
98 | def subnet_fc(dims_in, dims_out):
99 | return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(),
100 | n.Linear(hidden_size, dims_out))
101 |
102 | inn = Ff.SequenceINN(N_DIM)
103 | for k in range(num_layers):
104 | inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True)
105 |
106 |
107 | optimizer_comb = torch.optim.Adam(
108 | [{'params': f_x.parameters(),'lr': 0.0001},{'params': inn.parameters(),
109 | 'lr': 0.0001}])
110 | print(sum(p.numel() for p in inn.parameters()))
111 |
112 | startings=tt_tors.clone().detach()
113 |
114 | #Training loop
115 | import timeit
116 | epoch_time=[]
117 | import tqdm
118 | from tqdm import trange
119 |
120 | tt_tors = tt_tors.to(device)
121 | #t_d = t_d.to(device)
122 | inn.to(device)
123 | t_d = t_d.to(device)
124 | xy_d_list = torch.stack(xy_d_list).to(device)
125 |
126 | print('training INN')
127 | #for i in trange(0, 5000):
128 | for i in trange(0, 5000):
129 | optimizer_comb.zero_grad()
130 | loss=0.0
131 | start = timeit.default_timer()
132 | eval_nl=linear_val_ode2(tt_tors,t_d)
133 | """
134 | for j in range(len(xy_d_list)):
135 | eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)
136 |
137 | #torchdiffeq.odeint(fx,
138 | # tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,
139 | # method='euler')[:,0,:]
140 |
141 | #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])
142 | #loss+=loss_cur
143 | """
144 | loss_cur = torch.mean(torch.norm(eval_nl[:,:3]-xy_d,dim=1))
145 | loss+=loss_cur
146 |
147 | loss.backward()
148 | optimizer_comb.step()
149 | end = timeit.default_timer()
150 | epoch_time.append(end-start)
151 | if(i%100==0):
152 | print('Combined loss:'+str(i)+': '+str(loss))
153 | ep_time=np.array(epoch_time)
154 | #print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}')
155 | #print(f'total: {ep_time.sum() / run_for * target:.2f}')
156 | torch.save(f_x.state_dict(),'f_x_base_save_good_eod2.tar')
157 | torch.save(inn.state_dict(),'inn2_save_good_eod2.tar')
158 |
159 |
160 |
161 |
162 | train(1500, 5)
163 |
--------------------------------------------------------------------------------
/systems/lor_train_theirs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as pl
4 | import torchdiffeq
5 | import torch.nn as n
6 | import FrEIA.framework as Ff
7 | import FrEIA.modules as Fm
8 |
9 | xx=torch.tensor([0.,0.,0.])
10 | tt=torch.tensor([0.,5.])
11 | t_d=torch.linspace(0,2,80)
12 |
13 |
14 | device_str='cpu'
15 | def test_fun(t,x):
16 | sig=10.
17 | rho=28.
18 | beta=8/3
19 | vel=torch.zeros((3,1))
20 | vel[0]=sig*(x[1]-x[0])
21 | vel[1]=x[0]*(rho-x[2])-x[1]
22 | vel[2]=x[0]*x[1]-beta*x[2]
23 | return(vel)
24 |
25 | tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]])
26 | xy_d_list=[]
27 | for i in range(len(tt_tors)):
28 | noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3)))
29 | xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c
30 | xy_d_list.append(xy_d.clone().detach())
31 |
32 | pl.figure(figsize=(12,10))
33 | for i in range(len(xy_d_list)):
34 | xy_d=xy_d_list[i]
35 | pl.plot(xy_d[:,0],xy_d[:,1],marker='o')
36 |
37 | pl.savefig('lorr.png')
38 | xx=xy_d_list[0]
39 | n.Linear(6, 150),
40 | n.Tanh(),
41 | n.Linear(150, 150),
42 | n.Tanh(),
43 | n.Linear(150, 150),
44 | n.Tanh(),
45 | n.Linear(150, 150),
46 | n.Tanh(),
47 | n.Linear(150, 150),
48 | n.Tanh(),
49 | n.Linear(150, 6)
50 | )
51 |
52 | def fx(t,x):
53 | return(f_x(x))
54 |
55 |
56 | optimizer = torch.optim.Adam(
57 | [{'params': f_x.parameters(),'lr': 0.0001}])
58 |
59 | for j in range(0,5000):
60 | optimizer.zero_grad()
61 | tx=torchdiffeq.odeint(fx,
62 | tt_tors,t_d,#,rtol=1e-5,
63 | method='euler')[:,0,:]
64 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
65 | if(j%100==0):
66 | print('euler: '+str(j)+': '+str(loss))
67 | loss.backward()
68 | optimizer.step()
69 | torch.save(f_x.state_dict(),'lor_euler.tar')
70 |
71 | f_x=n.Sequential(
72 | n.Linear(6, 150),
73 | n.Tanh(),
74 | n.Linear(150, 150),
75 | n.Tanh(),
76 | n.Linear(150, 150),
77 | n.Tanh(),
78 | n.Linear(150, 150),
79 | n.Tanh(),
80 | n.Linear(150, 150),
81 | n.Tanh(),
82 | n.Linear(150, 6)
83 | )
84 |
85 | def fx(t,x):
86 | return(f_x(x))
87 |
88 |
89 | optimizer = torch.optim.Adam(
90 | [{'params': f_x.parameters(),'lr': 0.0001}])
91 |
92 | for j in range(0,5000):
93 | optimizer.zero_grad()
94 | tx=torchdiffeq.odeint(fx,
95 | tt_tors,t_d,#,rtol=1e-5,
96 | method='midpoint')[:,0,:]
97 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
98 | if(j%100==0):
99 | print('midpoint: '+str(j)+': '+str(loss))
100 | loss.backward()
101 | optimizer.step()
102 | torch.save(f_x.state_dict(),'lor_mid.tar')
103 |
104 | f_x=n.Sequential(
105 | n.Linear(6, 150),
106 | n.Tanh(),
107 | n.Linear(150, 150),
108 | n.Tanh(),
109 | n.Linear(150, 150),
110 | n.Tanh(),
111 | n.Linear(150, 150),
112 | n.Tanh(),
113 | n.Linear(150, 150),
114 | n.Tanh(),
115 | n.Linear(150, 6)
116 | )
117 |
118 | def fx(t,x):
119 | return(f_x(x))
120 |
121 |
122 | optimizer = torch.optim.Adam(
123 | [{'params': f_x.parameters(),'lr': 0.0001}])
124 |
125 | for j in range(0,5000):
126 | optimizer.zero_grad()
127 | tx=torchdiffeq.odeint(fx,
128 | tt_tors,t_d,#,rtol=1e-5,
129 | method='rk4')[:,0,:]
130 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
131 | if(j%100==0):
132 | print('rk4: '+str(j)+': '+str(loss))
133 |
134 | loss.backward()
135 | optimizer.step()
136 | torch.save(f_x.state_dict(),'lor_rk4.tar')
137 |
138 | f_x=n.Sequential(
139 | n.Linear(6, 150),
140 | n.Tanh(),
141 | n.Linear(150, 150),
142 | n.Tanh(),
143 | n.Linear(150, 150),
144 | n.Tanh(),
145 | n.Linear(150, 150),
146 | n.Tanh(),
147 | n.Linear(150, 150),
148 | n.Tanh(),
149 | n.Linear(150, 6)
150 | )
151 |
152 | def fx(t,x):
153 | return(f_x(x))
154 |
155 |
156 | optimizer = torch.optim.Adam(
157 | [{'params': f_x.parameters(),'lr': 0.0001}])
158 |
159 | for j in range(0,5000):
160 | optimizer.zero_grad()
161 | tx=torchdiffeq.odeint(fx,
162 | tt_tors,t_d,rtol=1e-5, atol=1e-5,
163 | method='dopri5')[:,0,:]
164 | loss=torch.mean(torch.norm(tx[:,:3]-xx,dim=1))#+0.1*torch.mean(torch.norm(tx,dim=1))
165 | if(j%100==0):
166 | print('dopri: '+str(j)+': '+str(loss))
167 | loss.backward()
168 | optimizer.step()
169 | torch.save(f_x.state_dict(),'lor_dopri.tar')
170 |
--------------------------------------------------------------------------------