├── .gitignore
├── LICENSE
├── README.md
├── assets
├── PLACEHOLDER.py
├── crossing_oscillation.png
├── pred_t.png
├── revised_tfm_fig1.png
└── traj_concept.png
├── data
└── toy_data.pkl
├── environment.yml
├── notebook
└── 3Oscillation.ipynb
└── src
├── conf
├── config.yaml
├── data
│ ├── eICU.yaml
│ ├── eICU_ablated.yaml
│ ├── eICU_multdim.yaml
│ └── mimic_liver.yaml
├── model
│ ├── cfm.yaml
│ ├── fm.yaml
│ ├── latentODE.yaml
│ ├── ode.yaml
│ ├── ode_256.yaml
│ ├── sde.yaml
│ ├── sde_256.yaml
│ ├── tfm_ablated_size_memory_uncertainty.yaml
│ ├── tfm_ablated_uncertainty.yaml
│ ├── tfm_ode.yaml
│ ├── tfm_ode_ablated_size_memory_uncertainty.yaml
│ ├── tfm_ode_ablated_uncertainty.yaml
│ └── tfm_sde.yaml
├── perturb_config.yaml
└── trainer.yaml
├── data
└── datamodule.py
├── main.py
├── model
├── FM_baseline.py
├── base_models.py
├── cfm_baseline.py
├── components
│ ├── grad_util.py
│ ├── mlp.py
│ ├── positional_encoding.py
│ └── sde_func_solver.py
├── latent_ode.py
├── mlp_memory.py
├── mlp_noise.py
└── ode_baseline.py
└── utils
├── latent_ode_utils.py
├── loss.py
├── metric_calc.py
├── mmd.py
├── sde.py
├── visualize.py
└── wrapper.py
/.gitignore:
--------------------------------------------------------------------------------
1 | src/outputs/*
2 | outputs/*
3 | *pycache*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Xi (Nicole) Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Trajctory Flow Matching
4 |
9 |
10 | [](http://arxiv.org/abs/2410.21154)
11 | [](https://pytorch.org/get-started/locally/)
12 | [](https://pytorchlightning.ai/)
13 | [](https://hydra.cc/)
14 | [](https://opensource.org/license/mit)
15 |

16 |
17 |
18 |
19 |
20 | ## Description
21 | Trajectory Flow Matching (TFM) is a method that leverages the flow matching technique from generative modeling to model time series. This approach offers a simulation-free training process, allowing for efficient fitting of stochastic differential equations to time-series data. Augmented with memory, time interval prediction, and uncertainty prediction, TFM can better model irregularly sampled trajectories with stochastic nature, for example clinical time series.
22 |
23 |
24 |
25 |
26 |
27 | The idea of TFM lies in using flow matching concept to predict both stochastic uncertainty and the next value in the time series. The prediction is conditioned on past data and conditional variables.
28 |
29 |
30 |
31 | ## How to run
32 |
37 |
38 | ### Initialize environment
39 | ```bash
40 | # clone project
41 | git clone https://github.com/nZhangx/TrajectoryFlowMatching.git
42 | cd TrajectoryFlowMatching
43 |
44 | # [OPTIONAL] create conda environment
45 | conda create -n tfm python=3.10
46 | conda activate tfm
47 |
48 | # install requirements
49 | conda env create -f environment.yml
50 | ```
51 |
52 | ### Run experiments
53 | Under `src`, create new a `DATA_NAME.yml` under `conf/data` and a `MODEL_NAME.yml` under `conf/model` with desired configurations. Then replace `data` and `model` definitions in `conf/config.yaml` with your `DATA_NAME` and `MODEL_NAME`. Then run
54 | ```bash
55 | python src/main.py
56 | ```
57 |
58 | ### Demo
59 | We have included an example of TFM modeling three crossing oscillations in a self-contained Jupyter notebook `notebook/3Oscillation.ipynb`.
60 |
61 | ### Implemented models
62 |
63 | - TFM and ablations (size ablated, uncertainty ablated)
64 | - Aligned Flow Matching [[Liu et al., 2023](https://arxiv.org/abs/2302.05872)][[Somnath et al., 2023](https://arxiv.org/abs/2302.11419)]
65 | - NeuralODE [[Chen et al., 2018](https://arxiv.org/abs/1806.07366)]
66 | - NeuralSDE [[Li et al., 2020](https://arxiv.org/abs/2001.01328)] [[Kidger et al., 2021](https://arxiv.org/abs/2102.03657)]
67 | - LatentODE [[Rubanova et al. 2019](https://arxiv.org/abs/1907.03907)]
68 |
69 | | | ICU Sepsis | ICU Cardiac Arrest | ICU GIB | ED GIB |
70 | |-------------------------------|----------------------------|----------------------------|----------------------------|-----------------------------|
71 | | NeuralODE | 4.776 $\pm$ 0.000 | 6.153 $\pm$ 0.000 | 3.170 $\pm$ 0.000 | 10.859 $\pm$ 0.000 |
72 | | FM baseline ODE | 4.671 $\pm$ 0.791 | 10.207 $\pm$ 1.076 | 118.439 $\pm$ 17.947 | 11.923 $\pm$ 1.123 |
73 | | LatentODE-RNN | 61.806 $\pm$ 46.573 | 386.190 $\pm$ 558.140 | 422.886 $\pm$ 431.954 | 980.228 $\pm$ 1032.393 |
74 | | **TFM-ODE (ours)** | **0.793 $\pm$ 0.017** | 2.762 $\pm$ 0.021 | 2.673 $\pm$ 0.069 | **8.245 $\pm$ 0.495** |
75 | ||
76 | | NeuralSDE | 4.747 $\pm$ 0.000 | 3.250 $\pm$ 0.024 | 3.186 $\pm$ 0.000 | 10.850 $\pm$ 0.043 |
77 | | **TFM (ours)** | 0.796 $\pm$ 0.026 | **2.755 $\pm$ 0.015** | **2.596 $\pm$ 0.079** | 8.613 $\pm$ 0.260 |
78 |
79 |
80 | ### Available datasets
81 |
82 | We plan to share the clinical data we used that are from the [eICU Collaborative Research Database v2.0](https://physionet.org/content/eicu-crd/2.0/) (ICU sepsis and ICU Cardiac Arrest) and the [Medical Information Mart for Intensive Care III (MIMIC-III) critical care database](https://physionet.org/content/mimiciii/1.4/) (ICU GIB) on [Physionet](https://physionet.org/).
83 |
84 | ## How to cite
85 |
86 | This repository contains the code to reproduce the main experiments and illustrations of the preprint [Trajectory Flow Matching with Applications to
87 | Clinical Time Series Modeling](https://arxiv.org/abs/2410.21154). We are excited that it was marked as a **spotlight** presentation.
88 |
89 | If you find this code useful in your research, please cite (expand for BibTeX):
90 |
91 |
92 |
93 | bibtex citation
94 |
95 |
96 | ```bibtex
97 | @article{TFM,
98 | title = {Trajectory Flow Matching with Applications to Clinical Time Series Modelling},
99 | author = {Zhang, Xi and Pu, Yuan and Kawamura, Yuki and Loza, Andrew and Bengio, Yoshua and Shung, Dennis and Tong, Alexander},
100 | year = 2024,
101 | journal = {NeurIPS},
102 | }
103 | ```
104 |
105 |
106 |
107 |
108 | ## License
109 | This repo is licensed under the [MIT License](https://opensource.org/license/mit).
110 |
--------------------------------------------------------------------------------
/assets/PLACEHOLDER.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/PLACEHOLDER.py
--------------------------------------------------------------------------------
/assets/crossing_oscillation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/crossing_oscillation.png
--------------------------------------------------------------------------------
/assets/pred_t.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/pred_t.png
--------------------------------------------------------------------------------
/assets/revised_tfm_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/revised_tfm_fig1.png
--------------------------------------------------------------------------------
/assets/traj_concept.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/traj_concept.png
--------------------------------------------------------------------------------
/data/toy_data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/data/toy_data.pkl
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: ti-env
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1
7 | - _openmp_mutex=4.5
8 | - asttokens=2.4.1
9 | - bzip2=1.0.8
10 | - ca-certificates=2024.6.2
11 | - comm=0.2.2
12 | - debugpy=1.6.7
13 | - decorator=5.1.1
14 | - executing=2.0.1
15 | - importlib-metadata=7.2.1
16 | - importlib_metadata=7.2.1
17 | - ipykernel=6.29.4
18 | - ipython=8.25.0
19 | - jedi=0.19.1
20 | - jupyter_client=8.6.2
21 | - jupyter_core=5.7.2
22 | - ld_impl_linux-64=2.38
23 | - libffi=3.4.4
24 | - libgcc-ng=13.2.0
25 | - libgomp=13.2.0
26 | - libsodium=1.0.18
27 | - libstdcxx-ng=11.2.0
28 | - libuuid=1.41.5
29 | - matplotlib-inline=0.1.7
30 | - ncurses=6.4
31 | - nest-asyncio=1.6.0
32 | - openssl=3.3.1
33 | - packaging=24.1
34 | - parso=0.8.4
35 | - pexpect=4.9.0
36 | - pickleshare=0.7.5
37 | - pip=24.0
38 | - platformdirs=4.2.2
39 | - prompt-toolkit=3.0.47
40 | - ptyprocess=0.7.0
41 | - pure_eval=0.2.2
42 | - pygments=2.18.0
43 | - python=3.10.14
44 | - python_abi=3.10
45 | - pyzmq=25.1.2
46 | - readline=8.2
47 | - setuptools=69.5.1
48 | - six=1.16.0
49 | - sqlite=3.45.3
50 | - stack_data=0.6.2
51 | - tk=8.6.14
52 | - tornado=6.4.1
53 | - traitlets=5.14.3
54 | - typing_extensions=4.12.2
55 | - wcwidth=0.2.13
56 | - wheel=0.43.0
57 | - xz=5.4.6
58 | - zeromq=4.3.5
59 | - zipp=3.19.2
60 | - zlib=1.2.13
61 | - pip:
62 | - aiohttp==3.9.5
63 | - aiosignal==1.3.1
64 | - alembic==1.13.1
65 | - anndata==0.10.7
66 | - antlr4-python3-runtime==4.9.3
67 | - array-api-compat==1.7.1
68 | - async-timeout==4.0.3
69 | - attrs==23.2.0
70 | - autopage==0.5.2
71 | - black==24.4.2
72 | - certifi==2024.6.2
73 | - cfgv==3.4.0
74 | - charset-normalizer==3.3.2
75 | - click==8.1.7
76 | - cliff==4.7.0
77 | - cloudpickle==3.0.0
78 | - cmaes==0.10.0
79 | - cmd2==2.4.3
80 | - colorlog==6.8.2
81 | - contourpy==1.2.1
82 | - cycler==0.12.1
83 | - distlib==0.3.8
84 | - docker-pycreds==0.4.0
85 | - exceptiongroup==1.2.1
86 | - filelock==3.15.1
87 | - fire==0.6.0
88 | - flake8==7.0.0
89 | - flake8-pyproject==1.2.3
90 | - fonttools==4.53.0
91 | - frozenlist==1.4.1
92 | - fsspec==2024.6.0
93 | - gitdb==4.0.11
94 | - gitpython==3.1.43
95 | - greenlet==3.0.3
96 | - h5py==3.11.0
97 | - huggingface-hub==0.23.4
98 | - hydra-colorlog==1.2.0
99 | - hydra-core==1.2.0
100 | - hydra-optuna-sweeper==1.2.0
101 | - hydra-submitit-launcher==1.2.0
102 | - identify==2.5.36
103 | - idna==3.7
104 | - iniconfig==2.0.0
105 | - ipywidgets==8.1.3
106 | - isort==5.13.2
107 | - joblib==1.4.2
108 | - jupyterlab-widgets==3.0.11
109 | - kiwisolver==1.4.5
110 | - legacy-api-wrap==1.4
111 | - lightning-bolts==0.6.0.post1
112 | - lightning-utilities==0.3.0
113 | - llvmlite==0.43.0
114 | - mako==1.3.5
115 | - markdown-it-py==3.0.0
116 | - markupsafe==2.1.5
117 | - matplotlib==3.9.0
118 | - mccabe==0.7.0
119 | - mdurl==0.1.2
120 | - multidict==6.0.5
121 | - mypy-extensions==1.0.0
122 | - natsort==8.4.0
123 | - networkx==3.3
124 | - nodeenv==1.9.1
125 | - numba==0.60.0
126 | - numpy==1.26.4
127 | - omegaconf==2.3.0
128 | - optuna==2.10.1
129 | - pandas==2.0.3
130 | - pastel==0.2.1
131 | - pathspec==0.12.1
132 | - patsy==0.5.6
133 | - pbr==6.0.0
134 | - pillow==10.3.0
135 | - pluggy==1.5.0
136 | - poethepoet==0.10.0
137 | - pot==0.9.3
138 | - pre-commit==3.7.1
139 | - prettytable==3.10.0
140 | - protobuf==5.27.1
141 | - psutil==5.9.8
142 | - pycodestyle==2.11.1
143 | - pyflakes==3.2.0
144 | - pynndescent==0.5.12
145 | - pyparsing==3.1.2
146 | - pyperclip==1.8.2
147 | - pyrootutils==1.0.4
148 | - pytest==8.2.2
149 | - python-dateutil==2.9.0.post0
150 | - python-dotenv==1.0.1
151 | # - python-graphviz==0.20.3
152 | - pytorch-lightning==1.8.3
153 | - pytz==2024.1
154 | - pyyaml==6.0.1
155 | - requests==2.32.3
156 | - rich==13.7.1
157 | - safetensors==0.4.3
158 | - scanpy==1.10.1
159 | - scikit-learn==1.5.0
160 | - scipy==1.13.1
161 | - scprep==1.2.3
162 | - seaborn==0.13.2
163 | - sentry-sdk==2.5.1
164 | - session-info==1.0.0
165 | - setproctitle==1.3.3
166 | - smmap==5.0.1
167 | - sqlalchemy==2.0.30
168 | - statsmodels==0.14.2
169 | - stdlib-list==0.10.0
170 | - stevedore==5.2.0
171 | - submitit==1.5.1
172 | - tensorboardx==2.6.2.2
173 | - termcolor==2.4.0
174 | - threadpoolctl==3.5.0
175 | - timm==1.0.3
176 | - tomli==2.0.1
177 | - tomlkit==0.12.5
178 | - torch==1.12.1+cu113
179 | - torchaudio==0.12.1+cu113
180 | - torchcde==0.2.5
181 | # - torchcubicspline==0.0.3
182 | - torchdiffeq==0.2.4
183 | - torchdyn==1.0.6
184 | - torchmetrics==0.11.0
185 | - torchsde==0.2.6
186 | - torchvision==0.13.1+cu113
187 | - torchviz==0.0.2
188 | - tqdm==4.66.4
189 | - trampoline==0.1.2
190 | - tzdata==2024.1
191 | - umap-learn==0.5.6
192 | - urllib3==2.2.1
193 | - virtualenv==20.26.2
194 | - wandb==0.17.1
195 | - widgetsnbextension==4.0.11
196 | - yarl==1.9.4
197 |
--------------------------------------------------------------------------------
/src/conf/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - data: eICU
4 | - model: tfm_sde
5 |
6 | wandb_logging: true
7 | wandb_project: clinical_trajectory
8 | wandb_dir: wandb_log/
9 | ckpt_dir: checkpoints/
10 | max_epochs: 200
11 | max_time: 00:48:00:00
12 | check_val_every_n_epoch: 10
13 | skip_training: false
14 | mode: None
15 | seed: 42
--------------------------------------------------------------------------------
/src/conf/data/eICU.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | data_module:
3 | _target_: data.datamodule.clinical_DataModule
4 | train_consecutive: true
5 | file_path: data/toy_data.pkl
6 | naming: eICU_DataModule_J30_v1
7 | t_headings: time_scaled_v1
8 | x_headings:
9 | - hr_normalized
10 | - map_normalized
11 | cond_headings:
12 | - apache_outcome_prob
13 | - norepi_inf_scaled
14 | memory: 0
15 |
--------------------------------------------------------------------------------
/src/conf/data/eICU_ablated.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | data_module:
3 | _target_: data.datamodule.clinical_DataModule
4 | train_consecutive: true
5 | file_path: data/toy_data.pkl
6 | naming: eICU_DataModule_J30_v1
7 | t_headings: time_scaled_v1
8 | x_headings:
9 | - hr_normalized
10 | - map_normalized
11 | cond_headings:
12 | - apache_outcome_prob
13 | memory: 0
--------------------------------------------------------------------------------
/src/conf/data/eICU_multdim.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | data_module:
3 | _target_: data.datamodule.clinical_DataModule
4 | train_consecutive: true
5 | file_path: eICU_CART_downsampled.pkl
6 | naming: eICU-CART_DataModule_APR16
7 | t_headings: time_scaled_v1
8 | x_headings:
9 | - hr_normalized_scaled
10 | - dbp_normalized_scaled
11 | - rr_normalized_scaled
12 | cond_headings:
13 | - AGE_AT_ADM_normalized
14 | memory: 0
--------------------------------------------------------------------------------
/src/conf/data/mimic_liver.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | data_module:
3 | _target_: data.mimic_liver_data.MIMICLiverDataModule
4 | naming: MIMICLiverDataModule
5 | train_consecutive: true
6 | memory: 0
7 | x_headings:
8 | - 1
9 | - MAP
10 | cond_headings:
11 | - prbc_outcome
12 | - pressor
13 | - bloodprod
14 | - severe_liver
15 |
--------------------------------------------------------------------------------
/src/conf/model/cfm.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.cfm_baseline.MLP_CFM
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | clip: 1e-2
8 | sigma: 0.1
9 | dim: 2
10 | metrics:
11 | - variance_dist
12 | - mse_loss
13 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/fm.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.FM_baseline.MLP_FM
4 | time_varying: true
5 | conditional: false
6 | treatment_cond: 0
7 | clip: 1e-2
8 | sigma: 0.1
9 | dim: 2
10 | metrics:
11 | - variance_dist
12 | - mse_loss
13 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/latentODE.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.latent_ode.LatentODE_pl
4 | input_dim: 2
5 | latent_dim: 2
6 | output_dim: 2
7 | metrics:
8 | - variance_dist
9 | - mse_loss
10 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/ode.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.ode_baseline.ODEBaseline
4 | dim: 2
5 | w: 64
6 | lr: 1e-5
7 | metrics:
8 | - variance_dist
9 | - mse_loss
10 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/ode_256.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.ode_baseline.ODEBaseline
4 | dim: 2
5 | w: 256
6 | lr: 1e-5
7 | metrics:
8 | - variance_dist
9 | - mse_loss
10 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/sde.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.ode_baseline.SDEBaseline
4 | dim: 2
5 | w: 64
6 | lr: 1e-5
7 | metrics:
8 | - variance_dist
9 | - mse_loss
10 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/sde_256.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.ode_baseline.SDEBaseline
4 | dim: 2
5 | w: 256
6 | lr: 1e-5
7 | metrics:
8 | - variance_dist
9 | - mse_loss
10 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/tfm_ablated_size_memory_uncertainty.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | dim: 2
9 | implementation: SDE
10 | sde_noise: 0.1
11 | clip: 1e-2
12 | sigma: 0.1
13 | metrics:
14 | - variance_dist
15 | - mse_loss
16 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/tfm_ablated_uncertainty.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | w: 256
9 | implementation: SDE
10 | clip: 1e-2
11 | sigma: 0.1
12 | dim: 2
13 | metrics:
14 | - variance_dist
15 | - mse_loss
16 | - l1_loss
--------------------------------------------------------------------------------
/src/conf/model/tfm_ode.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_noise.Noise_MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | dim: 2
9 | w: 256
10 | clip: 1e-2
11 | sigma: 0.1
12 | metrics:
13 | - variance_dist
14 | - mse_loss
15 | - l1_loss
16 |
--------------------------------------------------------------------------------
/src/conf/model/tfm_ode_ablated_size_memory_uncertainty.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | dim: 2
9 | clip: 1e-2
10 | sigma: 0.1
11 | metrics:
12 | - variance_dist
13 | - mse_loss
14 | - l1_loss
15 |
--------------------------------------------------------------------------------
/src/conf/model/tfm_ode_ablated_uncertainty.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | w: 256
9 | clip: 1e-2
10 | sigma: 0.1
11 | dim: 2
12 | metrics:
13 | - variance_dist
14 | - mse_loss
15 | - l1_loss
16 |
--------------------------------------------------------------------------------
/src/conf/model/tfm_sde.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model_module:
3 | _target_: model.mlp_noise.Noise_MLP_Cond_Memory_Module
4 | time_varying: true
5 | conditional: true
6 | treatment_cond: 0
7 | memory: 3
8 | dim: 2
9 | w: 256
10 | clip: 1e-2
11 | sigma: 0.1
12 | metrics:
13 | - variance_dist
14 | - mse_loss
15 | - l1_loss
16 | implementation: SDE
--------------------------------------------------------------------------------
/src/conf/perturb_config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - data: eICU
4 | - model: noise_mlp_memory_sde
5 |
6 |
7 | wandb_logging: true
8 | wandb_project: clinical_trajectory
9 | wandb_dir: wandb_log/
10 | ckpt_dir: checkpoints/
11 | max_epochs: 200
12 | max_time: 00:48:00:00
13 | check_val_every_n_epoch: 10
14 | skip_training: false
15 | mode: None
16 | seed: 42
--------------------------------------------------------------------------------
/src/conf/trainer.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | max_epochs: 20000
3 | check_val_every_n_epoch: 100
4 | gpus: 1
5 | precision: 32
6 |
--------------------------------------------------------------------------------
/src/data/datamodule.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import pandas as pd
3 | import pickle
4 | import numpy as np
5 |
6 | # from pytorch_lightning.utilities.types import EVAL_DATALOADERS
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 |
10 | """Helper class for specific training/testing dataset"""
11 | class TrainingDataset(Dataset):
12 | def __init__(self, x0_values, x0_classes, x1_values, times_x0, times_x1):
13 | self.x0_values = x0_values
14 | self.x0_classes = x0_classes
15 | self.x1_values = x1_values
16 | self.times_x0 = times_x0
17 | self.times_x1 = times_x1
18 |
19 | def __len__(self):
20 | return len(self.x0_values)
21 |
22 | def __getitem__(self, idx):
23 | return (self.x0_values[idx], self.x0_classes[idx], self.x1_values[idx], self.times_x0[idx], self.times_x1[idx])
24 |
25 | class PatientDataset(Dataset):
26 | def __init__(self, patient_data):
27 | self.patient_data = patient_data
28 |
29 | def __len__(self):
30 | return len(self.patient_data)
31 |
32 | def __getitem__(self, idx):
33 | return self.patient_data[idx]
34 |
35 | class clinical_DataModule(pl.LightningDataModule):
36 |
37 | """returns:
38 | x0_values, x0_classes, x1_values, times_x0, times_x1
39 | """
40 |
41 | def __init__(self,
42 | train_consecutive=False,
43 | batch_size=256,
44 | file_path=None,
45 | naming = None,
46 | t_headings = None,
47 | x_headings = [None],
48 | cond_headings = [None],
49 | memory=0):
50 | super().__init__()
51 | self.batch_size = batch_size
52 | self.file_path = file_path
53 | self.x_headings = x_headings
54 | self.cond_headings = cond_headings
55 | self.t_headings = t_headings
56 | self.input_dim = len(self.x_headings) + len(self.cond_headings)
57 | self.output_dim = len(self.x_headings)
58 | self.naming = naming
59 | self.memory = memory
60 | self.min_timept = 5 + self.memory
61 | self.train_consecutive = train_consecutive
62 | print("DataModule initialized to x_headings: ", self.x_headings, " cond_headings: ", self.cond_headings, " t_headings: ", self.t_headings, " train_consecutive: ", self.train_consecutive)
63 |
64 | def prepare_data(self):
65 | pass
66 |
67 | def __filter_data(self, data_set):
68 | # filter out data with less than 5 time points
69 | return data_set.groupby('HADM_ID').filter(lambda x: len(x) > self.min_timept)
70 |
71 | def __unpack__(self, data_set):
72 | x = data_set[self.x_headings].values
73 | cond = data_set[self.cond_headings].values
74 | t = data_set[self.t_headings].values
75 | return x, cond, t
76 |
77 | def __getitem__(self, idx):
78 | sample = self.data.iloc[idx].values.astype(np.float32)
79 | return sample
80 |
81 | def setup(self, stage=None):
82 | self.data = pd.read_pickle(self.file_path)
83 | if stage == 'fit' or stage is None:
84 | self.train = self.__filter_data(self.data['train'])
85 | self.val = self.__filter_data(self.data['val'])
86 | if stage == 'test' or stage is None:
87 | self.test = self.__filter_data(self.data['test'])
88 |
89 | def __sort_group__(self, data_set):
90 | grouped = data_set.groupby('HADM_ID')
91 | grouped_sorted = grouped.apply(lambda x: x.sort_values([self.t_headings], ascending = True)).reset_index(drop=True)
92 | return grouped_sorted
93 |
94 | def train_dataloader(self, shuffle=True):
95 | if not self.train_consecutive:
96 | train_data = self.create_pairs(self.train)
97 | # print(len(train_data))
98 | train_dataset = TrainingDataset(*train_data)
99 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=1)
100 | else:
101 | train_data = self.create_patient_data_t0(self.train)
102 | train_dataset = PatientDataset(train_data)
103 | return DataLoader(train_dataset, batch_size=1, shuffle=shuffle, num_workers=1)
104 |
105 | def val_dataloader(self):
106 | if self.train_consecutive:
107 | val_data = self.create_patient_data_t0(self.val)
108 | val_dataset = PatientDataset(val_data)
109 | return DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
110 | val_data = self.create_patient_data(self.val)
111 | val_dataset = PatientDataset(val_data)
112 | return DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
113 |
114 | def test_dataloader(self):
115 | if self.train_consecutive:
116 | test_data = self.create_patient_data_t0(self.test)
117 | test_dataset = PatientDataset(test_data)
118 | return DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
119 | test_data = self.create_patient_data(self.test)
120 | test_dataset = PatientDataset(test_data)
121 | return DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
122 |
123 | def create_patient_data(self, df):
124 | """Create formatted patient data from the DataFrame
125 | ex. patient 0: x0_values, x0_classes, times_x0...
126 |
127 | Args:
128 | df (_type_): _description_
129 | time_column (str, optional): _description_. Defaults to 'time_normalized'.
130 | """
131 | patient_lst = []
132 | for _, group in df.groupby('HADM_ID'):
133 |
134 | x0_values = []
135 | x0_classes = []
136 | x1_values = []
137 | times_x0 = []
138 | times_x1 = []
139 |
140 | sorted_group = group.sort_values(by=self.t_headings)
141 | x0_values, x0_classes, x1_values, times_x0, times_x1 = self.create_pairs(sorted_group)
142 |
143 | if len(self.cond_headings)<2:
144 | x0_classes = np.expand_dims(x0_classes, axis=1)
145 | else:
146 | x0_classes = x0_classes.squeeze().astype(np.float32)
147 |
148 |
149 | patient_lst.append((x0_values.squeeze().astype(np.float32),
150 | x0_classes,
151 | x1_values.squeeze().astype(np.float32),
152 | times_x0.squeeze().astype(np.float32),
153 | times_x1.squeeze().astype(np.float32)))
154 | return patient_lst
155 |
156 |
157 | def create_pairs(self, df):
158 | """create pairs of consecutive points from the DataFrame (for training the model)
159 |
160 | Args:
161 | df (pandas.DataFrame): _description_
162 | time_column (str, optional): _description_. Defaults to 'time_normalized'.
163 |
164 | Returns:
165 | numpy.array : x0_values, x0_classes, x1_values, times_x0, times_x1
166 |
167 | """
168 | # Initialize empty lists to store the components of the pairs
169 | x0_values = []
170 | x0_classes = []
171 | x1_values = []
172 | times_x0 = []
173 | times_x1 = []
174 |
175 | # Group the DataFrame by HADM_ID and iterate through each group
176 | for _, group in df.groupby('HADM_ID'):
177 | # Sort the group by time_normalized
178 | sorted_group = group.sort_values(by=self.t_headings)
179 |
180 | # Iterate through the sorted group to create pairs of consecutive points
181 | for i in range(self.memory,len(sorted_group) - 1):
182 | x0 = sorted_group.iloc[i]
183 | x0_class = x0[self.cond_headings].values
184 | x0_value = x0[self.x_headings].values
185 |
186 | x1 = sorted_group.iloc[i + 1]
187 | x1_value = x1[self.x_headings].values
188 |
189 | # memory component
190 | if self.memory>0:
191 | x0_memory = sorted_group.iloc[i - self.memory:i]
192 | x0_memory_flatten = x0_memory[self.x_headings].values.flatten()
193 | x0_class = np.append(x0_class, x0_memory_flatten)
194 |
195 | x0_values.append(x0_value)
196 | x0_classes.append(x0_class)
197 | x1_values.append(x1_value)
198 | times_x0.append(x0[self.t_headings])
199 | times_x1.append(x1[self.t_headings])
200 |
201 | # Convert the lists to arrays
202 | x0_values = np.array(x0_values).squeeze().astype(np.float32)
203 | x0_classes = np.array(x0_classes).squeeze().astype(np.float32)
204 | x1_values = np.array(x1_values).squeeze().astype(np.float32)
205 | times_x0 = np.array(times_x0).squeeze().astype(np.float32)
206 | times_x1 = np.array(times_x1).squeeze().astype(np.float32)
207 |
208 | if len(self.cond_headings)<2:
209 | x0_classes = np.expand_dims(x0_classes, axis=1)
210 |
211 | return x0_values, x0_classes, x1_values, times_x0, times_x1
212 |
213 |
214 | def create_patient_data_t0(self, df):
215 | """Create formatted patient data from the DataFrame
216 | ex. patient 0: x0_values, x0_classes, times_x0...
217 | This version has x0 constant and x1 varying (as well as t)
218 |
219 | Args:
220 | df (_type_): _description_
221 | time_column (str, optional): _description_. Defaults to 'time_normalized'.
222 | """
223 | patient_lst = []
224 | for _, group in df.groupby('HADM_ID'):
225 |
226 | x0_values = []
227 | x0_classes = []
228 | x1_values = []
229 | times_x0 = []
230 | times_x1 = []
231 |
232 | sorted_group = group.sort_values(by=self.t_headings)
233 | x0_values, x0_classes, x1_values, times_x0, times_x1 = self.create_pairs(sorted_group)
234 |
235 | if len(self.cond_headings)<2:
236 | x0_classes = np.expand_dims(x0_classes, axis=1)
237 | else:
238 | x0_classes = x0_classes.squeeze().astype(np.float32)
239 |
240 | # repeat the first point for x0_values and x0_classes, and times_x0
241 | x0_values = np.repeat(x0_values[0][None, :], len(x0_values), axis=0)
242 | x0_classes = np.repeat(x0_classes[0][None, :], len(x0_values), axis=0)
243 | times_x0 = np.repeat(times_x0[0], len(x0_values))
244 |
245 | patient_lst.append((x0_values.squeeze().astype(np.float32),
246 | x0_classes,
247 | x1_values.squeeze().astype(np.float32),
248 | times_x0.squeeze().astype(np.float32),
249 | times_x1.squeeze().astype(np.float32)))
250 | return patient_lst
251 |
252 | @property
253 | def dims(self):
254 | # x, cond, t
255 | return len(self.x_headings), len(self.cond_headings), len(self.t_headings)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # training script for lightning model
2 | """
3 | run via:
4 | (changing data and model)
5 | python main.py data=data/data3.yaml model=model/model3.yaml
6 |
7 | to test:
8 | python main.py skip_training=true
9 | """
10 | import pytorch_lightning as pl
11 | import torch
12 | from torch import nn
13 | import pandas as pd
14 | import numpy as np
15 | import wandb
16 | from pytorch_lightning.loggers import WandbLogger
17 | import os
18 |
19 | import pytorch_lightning as pl
20 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
21 |
22 |
23 | import hydra
24 | from omegaconf import OmegaConf
25 | import pytorch_lightning as pl
26 | from hydra.utils import instantiate
27 |
28 | from pytorch_lightning.strategies import DDPStrategy
29 |
30 | @hydra.main(config_path="conf", config_name="config")
31 | def train_model(cfg):
32 | # set seed
33 | pl.seed_everything(cfg.seed)
34 |
35 | print(cfg)
36 | if 'memory' in cfg['model_module'].keys(): # if is key
37 | memory = cfg['model_module']['memory']
38 | cfg.data_module.memory = memory
39 |
40 | data_module = instantiate(cfg.data_module)
41 |
42 | # correct dim
43 | x_dim = data_module.dims[0]
44 | if 'dim' in cfg.model_module.keys():
45 | cfg.model_module.dim = x_dim
46 | elif 'input_dim' in cfg.model_module.keys():
47 | cfg.model_module.input_dim = x_dim
48 | cfg.model_module.output_dim = x_dim
49 |
50 |
51 | if 'treatment_cond' in cfg.model_module.keys():
52 | # for conditional models, need this to configure
53 | cfg.model_module.treatment_cond = len(data_module.cond_headings)
54 |
55 | model = instantiate(cfg.model_module)
56 |
57 | # conditional models need train_consecutive false!
58 | if not('Cond' in model.naming):
59 | cfg.data_module.train_consecutive = True
60 | data_module = instantiate(cfg.data_module)
61 | else:
62 | cfg.data_module.train_consecutive = False
63 | data_module = instantiate(cfg.data_module)
64 |
65 | wandb_config = {key: value for key, value in cfg.model_module.items() if key not in ['_target_']}
66 | wandb_config['model'] = model.naming
67 | wandb_config['data'] = data_module.naming
68 | wandb_config['mode'] = 'batch_run'
69 | wandb_config['x_headings'] = data_module.x_headings
70 | wandb_config['cond_headings'] = data_module.cond_headings
71 | wandb_config['t_headings'] = data_module.t_headings
72 | wandb_config['seed'] = cfg.seed
73 |
74 | if cfg.wandb_logging and not(cfg.skip_training):
75 | wandb.init(project="clinical_trajectory",
76 | dir = '/home/mila/x/xi.zhang/scratch/shung_ICU/wandb_log/',
77 | config = wandb_config
78 | )
79 | wandb_logger = WandbLogger()
80 |
81 | ckpt_savedir = '/home/mila/x/xi.zhang/scratch/shung_ICU/checkpoints/'+model.naming+'_'+data_module.naming+'/'
82 | if not os.path.exists(ckpt_savedir):
83 | os.makedirs(ckpt_savedir)
84 |
85 | checkpoint_callback = ModelCheckpoint(
86 | dirpath='/home/mila/x/xi.zhang/scratch/shung_ICU/checkpoints/'+model.naming+'_'+data_module.naming+'/',
87 | filename='best_model',
88 | save_top_k=1,
89 | verbose=True,
90 | monitor='val_loss',
91 | mode='min',
92 | save_last=True
93 | )
94 |
95 | early_stopping_callback = EarlyStopping(
96 | monitor='val_loss',
97 | patience=3,
98 | verbose=True,
99 | mode='min'
100 | )
101 |
102 | strategy_ddps = DDPStrategy(find_unused_parameters=True)
103 |
104 | trainer = pl.Trainer(
105 | max_epochs=cfg.max_epochs,
106 | max_time=cfg.max_time,
107 | check_val_every_n_epoch=50,
108 | callbacks=[checkpoint_callback, early_stopping_callback],
109 | accelerator='gpu' if torch.cuda.is_available() else 'cpu',
110 | logger=wandb_logger if cfg.wandb_logging else None,
111 | limit_train_batches=0 if cfg.skip_training else 1.0,
112 | strategy=strategy_ddps,
113 | )
114 |
115 | # Train the model
116 | trainer.fit(model, datamodule=data_module)
117 |
118 | # Test the model
119 | trainer.test(model, datamodule=data_module)
120 |
121 | wandb.finish()
122 |
123 | def main():
124 | train_model()
125 |
126 | if __name__ == '__main__':
127 | main()
--------------------------------------------------------------------------------
/src/model/FM_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import ot as pot
8 | import torchdyn
9 | from torchdyn.core import NeuralODE
10 | import pytorch_lightning as pl
11 | from torch import optim
12 | import torch.functional as F
13 | import wandb
14 |
15 | from utils.visualize import *
16 | from utils.metric_calc import *
17 | from utils.sde import SDE
18 |
19 | from model.components.positional_encoding import *
20 | from model.components.grad_util import torch_wrapper_tv
21 |
22 |
23 | class MLP(torch.nn.Module):
24 | def __init__(self, dim, out_dim=None, w=64, time_varying=False):
25 | super().__init__()
26 | self.time_varying = time_varying
27 | if out_dim is None:
28 | out_dim = dim
29 | self.net = torch.nn.Sequential(
30 | torch.nn.Linear(dim + (1 if time_varying else 0), w),
31 | torch.nn.SELU(),
32 | torch.nn.Linear(w, w),
33 | torch.nn.SELU(),
34 | torch.nn.Linear(w, w),
35 | torch.nn.SELU(),
36 | torch.nn.Linear(w, out_dim),
37 | )
38 |
39 | def forward(self, x, *args, **kwargs):
40 | return self.net(x)
41 |
42 |
43 | # conditional liver model
44 | class MLP_conditional_liver(torch.nn.Module):
45 | """ Conditional with many available classes
46 |
47 | return the class as is
48 | """
49 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False):
50 | super().__init__()
51 | self.time_varying = time_varying
52 | if out_dim is None:
53 | self.out_dim = dim
54 | self.treatment_cond = treatment_cond
55 | self.dim = dim
56 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0)
57 | self.net = torch.nn.Sequential(
58 | torch.nn.Linear(self.indim, w),
59 | torch.nn.SELU(),
60 | torch.nn.Linear(w, w),
61 | torch.nn.SELU(),
62 | torch.nn.Linear(w, w),
63 | torch.nn.SELU(),
64 | torch.nn.Linear(w,self.out_dim),
65 | )
66 | self.default_class = 0
67 |
68 |
69 | def forward(self, x):
70 | """forward pass
71 | Assume first two dimensions are x, c, then t
72 | """
73 | result = self.net(x)
74 | return torch.cat([result, x[:,self.dim:-1]], dim=1)
75 |
76 | class FM_baseline(torch.nn.Module):
77 | """ Conditional with many available classes
78 |
79 | return the class as is
80 | """
81 | def __init__(self, dim,
82 | out_dim=None,
83 | w=64,
84 | time_varying=False,
85 | conditional=False,
86 | treatment_cond = 0,
87 | time_dim = NUM_FREQS * 2,
88 | clip = None):
89 | super().__init__()
90 | self.time_varying = time_varying
91 | if out_dim is None:
92 | self.out_dim = dim
93 | self.out_dim += 1
94 | self.treatment_cond = treatment_cond
95 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0)
96 | self.net = torch.nn.Sequential(
97 | torch.nn.Linear(self.indim, w),
98 | torch.nn.SELU(),
99 | torch.nn.Linear(w, w),
100 | torch.nn.SELU(),
101 | torch.nn.Linear(w, w),
102 | torch.nn.SELU(),
103 | torch.nn.Linear(w,self.out_dim),
104 | )
105 | self.default_class = 0
106 | # self.encoding_function = positional_encoding_tensor()
107 | self.clip = clip
108 |
109 | def encoding_function(self, time_tensor):
110 | return positional_encoding_tensor(time_tensor)
111 |
112 | def forward_train(self, x):
113 | """forward pass
114 | Assume first two dimensions are x, c, then t
115 | input: x0
116 | output: vt
117 | """
118 | time_tensor = x[:,-1]
119 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
120 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32)
121 | result = self.net(new_x)
122 | return torch.cat([result[:,:-1], result[:,-1].unsqueeze(1)], dim=1)
123 |
124 | def forward(self,x):
125 | """Function for simulation testing
126 |
127 |
128 | Args:
129 | x (_type_): x + time dimension
130 |
131 | Returns:
132 | forward_train(x)[:,:-1]: x without time dimension
133 | """
134 | return self.forward_train(x)[:,:-1]
135 |
136 |
137 |
138 | """ Lightning module """
139 | def mse_loss(pred, true):
140 | return torch.mean((pred - true) ** 2)
141 |
142 | def l1_loss(pred, true):
143 | return torch.mean(torch.abs(pred - true))
144 |
145 | class MLP_FM(pl.LightningModule):
146 | def __init__(self,
147 | treatment_cond,
148 | dim=2,
149 | w=64,
150 | time_varying=True,
151 | conditional=False,
152 | lr=1e-6,
153 | sigma = 0.1,
154 | loss_fn = mse_loss,
155 | metrics = ['mse_loss', 'l1_loss'],
156 | implementation = "ODE", # can be SDE
157 | sde_noise = 0.1,
158 | clip = None, # float
159 | naming = None,
160 | ):
161 | super().__init__()
162 | self.model = FM_baseline(dim=dim,
163 | w=w,
164 | time_varying=time_varying,
165 | conditional=conditional, # no conditional for baseline
166 | clip = clip,
167 | treatment_cond=treatment_cond)
168 | self.loss_fn = loss_fn
169 | self.save_hyperparameters()
170 | self.dim = dim
171 | # self.out_dim = out_dim
172 | self.w = w
173 | self.time_varying = time_varying
174 | self.conditional = conditional
175 | self.treatment_cond = treatment_cond
176 | self.lr = lr
177 | self.sigma = sigma
178 | self.naming = "FM_baseline_"+implementation
179 | self.metrics = metrics
180 | self.implementation = implementation
181 | self.sde_noise = sde_noise
182 | self.clip = clip
183 |
184 |
185 | def __convert_tensor__(self, tensor):
186 | return tensor.to(torch.float32)
187 |
188 | def __x_processing__(self, x0, x1, t0, t1):
189 | # squeeze xs (prevent mismatch)
190 | x0 = x0.squeeze(0)
191 | x1 = x1.squeeze(0)
192 | t0 = t0.squeeze()
193 | t1 = t1.squeeze()
194 |
195 | t = torch.rand(x0.shape[0],1).to(x0.device)
196 | mu_t = x0 * (1 - t) + x1 * t
197 | data_t_diff = (t1 - t0).unsqueeze(1)
198 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)
199 | ut = (x1 - x0) / (data_t_diff + 1e-4)
200 | t_model = t * data_t_diff + t0.unsqueeze(1)
201 | futuretime = t1 - t_model
202 | return x, ut, t_model, futuretime, t
203 |
204 | def training_step(self, batch, batch_idx):
205 | """_summary_
206 |
207 | Args:
208 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
209 | batch_idx (_type_): _description_
210 |
211 | Returns:
212 | _type_: _description_
213 | """
214 | x0, x0_class, x1, x0_time, x1_time = batch
215 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)
216 |
217 |
218 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
219 |
220 |
221 | if len(x0_class.shape) == 3:
222 | x0_class = x0_class.squeeze(0)
223 |
224 | # in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
225 | in_tensor = torch.cat([x, t_model], dim = -1)
226 | vt = self.model.forward_train(in_tensor)
227 |
228 | # SDE: inject noise in the loss
229 | if self.implementation == "SDE":
230 | variance = t*(1-t)*(self.sde_noise ** 2)
231 | noise = torch.randn_like(vt[:,:self.dim]) * torch.sqrt(variance)
232 | loss = self.loss_fn(vt[:,:self.dim]+noise, ut) + self.loss_fn(vt[:,-1], futuretime)
233 | else:
234 | loss = self.loss_fn(vt[:,:self.dim], ut) + self.loss_fn(vt[:,-1], futuretime)
235 | self.log('train_loss', loss)
236 | return loss
237 |
238 | def config_optimizer(self):
239 | return torch.optim.Adam(self.parameters(), lr=self.lr)
240 |
241 | def validation_step(self, batch, batch_idx):
242 | """validation_step
243 |
244 | Args:
245 | batch (_type_): batch size of 1 (since uneven)
246 | batch_idx (_type_): _description_
247 |
248 | Returns:
249 | _type_: _description_
250 | """
251 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val')
252 | self.log('val_loss', loss)
253 | for key, value in metricD.items():
254 | self.log(key+"_val", value)
255 | # return total_loss, traj_pairs
256 | return {'val_loss':loss, 'traj_pairs':pairs}
257 |
258 | def test_step(self, batch, batch_idx):
259 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test')
260 | self.log('test_loss', loss)
261 | for key, value in metricD.items():
262 | self.log(key+"_test", value)
263 | # return total_loss, traj_pairs
264 | return {'test_loss':loss, 'traj_pairs':pairs}
265 |
266 | def test_func_step(self, batch, batch_idx, mode='none'):
267 | """assuming each is one patient/batch"""
268 | total_loss = []
269 | traj_pairs = []
270 |
271 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
272 | times_x0 = times_x0.squeeze()
273 | times_x1 = times_x1.squeeze()
274 |
275 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0),
276 | x1_values[0,:,:self.dim]],
277 | dim=0)
278 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
279 | ind_loss, pred_traj = self.test_trajectory(batch)
280 | total_loss.append(ind_loss)
281 | traj_pairs.append([full_traj, pred_traj])
282 |
283 | full_traj = full_traj.detach().cpu().numpy()
284 | pred_traj = pred_traj.detach().cpu().numpy()
285 | full_time = full_time.detach().cpu().numpy()
286 |
287 | # graph
288 | fig = plot_3d_path_ind(pred_traj,
289 | full_traj,
290 | t_span=full_time,
291 | title="{}_trajectory_patient_{}".format(mode, batch_idx))
292 | if self.logger:
293 | # may cause problem if wandb disabled
294 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
295 |
296 | plt.close(fig)
297 |
298 | # metrics
299 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
300 | return np.mean(total_loss), traj_pairs, metricD
301 |
302 | def test_trajectory(self,pt_tensor):
303 | if self.implementation == "ODE":
304 | return self.test_trajectory_ode(pt_tensor)
305 | elif self.implementation == "SDE":
306 | return self.test_trajectory_sde(pt_tensor)
307 |
308 | def test_trajectory_ode(self,pt_tensor):
309 | """test_trajectory
310 |
311 | Args:
312 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
313 |
314 |
315 | Returns:
316 | mse_all, total_pred_tensor: _description_
317 | """
318 | node = NeuralODE(
319 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
320 | )
321 | total_pred = []
322 | mse = []
323 | t_max = 0
324 |
325 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
326 | # squeeze all
327 | x0_values = x0_values.squeeze(0)
328 | x1_values = x1_values.squeeze(0)
329 | times_x0 = times_x0.squeeze()
330 | times_x1 = times_x1.squeeze()
331 | x0_classes = x0_classes.squeeze()
332 |
333 | if len(x0_classes.shape) == 1:
334 | x0_classes = x0_classes.unsqueeze(1)
335 |
336 |
337 |
338 | total_pred.append(x0_values[0].unsqueeze(0))
339 | len_path = x0_values.shape[0]
340 | assert len_path == x1_values.shape[0]
341 | for i in range(len_path):
342 | # print(i)
343 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative)
344 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
345 | # encoded_time_span = positional_encoding_tensor(time_span).squeeze(1).reshape(-1, NUM_FREQS * 2)
346 | with torch.no_grad():
347 | # get last pred, if none then use startpt
348 | if i == 0:
349 | testpt = torch.cat([x0_values[i].unsqueeze(0)],dim=1)
350 | else: # incorporate last prediction
351 | testpt = pred_traj
352 | # print(testpt.shape)
353 | traj = node.trajectory(
354 | testpt,
355 | t_span=time_span,
356 | )
357 | pred_traj = traj[-1,:,:self.dim]
358 | total_pred.append(pred_traj)
359 | ground_truth_coords = x1_values[i]
360 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
361 | mse_all = np.mean(mse)
362 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
363 | return mse_all, total_pred_tensor
364 |
365 | def configure_optimizers(self):
366 | # Define the optimizer
367 | optimizer = optim.Adam(self.parameters(), lr=self.lr)
368 | return optimizer
369 |
370 | def test_trajectory_sde(self,pt_tensor):
371 | """test_trajectory
372 |
373 | Args:
374 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
375 |
376 |
377 | Returns:
378 | mse_all, total_pred_tensor: _description_
379 | """
380 | sde = SDE(self.model, noise=self.sde_noise)
381 | total_pred = []
382 | mse = []
383 | t_max = 0
384 |
385 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
386 | # squeeze all
387 | x0_values = x0_values.squeeze(0)
388 | x1_values = x1_values.squeeze(0)
389 | times_x0 = times_x0.squeeze()
390 | times_x1 = times_x1.squeeze()
391 | x0_classes = x0_classes.squeeze()
392 |
393 | if len(x0_classes.shape) == 1:
394 | x0_classes = x0_classes.unsqueeze(1)
395 |
396 |
397 |
398 | total_pred.append(x0_values[0].unsqueeze(0))
399 | len_path = x0_values.shape[0]
400 | assert len_path == x1_values.shape[0]
401 | for i in range(len_path):
402 | # print(i)
403 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative)
404 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
405 |
406 | with torch.no_grad():
407 | # get last pred, if none then use startpt
408 | if i == 0:
409 | testpt = torch.cat([x0_values[i].unsqueeze(0)],dim=1)
410 | else:
411 | testpt = pred_traj
412 | # print(testpt.shape)
413 | traj = self._sde_solver(sde, testpt, time_span)
414 |
415 | pred_traj = traj[-1,:,:self.dim]
416 | total_pred.append(pred_traj)
417 | ground_truth_coords = x1_values[i]
418 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
419 | mse_all = np.mean(mse)
420 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
421 | return mse_all, total_pred_tensor
422 |
423 |
424 | def _sde_solver(self, sde, initial_state, time_span):
425 | dt = time_span[1] - time_span[0] # Time step
426 | current_state = initial_state
427 | trajectory = [current_state]
428 |
429 | for t in time_span[1:]:
430 | drift = sde.f(t, current_state)
431 | diffusion = sde.g(t, current_state)
432 | noise = torch.randn_like(current_state) * torch.sqrt(dt)
433 | current_state = current_state + drift * dt + diffusion * noise
434 | trajectory.append(current_state)
435 |
436 | return torch.stack(trajectory)
437 |
438 |
--------------------------------------------------------------------------------
/src/model/base_models.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | ###########################
4 | # Adapted from:
5 | # Latent ODEs for Irregularly-Sampled Time Series
6 | # Author: Yulia Rubanova
7 | ###########################
8 |
9 | import utils.latent_ode_utils as utils
10 | from utils.latent_ode_utils import get_device
11 |
12 | from torch.distributions.multivariate_normal import MultivariateNormal
13 | from torch.distributions.normal import Normal
14 | from torch.distributions import kl_divergence, Independent
15 |
16 |
17 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
18 | n_data_points = mu_2d.size()[-1]
19 |
20 | if n_data_points > 0:
21 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
22 | log_prob = gaussian.log_prob(data_2d)
23 | log_prob = log_prob / n_data_points
24 | else:
25 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
26 | return log_prob
27 |
28 |
29 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas):
30 | # masked_log_lambdas and masked_data
31 | n_data_points = masked_data.size()[-1]
32 |
33 | if n_data_points > 0:
34 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
35 | #log_prob = log_prob / n_data_points
36 | else:
37 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
38 | return log_prob
39 |
40 |
41 |
42 | def compute_binary_CE_loss(label_predictions, mortality_label):
43 | #print("Computing binary classification loss: compute_CE_loss")
44 |
45 | mortality_label = mortality_label.reshape(-1)
46 |
47 | if len(label_predictions.size()) == 1:
48 | label_predictions = label_predictions.unsqueeze(0)
49 |
50 | n_traj_samples = label_predictions.size(0)
51 | label_predictions = label_predictions.reshape(n_traj_samples, -1)
52 |
53 | idx_not_nan = ~torch.isnan(mortality_label)
54 | if len(idx_not_nan) == 0.:
55 | print("All are labels are NaNs!")
56 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label))
57 |
58 | label_predictions = label_predictions[:,idx_not_nan]
59 | mortality_label = mortality_label[idx_not_nan]
60 |
61 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0:
62 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.")
63 |
64 | assert(not torch.isnan(label_predictions).any())
65 | assert(not torch.isnan(mortality_label).any())
66 |
67 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
68 | mortality_label = mortality_label.repeat(n_traj_samples, 1)
69 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)
70 |
71 | # divide by number of patients in a batch
72 | ce_loss = ce_loss / n_traj_samples
73 | return ce_loss
74 |
75 |
76 | def compute_multiclass_CE_loss(label_predictions, true_label, mask):
77 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss")
78 |
79 | if (len(label_predictions.size()) == 3):
80 | label_predictions = label_predictions.unsqueeze(0)
81 |
82 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()
83 |
84 | # assert(not torch.isnan(label_predictions).any())
85 | # assert(not torch.isnan(true_label).any())
86 |
87 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
88 | true_label = true_label.repeat(n_traj_samples, 1, 1)
89 |
90 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims)
91 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)
92 |
93 | # choose time points with at least one measurement
94 | mask = torch.sum(mask, -1) > 0
95 |
96 | # repeat the mask for each label to mark that the label for this time point is present
97 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0)
98 |
99 | label_mask = mask
100 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1)
101 | label_mask = label_mask.repeat(n_traj_samples,1,1,1)
102 |
103 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
104 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)
105 |
106 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
107 | assert(label_predictions.size(-1) == true_label.size(-1))
108 | # targets are in one-hot encoding -- convert to indices
109 | _, true_label = true_label.max(-1)
110 |
111 | res = []
112 | for i in range(true_label.size(0)):
113 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool())
114 | labels = torch.masked_select(true_label[i], label_mask[i].bool())
115 |
116 | pred_masked = pred_masked.reshape(-1, n_dims)
117 |
118 | if (len(labels) == 0):
119 | continue
120 |
121 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
122 | res.append(ce_loss)
123 |
124 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
125 | ce_loss = torch.mean(ce_loss)
126 | # # divide by number of patients in a batch
127 | # ce_loss = ce_loss / n_traj_samples
128 | return ce_loss
129 |
130 |
131 |
132 |
133 | def compute_masked_likelihood(mu, data, mask, likelihood_func):
134 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
135 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
136 |
137 | res = []
138 | for i in range(n_traj_samples):
139 | for k in range(n_traj):
140 | for j in range(n_dims):
141 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool())
142 |
143 | #assert(torch.sum(data_masked == 0.) < 10)
144 |
145 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool())
146 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j))
147 | res.append(log_prob)
148 | # shape: [n_traj*n_traj_samples, 1]
149 |
150 | res = torch.stack(res, 0).to(get_device(data))
151 | res = res.reshape((n_traj_samples, n_traj, n_dims))
152 | # Take mean over the number of dimensions
153 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean
154 | res = res.transpose(0,1)
155 | return res
156 |
157 |
158 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None):
159 | # these cases are for plotting through plot_estim_density
160 | if (len(mu.size()) == 3):
161 | # add additional dimension for gp samples
162 | mu = mu.unsqueeze(0)
163 |
164 | if (len(data.size()) == 2):
165 | # add additional dimension for gp samples and time step
166 | data = data.unsqueeze(0).unsqueeze(2)
167 | elif (len(data.size()) == 3):
168 | # add additional dimension for gp samples
169 | data = data.unsqueeze(0)
170 |
171 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
172 |
173 | assert(data.size()[-1] == n_dims)
174 |
175 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
176 | if mask is None:
177 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
178 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
179 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
180 |
181 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std)
182 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1)
183 | else:
184 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
185 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices)
186 | res = compute_masked_likelihood(mu, data, mask, func)
187 | return res
188 |
189 |
190 |
191 | def mse(mu, data, indices = None):
192 | n_data_points = mu.size()[-1]
193 |
194 | if n_data_points > 0:
195 | mse = nn.MSELoss()(mu, data)
196 | else:
197 | mse = torch.zeros([1]).to(get_device(data)).squeeze()
198 | return mse
199 |
200 |
201 | def compute_mse(mu, data, mask = None):
202 | # these cases are for plotting through plot_estim_density
203 | if (len(mu.size()) == 3):
204 | # add additional dimension for gp samples
205 | mu = mu.unsqueeze(0)
206 |
207 | if (len(data.size()) == 2):
208 | # add additional dimension for gp samples and time step
209 | data = data.unsqueeze(0).unsqueeze(2)
210 | elif (len(data.size()) == 3):
211 | # add additional dimension for gp samples
212 | data = data.unsqueeze(0)
213 |
214 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
215 | assert(data.size()[-1] == n_dims)
216 |
217 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
218 | if mask is None:
219 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
220 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
221 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
222 | res = mse(mu_flat, data_flat)
223 | else:
224 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
225 | res = compute_masked_likelihood(mu, data, mask, mse)
226 | return res
227 |
228 |
229 |
230 |
231 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None):
232 | # Compute Poisson likelihood
233 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process
234 | # Sum log lambdas across all time points
235 | if mask is None:
236 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"]
237 | # Sum over data dims
238 | poisson_log_l = torch.mean(poisson_log_l, -1)
239 | else:
240 | # Compute likelihood of the data under the predictions
241 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
242 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1)
243 |
244 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
245 | int_lambda = info["int_lambda"]
246 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda)
247 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f)
248 | poisson_log_l = poisson_log_l.permute(1,0)
249 | # Take mean over n_traj
250 | #poisson_log_l = torch.mean(poisson_log_l, 1)
251 |
252 | # poisson_log_l shape: [n_traj_samples, n_traj]
253 | return poisson_log_l
254 |
255 |
256 | import numpy as np
257 | import torch
258 | import torch.nn as nn
259 | from torch.nn.functional import relu
260 |
261 | from utils.latent_ode_utils import *
262 | from model.latent_ode import *
263 | # from lib.likelihood_eval import *
264 |
265 | from torch.distributions.multivariate_normal import MultivariateNormal
266 | from torch.distributions.normal import Normal
267 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
268 |
269 |
270 |
271 | from torch.distributions.normal import Normal
272 | from torch.distributions import Independent
273 | from torch.nn.parameter import Parameter
274 |
275 | def create_classifier(z0_dim, n_labels):
276 | return nn.Sequential(
277 | nn.Linear(z0_dim, 300),
278 | nn.ReLU(),
279 | nn.Linear(300, 300),
280 | nn.ReLU(),
281 | nn.Linear(300, n_labels),)
282 |
283 | class VAE_Baseline(nn.Module):
284 | def __init__(self, input_dim, latent_dim,
285 | z0_prior, device,
286 | obsrv_std = 0.01,
287 | use_binary_classif = False,
288 | classif_per_tp = False,
289 | use_poisson_proc = False,
290 | linear_classifier = False,
291 | n_labels = 1,
292 | train_classif_w_reconstr = False):
293 |
294 | super(VAE_Baseline, self).__init__()
295 |
296 | self.input_dim = input_dim
297 | self.latent_dim = latent_dim
298 | self.device = device
299 | self.n_labels = n_labels
300 |
301 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
302 |
303 | self.z0_prior = z0_prior
304 | self.use_binary_classif = use_binary_classif
305 | self.classif_per_tp = classif_per_tp
306 | self.use_poisson_proc = use_poisson_proc
307 | self.linear_classifier = linear_classifier
308 | self.train_classif_w_reconstr = train_classif_w_reconstr
309 |
310 | z0_dim = latent_dim
311 | if use_poisson_proc:
312 | z0_dim += latent_dim
313 |
314 | if use_binary_classif:
315 | if linear_classifier:
316 | self.classifier = nn.Sequential(
317 | nn.Linear(z0_dim, n_labels))
318 | else:
319 | self.classifier = create_classifier(z0_dim, n_labels)
320 | utils.init_network_weights(self.classifier)
321 |
322 |
323 | def get_gaussian_likelihood(self, truth, pred_y, mask = None):
324 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
325 | # truth shape [n_traj, n_tp, n_dim]
326 | n_traj, n_tp, n_dim = truth.size()
327 |
328 | # Compute likelihood of the data under the predictions
329 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
330 |
331 | if mask is not None:
332 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
333 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated,
334 | obsrv_std = self.obsrv_std, mask = mask)
335 | log_density_data = log_density_data.permute(1,0)
336 | log_density = torch.mean(log_density_data, 1)
337 |
338 | # shape: [n_traj_samples]
339 | return log_density
340 |
341 |
342 | def get_mse(self, truth, pred_y, mask = None):
343 |
344 | # Compute likelihood of the data under the predictions
345 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
346 |
347 | if mask is not None:
348 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
349 |
350 | # Compute likelihood of the data under the predictions
351 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask)
352 | # shape: [1]
353 | return torch.mean(log_density_data)
354 |
355 |
356 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1.):
357 | # Condition on subsampled points
358 | # Make predictions for all the points
359 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"],
360 | batch_dict["observed_data"], batch_dict["observed_tp"],
361 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
362 | mode = batch_dict["mode"])
363 | fp_mu, fp_std, fp_enc = info["first_point"]
364 | fp_std = fp_std.abs()
365 | fp_distr = Normal(fp_mu, fp_std)
366 |
367 | assert(torch.sum(fp_std < 0) == 0.)
368 |
369 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior)
370 |
371 | if torch.isnan(kldiv_z0).any():
372 | print(fp_mu)
373 | print(fp_std)
374 | raise Exception("kldiv_z0 is Nan!")
375 |
376 | # Mean over number of latent dimensions
377 | kldiv_z0 = torch.mean(kldiv_z0,(1,2))
378 |
379 | # Compute likelihood of all the points
380 | rec_likelihood = self.get_gaussian_likelihood(
381 | batch_dict["data_to_predict"], pred_y,
382 | mask = batch_dict["mask_predicted_data"])
383 |
384 | mse = self.get_mse(
385 | batch_dict["data_to_predict"], pred_y,
386 | mask = batch_dict["mask_predicted_data"])
387 |
388 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
389 | if self.use_poisson_proc:
390 | pois_log_likelihood = compute_poisson_proc_likelihood(
391 | batch_dict["data_to_predict"], pred_y,
392 | info, mask = batch_dict["mask_predicted_data"])
393 | # Take mean over n_traj
394 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1)
395 |
396 |
397 | # IWAE loss
398 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0)
399 | if torch.isnan(loss):
400 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)
401 |
402 | results = {}
403 | results["loss"] = torch.mean(loss)
404 | results["likelihood"] = torch.mean(rec_likelihood).detach()
405 | results["mse"] = torch.mean(mse).detach()
406 | results["kl_first_p"] = torch.mean(kldiv_z0).detach()
407 | results["std_first_p"] = torch.mean(fp_std).detach()
408 |
409 | if batch_dict["labels"] is not None and self.use_binary_classif:
410 | results["label_predictions"] = info["label_predictions"].detach()
411 |
412 | return results
--------------------------------------------------------------------------------
/src/model/cfm_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import ot as pot
8 | import torchdyn
9 | from torchdyn.core import NeuralODE
10 | import pytorch_lightning as pl
11 | from torch import optim
12 | import torch.functional as F
13 | import wandb
14 | from model.components.grad_util import *
15 |
16 | from utils.visualize import *
17 | from utils.metric_calc import *
18 | from utils.sde import SDE
19 | from model.components.positional_encoding import *
20 | from utils.loss import mse_loss, l1_loss
21 |
22 |
23 |
24 | class MLP(torch.nn.Module):
25 | def __init__(self, dim, out_dim=None, w=64, time_varying=False):
26 | super().__init__()
27 | self.time_varying = time_varying
28 | if out_dim is None:
29 | out_dim = dim
30 | self.net = torch.nn.Sequential(
31 | torch.nn.Linear(dim + (1 if time_varying else 0), w),
32 | torch.nn.SELU(),
33 | torch.nn.Linear(w, w),
34 | torch.nn.SELU(),
35 | torch.nn.Linear(w, w),
36 | torch.nn.SELU(),
37 | torch.nn.Linear(w, out_dim),
38 | )
39 |
40 | def forward(self, x, *args, **kwargs):
41 | return self.net(x)
42 |
43 |
44 | # conditional liver model
45 | class MLP_conditional_liver(torch.nn.Module):
46 | """ Conditional with many available classes
47 |
48 | return the class as is
49 | """
50 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False):
51 | super().__init__()
52 | self.time_varying = time_varying
53 | if out_dim is None:
54 | self.out_dim = dim
55 | self.treatment_cond = treatment_cond
56 | self.dim = dim
57 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0)
58 | self.net = torch.nn.Sequential(
59 | torch.nn.Linear(self.indim, w),
60 | torch.nn.SELU(),
61 | torch.nn.Linear(w, w),
62 | torch.nn.SELU(),
63 | torch.nn.Linear(w, w),
64 | torch.nn.SELU(),
65 | torch.nn.Linear(w,self.out_dim),
66 | )
67 | self.default_class = 0
68 |
69 |
70 | def forward(self, x):
71 | """forward pass
72 | Assume first two dimensions are x, c, then t
73 | """
74 | result = self.net(x)
75 | return torch.cat([result, x[:,self.dim:-1]], dim=1)
76 |
77 | class FM_baseline(torch.nn.Module):
78 | """ Conditional with many available classes
79 |
80 | return the class as is
81 | """
82 | def __init__(self, dim,
83 | out_dim=None,
84 | w=64,
85 | time_varying=False,
86 | conditional=False,
87 | treatment_cond = 0,
88 | time_dim = NUM_FREQS * 2,
89 | clip = None):
90 | super().__init__()
91 | self.dim = dim
92 | self.time_varying = time_varying
93 | if out_dim is None:
94 | self.out_dim = dim
95 | self.out_dim += 1
96 | self.treatment_cond = treatment_cond
97 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0)
98 | self.net = torch.nn.Sequential(
99 | torch.nn.Linear(self.indim, w),
100 | torch.nn.SELU(),
101 | torch.nn.Linear(w, w),
102 | torch.nn.SELU(),
103 | torch.nn.Linear(w, w),
104 | torch.nn.SELU(),
105 | torch.nn.Linear(w,self.out_dim),
106 | )
107 | self.default_class = 0
108 | self.clip = clip
109 |
110 | def encoding_function(self, time_tensor):
111 | return positional_encoding_tensor(time_tensor)
112 |
113 | def forward_train(self, x):
114 | """forward pass
115 | Assume first two dimensions are x, c, then t
116 | input: x0
117 | output: vt
118 | """
119 | time_tensor = x[:,-1]
120 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
121 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32)
122 | result = self.net(new_x)
123 | return torch.cat([result[:,:-1], x[:,self.dim:-1],result[:,-1].unsqueeze(1)], dim=1)
124 |
125 | def forward(self,x):
126 | """Function for simulation testing
127 |
128 |
129 | Args:
130 | x (_type_): x + time dimension
131 |
132 | Returns:
133 | forward_train(x)[:,:-1]: x without time dimension
134 | """
135 | return self.forward_train(x)[:,:-1]
136 |
137 |
138 |
139 | """ Lightning module """
140 |
141 | class MLP_CFM(pl.LightningModule):
142 | def __init__(self,
143 | treatment_cond,
144 | dim=2,
145 | w=64,
146 | time_varying=True,
147 | conditional=True,
148 | lr=1e-6,
149 | sigma = 0.1,
150 | loss_fn = mse_loss,
151 | metrics = ['mse_loss', 'l1_loss'],
152 | implementation = "ODE", # can be SDE
153 | sde_noise = 0.1,
154 | clip = None, # float
155 | naming = None,
156 | ):
157 | super().__init__()
158 | self.model = FM_baseline(dim=dim,
159 | w=w,
160 | time_varying=time_varying,
161 | conditional=conditional, # no conditional for baseline
162 | clip = clip,
163 | treatment_cond=treatment_cond)
164 | self.loss_fn = loss_fn
165 | self.save_hyperparameters()
166 | self.dim = dim
167 | # self.out_dim = out_dim
168 | self.w = w
169 | self.time_varying = time_varying
170 | self.conditional = conditional
171 | self.treatment_cond = treatment_cond
172 | self.lr = lr
173 | self.sigma = sigma
174 | self.naming = "CFM_baseline_"+implementation
175 | self.metrics = metrics
176 | self.implementation = implementation
177 | self.sde_noise = sde_noise
178 | self.clip = clip
179 |
180 |
181 | def __convert_tensor__(self, tensor):
182 | return tensor.to(torch.float32)
183 |
184 | def __x_processing__(self, x0, x1, t0, t1):
185 | x0 = x0.squeeze(0)
186 | x1 = x1.squeeze(0)
187 | t0 = t0.squeeze()
188 | t1 = t1.squeeze()
189 |
190 | t = torch.rand(x0.shape[0],1).to(x0.device)
191 | mu_t = x0 * (1 - t) + x1 * t
192 | data_t_diff = (t1 - t0).unsqueeze(1)
193 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)
194 | ut = (x1 - x0) / (data_t_diff + 1e-4)
195 | t_model = t * data_t_diff + t0.unsqueeze(1)
196 | futuretime = t1 - t_model
197 | return x, ut, t_model, futuretime, t
198 |
199 | def training_step(self, batch, batch_idx):
200 | """_summary_
201 |
202 | Args:
203 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
204 | batch_idx (_type_): _description_
205 |
206 | Returns:
207 | _type_: _description_
208 | """
209 | x0, x0_class, x1, x0_time, x1_time = batch
210 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)
211 |
212 |
213 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
214 |
215 |
216 | if len(x0_class.shape) == 3:
217 | x0_class = x0_class.squeeze(0)
218 |
219 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
220 | vt = self.model.forward_train(in_tensor)
221 |
222 | # SDE: inject noise in the loss
223 | if self.implementation == "SDE":
224 | variance = t*(1-t)*(self.sde_noise ** 2)
225 | noise = torch.randn_like(vt[:,:self.dim]) * torch.sqrt(variance)
226 | loss = self.loss_fn(vt[:,:self.dim]+noise, ut) + self.loss_fn(vt[:,-1], futuretime)
227 | else:
228 | loss = self.loss_fn(vt[:,:self.dim], ut) + self.loss_fn(vt[:,-1], futuretime)
229 | self.log('train_loss', loss)
230 | return loss
231 |
232 | def config_optimizer(self):
233 | return torch.optim.Adam(self.parameters(), lr=self.lr)
234 |
235 | def validation_step(self, batch, batch_idx):
236 | """validation_step
237 |
238 | Args:
239 | batch (_type_): batch size of 1 (since uneven)
240 | batch_idx (_type_): _description_
241 |
242 | Returns:
243 | _type_: _description_
244 | """
245 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val')
246 | self.log('val_loss', loss)
247 | for key, value in metricD.items():
248 | self.log(key+"_val", value)
249 | return {'val_loss':loss, 'traj_pairs':pairs}
250 |
251 | def test_step(self, batch, batch_idx):
252 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test')
253 | self.log('test_loss', loss)
254 | for key, value in metricD.items():
255 | self.log(key+"_test", value)
256 | return {'test_loss':loss, 'traj_pairs':pairs}
257 |
258 | def test_func_step(self, batch, batch_idx, mode='none'):
259 | """assuming each is one patient/batch"""
260 | total_loss = []
261 | traj_pairs = []
262 |
263 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
264 | times_x0 = times_x0.squeeze()
265 | times_x1 = times_x1.squeeze()
266 |
267 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0),
268 | x1_values[0,:,:self.dim]],
269 | dim=0)
270 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
271 | ind_loss, pred_traj = self.test_trajectory(batch)
272 | total_loss.append(ind_loss)
273 | traj_pairs.append([full_traj, pred_traj])
274 |
275 | full_traj = full_traj.detach().cpu().numpy()
276 | pred_traj = pred_traj.detach().cpu().numpy()
277 | full_time = full_time.detach().cpu().numpy()
278 |
279 | # graph
280 | fig = plot_3d_path_ind(pred_traj,
281 | full_traj,
282 | t_span=full_time,
283 | title="{}_trajectory_patient_{}".format(mode, batch_idx))
284 | if self.logger:
285 | # may cause problem if wandb disabled
286 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
287 |
288 | plt.close(fig)
289 |
290 | # metrics
291 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
292 | return np.mean(total_loss), traj_pairs, metricD
293 |
294 | def test_trajectory(self,pt_tensor):
295 | if self.implementation == "ODE":
296 | return self.test_trajectory_ode(pt_tensor)
297 | elif self.implementation == "SDE":
298 | return self.test_trajectory_sde(pt_tensor)
299 |
300 | def test_trajectory_ode(self,pt_tensor):
301 | """test_trajectory
302 |
303 | Args:
304 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
305 |
306 |
307 | Returns:
308 | mse_all, total_pred_tensor: _description_
309 | """
310 | node = NeuralODE(
311 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
312 | )
313 | total_pred = []
314 | mse = []
315 | t_max = 0
316 |
317 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
318 | # squeeze all
319 | x0_values = x0_values.squeeze(0)
320 | x1_values = x1_values.squeeze(0)
321 | times_x0 = times_x0.squeeze()
322 | times_x1 = times_x1.squeeze()
323 | x0_classes = x0_classes.squeeze()
324 |
325 | if len(x0_classes.shape) == 1:
326 | x0_classes = x0_classes.unsqueeze(1)
327 |
328 |
329 |
330 | total_pred.append(x0_values[0].unsqueeze(0))
331 | len_path = x0_values.shape[0]
332 | assert len_path == x1_values.shape[0]
333 | for i in range(len_path):
334 | t_max = (times_x1[i]-times_x0[i])
335 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
336 | with torch.no_grad():
337 | if i == 0:
338 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1)
339 | else: # incorporate last prediction
340 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1)
341 | traj = node.trajectory(
342 | testpt,
343 | t_span=time_span,
344 | )
345 | pred_traj = traj[-1,:,:self.dim]
346 | total_pred.append(pred_traj)
347 | ground_truth_coords = x1_values[i]
348 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
349 | mse_all = np.mean(mse)
350 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
351 | return mse_all, total_pred_tensor
352 |
353 | def configure_optimizers(self):
354 | optimizer = optim.Adam(self.parameters(), lr=self.lr)
355 | return optimizer
356 |
357 | def test_trajectory_sde(self,pt_tensor):
358 | """test_trajectory
359 |
360 | Args:
361 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
362 |
363 |
364 | Returns:
365 | mse_all, total_pred_tensor: _description_
366 | """
367 | sde = SDE(self.model, noise=self.sde_noise)
368 | total_pred = []
369 | mse = []
370 | t_max = 0
371 |
372 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
373 | # squeeze all
374 | x0_values = x0_values.squeeze(0)
375 | x1_values = x1_values.squeeze(0)
376 | times_x0 = times_x0.squeeze()
377 | times_x1 = times_x1.squeeze()
378 | x0_classes = x0_classes.squeeze()
379 |
380 | if len(x0_classes.shape) == 1:
381 | x0_classes = x0_classes.unsqueeze(1)
382 |
383 |
384 |
385 | total_pred.append(x0_values[0].unsqueeze(0))
386 | len_path = x0_values.shape[0]
387 | assert len_path == x1_values.shape[0]
388 | for i in range(len_path):
389 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
390 |
391 | with torch.no_grad():
392 | # get last pred, if none then use startpt
393 | if i == 0:
394 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1)
395 | else: # incorporate last prediction
396 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1)
397 | traj = self._sde_solver(sde, testpt, time_span)
398 |
399 | pred_traj = traj[-1,:,:self.dim]
400 | total_pred.append(pred_traj)
401 | ground_truth_coords = x1_values[i]
402 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
403 | mse_all = np.mean(mse)
404 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
405 | return mse_all, total_pred_tensor
406 |
407 |
408 | def _sde_solver(self, sde, initial_state, time_span):
409 | dt = time_span[1] - time_span[0] # Time step
410 | current_state = initial_state
411 | trajectory = [current_state]
412 |
413 | for t in time_span[1:]:
414 | drift = sde.f(t, current_state)
415 | diffusion = sde.g(t, current_state)
416 | noise = torch.randn_like(current_state) * torch.sqrt(dt)
417 | current_state = current_state + drift * dt + diffusion * noise
418 | trajectory.append(current_state)
419 |
420 | return torch.stack(trajectory)
421 |
422 |
--------------------------------------------------------------------------------
/src/model/components/grad_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class GradModel(torch.nn.Module):
4 | def __init__(self, action):
5 | super().__init__()
6 | self.action = action
7 |
8 | def forward(self, x):
9 | x = x.requires_grad_(True)
10 | grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
11 | return grad[:, :-1]
12 |
13 |
14 | class torch_wrapper(torch.nn.Module):
15 | """Wraps model to torchdyn compatible format."""
16 |
17 | def __init__(self, model):
18 | super().__init__()
19 | self.model = model
20 |
21 | def forward(self, t, x):
22 | return self.model(x)
23 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))
24 |
25 | class torch_wrapper_tv(torch.nn.Module):
26 | """Wraps model to torchdyn compatible format."""
27 |
28 | def __init__(self, model):
29 | super().__init__()
30 | self.model = model
31 |
32 | def forward(self, t, x, *args, **kwargs):
33 | # print(x.shape, t.shape)
34 | # print(t)
35 | # return self.model(x)
36 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))
37 |
38 |
39 |
40 | class torch_wrapper_cond(torch.nn.Module):
41 | """Wraps model to torchdyn compatible format."""
42 |
43 | def __init__(self, model):
44 | super().__init__()
45 | self.model = model
46 |
47 | def forward(self, t, x, *args, **kwargs):
48 | # unpack the input
49 | # class_cond = torch.zeros(x.shape[0],1)
50 | input = torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)
51 | print(input.shape)
52 | return self.model(input)
--------------------------------------------------------------------------------
/src/model/components/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import ot as pot
8 | import torchdyn
9 | from torchdyn.core import NeuralODE
10 | import pytorch_lightning as pl
11 | from torch import optim
12 | import torch.functional as F
13 | import wandb
14 |
15 | from utils.visualize import *
16 | from utils.metric_calc import *
17 | from utils.sde import SDE
18 |
19 | from model.components.positional_encoding import *
20 | from model.components.grad_util import torch_wrapper_tv
21 | from utils.loss import mse_loss
22 |
23 |
24 | class MLP(torch.nn.Module):
25 | def __init__(self, dim, out_dim=None, w=64, time_varying=False):
26 | super().__init__()
27 | self.time_varying = time_varying
28 | if out_dim is None:
29 | out_dim = dim
30 | self.net = torch.nn.Sequential(
31 | torch.nn.Linear(dim + (1 if time_varying else 0), w),
32 | torch.nn.SELU(),
33 | torch.nn.Linear(w, w),
34 | torch.nn.SELU(),
35 | torch.nn.Linear(w, w),
36 | torch.nn.SELU(),
37 | torch.nn.Linear(w, out_dim),
38 | )
39 |
40 | def forward(self, x, *args, **kwargs):
41 | return self.net(x)
42 |
43 |
44 | # conditional liver model
45 | class MLP_conditional_liver(torch.nn.Module):
46 | """ Conditional with many available classes
47 |
48 | return the class as is
49 | """
50 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False):
51 | super().__init__()
52 | self.time_varying = time_varying
53 | if out_dim is None:
54 | self.out_dim = dim
55 | self.treatment_cond = treatment_cond
56 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0)
57 | self.net = torch.nn.Sequential(
58 | torch.nn.Linear(self.indim, w),
59 | torch.nn.SELU(),
60 | torch.nn.Linear(w, w),
61 | torch.nn.SELU(),
62 | torch.nn.Linear(w, w),
63 | torch.nn.SELU(),
64 | torch.nn.Linear(w,self.out_dim),
65 | )
66 | self.default_class = 0
67 |
68 |
69 | def forward(self, x):
70 | """forward pass
71 | Assume first two dimensions are x, c, then t
72 | """
73 | result = self.net(x)
74 | # print(result.shape)
75 | # print(x[:,2:-2].shape)
76 | return torch.cat([result, x[:,2:-1]], dim=1)
77 |
78 | class MLP_conditional_liver_pe(torch.nn.Module):
79 | """ Conditional with many available classes
80 |
81 | return the class as is
82 | """
83 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False, time_dim = NUM_FREQS * 2, clip = None):
84 | super().__init__()
85 | self.time_varying = time_varying
86 | if out_dim is None:
87 | self.out_dim = dim
88 | self.out_dim += 1
89 | self.treatment_cond = treatment_cond
90 | self.dim = dim
91 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0)
92 | self.net = torch.nn.Sequential(
93 | torch.nn.Linear(self.indim, w),
94 | torch.nn.SELU(),
95 | torch.nn.Linear(w, w),
96 | torch.nn.SELU(),
97 | torch.nn.Linear(w, w),
98 | torch.nn.SELU(),
99 | torch.nn.Linear(w,self.out_dim),
100 | )
101 | self.default_class = 0
102 | self.clip = clip
103 |
104 | def encoding_function(self, time_tensor):
105 | return positional_encoding_tensor(time_tensor)
106 |
107 | def forward_train(self, x):
108 | """forward pass
109 | Assume first two dimensions are x, c, then t
110 | """
111 | time_tensor = x[:,-1]
112 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
113 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32)
114 | result = self.net(new_x)
115 | return torch.cat([result[:,:-1], x[:,self.dim:-1],result[:,-1].unsqueeze(1)], dim=1)
116 |
117 | def forward(self, x):
118 | """ call forward_train for training
119 | x here is x_t
120 | xt = (t)x1 + (1-t)x0
121 | (xt - tx1)/(1-t) = x0
122 | """
123 | x1 = self.forward_train(x)
124 | x1_coord = x1[:,:self.dim]
125 | t = x[:,-1]
126 | pred_time_till_t1 = x1[:,-1]
127 | x_coord = x[:,:self.dim]
128 | if self.clip is None:
129 | vt = (x1_coord - x_coord)/(pred_time_till_t1)
130 | else:
131 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip)
132 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1)
133 | return final_vt
134 |
135 | class MLP_Cond_Module(pl.LightningModule):
136 | def __init__(self,
137 | treatment_cond,
138 | dim=2,
139 | w=64,
140 | time_varying=True,
141 | conditional=True,
142 | lr=1e-6,
143 | sigma = 0.1,
144 | loss_fn = mse_loss,
145 | metrics = ['mse_loss', 'l1_loss'],
146 | implementation = "ODE", # can be SDE
147 | sde_noise = 0.1,
148 | clip = None, # float
149 | naming = None,
150 | ):
151 | super().__init__()
152 | self.model = MLP_conditional_liver_pe(dim=dim,
153 | w=w,
154 | time_varying=time_varying,
155 | conditional=conditional,
156 | treatment_cond=treatment_cond,
157 | clip = clip)
158 | self.loss_fn = loss_fn
159 | self.save_hyperparameters()
160 | self.dim = dim
161 | # self.out_dim = out_dim
162 | self.w = w
163 | self.time_varying = time_varying
164 | self.conditional = conditional
165 | self.treatment_cond = treatment_cond
166 | self.lr = lr
167 | self.sigma = sigma
168 | self.naming = "MLP_Cond_Module_"+implementation if naming is None else naming
169 | self.metrics = metrics
170 | self.implementation = implementation
171 | self.sde_noise = sde_noise
172 | self.clip = clip
173 |
174 |
175 | def __convert_tensor__(self, tensor):
176 | return tensor.to(torch.float32)
177 |
178 | def __x_processing__(self, x0, x1, t0, t1):
179 | # squeeze xs (prevent mismatch)
180 | x0 = x0.squeeze(0)
181 | x1 = x1.squeeze(0)
182 | t0 = t0.squeeze()
183 | t1 = t1.squeeze()
184 |
185 | t = torch.rand(x0.shape[0],1).to(x0.device)
186 | mu_t = x0 * (1 - t) + x1 * t
187 | data_t_diff = (t1 - t0).unsqueeze(1)
188 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)
189 | ut = (x1 - x0) / (data_t_diff + 1e-4)
190 | t_model = t * data_t_diff + t0.unsqueeze(1)
191 | futuretime = t1 - t_model
192 | return x, ut, t_model, futuretime, t
193 |
194 | def training_step(self, batch, batch_idx):
195 | """_summary_
196 |
197 | Args:
198 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
199 | batch_idx (_type_): _description_
200 |
201 | Returns:
202 | _type_: _description_
203 | """
204 | x0, x0_class, x1, x0_time, x1_time = batch
205 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)
206 |
207 |
208 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
209 |
210 |
211 | if len(x0_class.shape) == 3:
212 | x0_class = x0_class.squeeze(0)
213 |
214 | if self.conditional:
215 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
216 | else:
217 | in_tensor = torch.cat([x, t_model], dim = -1)
218 | xt = self.model.forward_train(in_tensor)
219 |
220 | if self.implementation == "SDE":
221 | variance = t*(1-t)*(self.sde_noise ** 2)
222 | noise = torch.randn_like(xt[:,:self.dim]) * torch.sqrt(variance)
223 | loss = self.loss_fn(xt[:,:self.dim]+noise, x1) + self.loss_fn(xt[:,-1], futuretime)
224 | else:
225 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime)
226 | self.log('train_loss', loss)
227 | return loss
228 |
229 | def config_optimizer(self):
230 | return torch.optim.Adam(self.parameters(), lr=self.lr)
231 |
232 | def validation_step(self, batch, batch_idx):
233 | """validation_step
234 |
235 | Args:
236 | batch (_type_): batch size of 1 (since uneven)
237 | batch_idx (_type_): _description_
238 |
239 | Returns:
240 | _type_: _description_
241 | """
242 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val')
243 | self.log('val_loss', loss)
244 | for key, value in metricD.items():
245 | self.log(key+"_val", value)
246 | # return total_loss, traj_pairs
247 | return {'val_loss':loss, 'traj_pairs':pairs}
248 |
249 | def test_step(self, batch, batch_idx):
250 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test')
251 | self.log('test_loss', loss)
252 | for key, value in metricD.items():
253 | self.log(key+"_test", value)
254 | return {'test_loss':loss, 'traj_pairs':pairs}
255 |
256 | def test_func_step(self, batch, batch_idx, mode='none'):
257 | """assuming each is one patient/batch"""
258 | total_loss = []
259 | traj_pairs = []
260 |
261 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
262 | times_x0 = times_x0.squeeze()
263 | times_x1 = times_x1.squeeze()
264 |
265 | # print(x0_values.shape)
266 | # print(x1_values.shape)
267 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0),
268 | x1_values[0,:,:self.dim]],
269 | dim=0)
270 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
271 | ind_loss, pred_traj = self.test_trajectory(batch)
272 | total_loss.append(ind_loss)
273 | traj_pairs.append([full_traj, pred_traj])
274 |
275 | full_traj = full_traj.detach().cpu().numpy()
276 | pred_traj = pred_traj.detach().cpu().numpy()
277 | full_time = full_time.detach().cpu().numpy()
278 |
279 | # graph
280 | fig = plot_3d_path_ind(pred_traj,
281 | full_traj,
282 | t_span=full_time,
283 | title="{}_trajectory_patient_{}".format(mode, batch_idx))
284 | if self.logger:
285 | # may cause problem if wandb disabled
286 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
287 |
288 | plt.close(fig)
289 |
290 | # metrics
291 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
292 | return np.mean(total_loss), traj_pairs, metricD
293 |
294 | def test_trajectory(self,pt_tensor):
295 | if self.implementation == "ODE":
296 | return self.test_trajectory_ode(pt_tensor)
297 | elif self.implementation == "SDE":
298 | return self.test_trajectory_sde(pt_tensor)
299 |
300 | def test_trajectory_ode(self,pt_tensor):
301 | """test_trajectory
302 |
303 | Args:
304 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
305 |
306 |
307 | Returns:
308 | mse_all, total_pred_tensor: _description_
309 | """
310 | node = NeuralODE(
311 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
312 | )
313 | total_pred = []
314 | mse = []
315 | t_max = 0
316 |
317 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
318 | # squeeze all
319 | x0_values = x0_values.squeeze(0)
320 | x1_values = x1_values.squeeze(0)
321 | times_x0 = times_x0.squeeze()
322 | times_x1 = times_x1.squeeze()
323 | x0_classes = x0_classes.squeeze()
324 |
325 | if len(x0_classes.shape) == 1:
326 | x0_classes = x0_classes.unsqueeze(1)
327 |
328 |
329 |
330 | total_pred.append(x0_values[0].unsqueeze(0))
331 | len_path = x0_values.shape[0]
332 | assert len_path == x1_values.shape[0]
333 | for i in range(len_path):
334 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative)
335 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
336 | with torch.no_grad():
337 | # get last pred, if none then use startpt
338 | if i == 0:
339 | if self.conditional:
340 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1)
341 | else:
342 | testpt = x0_values[i].unsqueeze(0)
343 | else: # incorporate last prediction
344 | if self.conditional:
345 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1)
346 | else:
347 | testpt = pred_traj
348 | traj = node.trajectory(
349 | testpt,
350 | t_span=time_span,
351 | )
352 | pred_traj = traj[-1,:,:self.dim]
353 | total_pred.append(pred_traj)
354 | ground_truth_coords = x1_values[i]
355 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
356 | mse_all = np.mean(mse)
357 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
358 | return mse_all, total_pred_tensor
359 |
360 | def configure_optimizers(self):
361 | # Define the optimizer
362 | optimizer = optim.Adam(self.parameters(), lr=self.lr)
363 | return optimizer
364 |
365 | def test_trajectory_sde(self,pt_tensor):
366 | """test_trajectory
367 |
368 | Args:
369 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
370 |
371 |
372 | Returns:
373 | mse_all, total_pred_tensor: _description_
374 | """
375 | sde = SDE(self.model, noise=0.1)
376 | total_pred = []
377 | mse = []
378 | t_max = 0
379 |
380 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
381 | # squeeze all
382 | x0_values = x0_values.squeeze(0)
383 | x1_values = x1_values.squeeze(0)
384 | times_x0 = times_x0.squeeze()
385 | times_x1 = times_x1.squeeze()
386 | x0_classes = x0_classes.squeeze()
387 |
388 | if len(x0_classes.shape) == 1:
389 | x0_classes = x0_classes.unsqueeze(1)
390 |
391 |
392 |
393 | total_pred.append(x0_values[0].unsqueeze(0))
394 | len_path = x0_values.shape[0]
395 | assert len_path == x1_values.shape[0]
396 | for i in range(len_path):
397 | t_max = (times_x1[i]-times_x0[i])
398 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
399 |
400 | with torch.no_grad():
401 | # get last pred, if none then use startpt
402 | if i == 0:
403 | if self.conditional:
404 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1)
405 | else:
406 | testpt = x0_values[i].unsqueeze(0)
407 | else: # incorporate last prediction
408 | if self.conditional:
409 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1)
410 | else:
411 | testpt = pred_traj
412 | traj = self._sde_solver(sde, testpt, time_span)
413 |
414 | pred_traj = traj[-1,:,:self.dim]
415 | total_pred.append(pred_traj)
416 | ground_truth_coords = x1_values[i]
417 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
418 | mse_all = np.mean(mse)
419 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
420 | return mse_all, total_pred_tensor
421 |
422 |
423 | def _sde_solver(self, sde, initial_state, time_span):
424 | dt = time_span[1] - time_span[0] # Time step
425 | current_state = initial_state
426 | trajectory = [current_state]
427 |
428 | for t in time_span[1:]:
429 | drift = sde.f(t, current_state)
430 | diffusion = sde.g(t, current_state)
431 | noise = torch.randn_like(current_state) * torch.sqrt(dt)
432 | current_state = current_state + drift * dt + diffusion * noise
433 | trajectory.append(current_state)
434 |
435 | return torch.stack(trajectory)
436 |
437 |
438 |
439 |
440 | class MLP_conditional_liver_pe_memory(torch.nn.Module):
441 | """ Conditional with many available classes
442 |
443 | return the class as is
444 | """
445 | def __init__(self,
446 | dim,
447 | treatment_cond,
448 | memory, # how many time steps
449 | out_dim=None,
450 | w=64,
451 | time_varying=False,
452 | conditional=False,
453 | time_dim = NUM_FREQS * 2,
454 | clip = None,
455 | ):
456 | super().__init__()
457 | self.time_varying = time_varying
458 | if out_dim is None:
459 | self.out_dim = dim
460 | self.out_dim += 1 # for the time dimension
461 | self.treatment_cond = treatment_cond
462 | self.memory = memory
463 | self.dim = dim
464 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory)
465 | self.net = torch.nn.Sequential(
466 | torch.nn.Linear(self.indim, w),
467 | torch.nn.SELU(),
468 | torch.nn.Linear(w, w),
469 | torch.nn.SELU(),
470 | torch.nn.Linear(w, w),
471 | torch.nn.SELU(),
472 | torch.nn.Linear(w,self.out_dim),
473 | )
474 | self.default_class = 0
475 | self.clip = clip
476 |
477 | def encoding_function(self, time_tensor):
478 | return positional_encoding_tensor(time_tensor)
479 |
480 | def forward_train(self, x):
481 | """forward pass
482 | Assume first two dimensions are x, c, then t
483 | """
484 | time_tensor = x[:,-1]
485 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
486 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1)
487 | result = self.net(new_x)
488 | return torch.cat([result[:,:-1], x[:,self.dim:-1], result[:,-1].unsqueeze(1)], dim=1)
489 |
490 | def forward(self, x):
491 | """ call forward_train for training
492 | x here is x_t
493 | xt = (t)x1 + (1-t)x0
494 | (xt - tx1)/(1-t) = x0
495 | """
496 | x1 = self.forward_train(x)
497 | x1_coord = x1[:,:self.dim]
498 | t = x[:,-1]
499 | pred_time_till_t1 = x1[:,-1]
500 | x_coord = x[:,:self.dim]
501 | if self.clip is None:
502 | vt = (x1_coord - x_coord)/(pred_time_till_t1)
503 | else:
504 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip)
505 |
506 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1)
507 | return final_vt
--------------------------------------------------------------------------------
/src/model/components/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | PE_BASE = 0.012 # 0.012615662610100801
4 | NUM_FREQS = 10
5 |
6 | def positional_encoding_tensor(time_tensor, num_frequencies=NUM_FREQS, base=PE_BASE):
7 | # Ensure the time tensor is in the range [0, 1]
8 | time_tensor = time_tensor.clamp(0, 1).unsqueeze(1) # Clamp and add dimension for broadcasting
9 |
10 | # Compute the arguments for the sine and cosine functions using the custom base
11 | frequencies = torch.pow(base, -torch.arange(0, num_frequencies, dtype=torch.float32) / num_frequencies).to(time_tensor.device)
12 | angles = time_tensor * frequencies
13 |
14 | # Compute the sine and cosine for even and odd indices respectively
15 | sine = torch.sin(angles)
16 | cosine = torch.cos(angles)
17 |
18 | # Stack them along the last dimension
19 | pos_encoding = torch.stack((sine, cosine), dim=-1)
20 | pos_encoding = pos_encoding.flatten(start_dim=2)
21 |
22 | # Normalize to have values between 0 and 1
23 | pos_encoding = (pos_encoding + 1) / 2 # Now values are between 0 and 1
24 |
25 | return pos_encoding
26 |
27 | def positional_encoding_df(df, col_mod = "time_normalized"):
28 | pe_tensors = torch.tensor(df[col_mod].values).astype(torch.float32)
29 | pos_encoding = positional_encoding_tensor(pe_tensors)
30 | pos_encoding_array = pos_encoding.numpy().reshape(-1, NUM_FREQS * 2)
31 | return pos_encoding_array
32 |
--------------------------------------------------------------------------------
/src/model/components/sde_func_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class SDE_func_solver(torch.nn.Module):
4 |
5 | noise_type = "diagonal"
6 | sde_type = "ito"
7 |
8 | # noise is sigma in this notebook for the equation sigma * (t * (1 - t))
9 | def __init__(self, ode_drift, noise, reverse=False):
10 | super().__init__()
11 | self.drift = ode_drift
12 | self.reverse = reverse
13 | self.noise = noise # changeable, a model itself
14 |
15 | # Drift
16 | def f(self, t, y):
17 | if self.reverse:
18 | t = 1 - t
19 | if len(t.shape) == len(y.shape):
20 | x = torch.cat([y, t], 1)
21 | else:
22 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)
23 | return self.drift(x)
24 |
25 | # Diffusion
26 | def g(self, t, y):
27 | if self.reverse:
28 | t = 1 - t
29 | if len(t.shape) == len(y.shape):
30 | x = torch.cat([y, t], 1)
31 | else:
32 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)
33 | noise_result = self.noise(x)
34 | return noise_result* torch.sqrt(t * (1 - t))
--------------------------------------------------------------------------------
/src/model/latent_ode.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | from torchdiffeq import odeint
5 | from utils.metric_calc import *
6 | from utils.latent_ode_utils import *
7 |
8 | class Encoder(nn.Module):
9 | def __init__(self, input_dim, latent_dim):
10 | super().__init__()
11 | self.net = nn.Sequential(
12 | nn.Linear(input_dim, 20),
13 | nn.Tanh(),
14 | nn.Linear(20, latent_dim),
15 | )
16 |
17 | def forward(self, x):
18 | return self.net(x)
19 |
20 |
21 |
22 | class Encoder_z0_ODE_RNN(nn.Module):
23 | # Derive z0 by running ode backwards.
24 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i
25 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1
26 | # Continue until we get to z0
27 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None,
28 | z0_dim = None, GRU_update = None,
29 | n_gru_units = 100,
30 | ):
31 |
32 | super(Encoder_z0_ODE_RNN, self).__init__()
33 |
34 | self.z0_dim = latent_dim
35 | self.GRU_update = GRU_unit(latent_dim, input_dim,
36 | n_units = n_gru_units)
37 | # device=device).to(device)
38 |
39 | self.z0_diffeq_solver = z0_diffeq_solver
40 | self.latent_dim = latent_dim
41 | self.input_dim = input_dim
42 | # self.device = device
43 | self.extra_info = None
44 |
45 | self.transform_z0 = nn.Sequential(
46 | nn.Linear(latent_dim * 2, 100),
47 | nn.Tanh(),
48 | nn.Linear(100, self.z0_dim * 2),)
49 |
50 |
51 | def forward(self, x1):
52 | # data, time_steps -- observations and their time stamps
53 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
54 | # xi is timepoint i of the data
55 | device_of_x1 = x1.device
56 | x1_len = x1.shape[1]
57 |
58 | prev_y = torch.zeros((1, x1_len, self.latent_dim)).to(device_of_x1)
59 | prev_std = torch.zeros((1, x1_len, self.latent_dim)).to(device_of_x1)
60 |
61 | # x1 = x1.unsqueeze(0)
62 |
63 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, x1)
64 |
65 | means_z0 = last_yi.reshape(1, x1_len, self.latent_dim)
66 | std_z0 = last_yi_std.reshape(1, x1_len, self.latent_dim)
67 |
68 | mean_z0, std_z0 = split_last_dim(self.transform_z0( torch.cat((means_z0, std_z0), -1)))
69 | std_z0 = std_z0.abs()
70 |
71 | return mean_z0, std_z0
72 |
73 |
74 |
75 | from torch.distributions.multivariate_normal import MultivariateNormal
76 | from torch.distributions.normal import Normal
77 | from torch.distributions import kl_divergence, Independent
78 | from model.base_models import VAE_Baseline
79 |
80 |
81 | class VAE(nn.Module):
82 | def __init__(self, encoder, decoder, ode_func, latent_dim):
83 | super(VAE, self).__init__()
84 | self.encoder = encoder
85 | self.decoder = decoder
86 | self.ode_func = ode_func
87 | self.latent_dim = latent_dim
88 |
89 | def forward(self, x, t, reverse=False):
90 | """Forward pass for training with the option to reverse time.
91 |
92 | Args:
93 | x: Input data (e.g., sequence of observations).
94 | t: Corresponding time points for the input data.
95 | reverse: Whether to reverse time for the ODE solver.
96 | """
97 | if len(t.shape)>1:
98 | t = t.squeeze()
99 |
100 | if reverse:
101 | # Training mode: reverse time
102 | t = torch.flip(t, dims=[0])
103 | x = torch.flip(x, dims=[1])
104 |
105 | mu, std_z0 = self.encoder(x)
106 | z = self.reparameterize(mu, std_z0)
107 | z_pred = odeint(self.ode_func, z, t)
108 | x_pred = self.decoder(z_pred)
109 |
110 | if reverse:
111 | # Flip predictions back to original time order
112 | x_pred = torch.flip(x_pred, dims=[1])
113 |
114 | return x_pred, mu, std_z0
115 |
116 | def reparameterize(self, mu, logvar):
117 | std = torch.exp(0.5 * logvar)
118 | eps = torch.randn_like(std)
119 | return mu + eps * std
120 |
121 |
122 | def extrapolate(self,z0, t):
123 | """predict future states - call during val/test
124 |
125 | Args:
126 | z0: Initial latent state from which to extrapolate.
127 | t: Future time points to predict.
128 | """
129 | # do one at a time
130 | z_pred = odeint(self.ode_func, z0, t)
131 | return self.decoder(z_pred)
132 |
133 |
134 |
135 | class Decoder(nn.Module):
136 | def __init__(self, latent_dim, output_dim):
137 | super().__init__()
138 | self.net = nn.Sequential(
139 | nn.Linear(latent_dim, output_dim),
140 | )
141 |
142 | def forward(self, x):
143 | return self.net(x)
144 |
145 | class LatentODEFunc(nn.Module):
146 | def __init__(self, latent_dim):
147 | super().__init__()
148 | self.net = nn.Sequential(
149 | nn.Linear(latent_dim, 50),
150 | nn.Tanh(),
151 | nn.Linear(50, 50),
152 | nn.Tanh(),
153 | nn.Linear(50, latent_dim),
154 | )
155 |
156 | def forward(self, t, x):
157 | return self.net(x)
158 |
159 | class LatentODE_pl(pl.LightningModule):
160 | def __init__(self,
161 | input_dim,
162 | latent_dim,
163 | output_dim,
164 | lr=1e-3,
165 | loss_fn=nn.MSELoss(),
166 | metrics = ['mse_loss', 'l1_loss'],
167 | ):
168 | super().__init__()
169 | self.encoder = Encoder_z0_ODE_RNN(latent_dim,input_dim)
170 | self.decoder = Decoder(latent_dim, output_dim)
171 | self.vae = VAE(self.encoder, self.decoder, LatentODEFunc(latent_dim), latent_dim)
172 | self.lr = lr
173 | self.loss_fn = loss_fn
174 | self.naming = 'LatentODE_RNN'
175 | self.metrics = metrics
176 |
177 | def forward(self, x, x_time, mode='train'):
178 | if mode == 'train':
179 | reverse = True
180 | else:
181 | reverse = False
182 | x_pred, mu, std_z0 = self.vae(x, x_time, reverse=reverse)
183 | return x_pred, mu, std_z0
184 |
185 | def training_step(self, batch, batch_idx):
186 | x0, x0_class, x1, x0_time, x1_time = batch
187 | t_span = x1_time.squeeze()
188 | x_pred, mu, std_z0 = self.forward(x1, x1_time, mode='train')
189 | # loss = self.loss_fn(x_pred[-1], x1)
190 | loss = self.loss_function(x_pred[-1], x0, mu, std_z0)
191 | self.log('train_loss', loss)
192 | return loss
193 |
194 | def configure_optimizers(self):
195 | return torch.optim.Adam(self.parameters(), lr=self.lr)
196 |
197 | def validation_step(self, batch, batch_idx):
198 | x0, x0_class, x1, x0_time, x1_time = batch
199 | x_pred, x1 = self.testing_vae(batch)
200 | loss = self.loss_fn(x_pred[-1], x1)
201 |
202 | # metrics
203 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
204 | for key, value in metricD.items():
205 | self.log(key+"_val", value)
206 |
207 | loss = self.loss_fn(x_pred[-1], x1)
208 | self.log('val_loss', loss)
209 | return loss
210 |
211 | def test_step(self, batch, batch_idx):
212 | x0, x0_class, x1, x0_time, x1_time = batch
213 | x_pred, x1 = self.testing_vae(batch)
214 | loss = self.loss_fn(x_pred[-1], x1)
215 |
216 | # Calculate metrics
217 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
218 |
219 | for key, value in metricD.items():
220 | self.log(key+"_test", value)
221 |
222 | self.log('test_loss', loss)
223 | return loss
224 |
225 |
226 |
227 | def loss_function(self, recon_x, x, mu, logvar):
228 | recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
229 | kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
230 | return recon_loss + kld_loss
231 |
232 |
233 | def sample(self, num_samples=1):
234 | """Sample from the latent space.
235 |
236 | Args:
237 | num_samples: Number of samples to generate.
238 | """
239 | z = torch.randn(num_samples, self.latent_dim)
240 | return self.decoder(z)
241 |
242 | def testing_vae(self, batch):
243 | x0, x0_class, x1, x0_time, x1_time = batch
244 | t_span = x1_time.squeeze()
245 | z0_mean, z0_logvar = self.encoder(x0)
246 | z0 = self.vae.reparameterize(z0_mean, z0_logvar) # sample
247 | #@HERE: squeeze z0?
248 | x_pred = self.vae.extrapolate(z0, t_span)
249 | return x_pred, x1
250 |
251 |
252 | class LatentODE_deprecated(pl.LightningModule):
253 | def __init__(self,
254 | input_dim,
255 | latent_dim,
256 | output_dim,
257 | lr=1e-3,
258 | loss_fn=nn.MSELoss(),
259 | metrics = ['mse_loss', 'l1_loss']):
260 | super().__init__()
261 | self.encoder = Encoder(input_dim, latent_dim)
262 | self.decoder = Decoder(latent_dim, output_dim)
263 | self.ode_func = LatentODEFunc(latent_dim)
264 | self.lr = lr
265 | self.loss_fn = loss_fn
266 | self.naming = 'LatentODE'
267 | self.metrics = metrics
268 |
269 | def forward(self, x0, t_span):
270 | z0 = self.encoder(x0)
271 | z_pred = odeint(self.ode_func, z0, t_span)
272 | x_pred = self.decoder(z_pred)
273 | return x_pred
274 |
275 | def training_step(self, batch, batch_idx):
276 | x0, x0_class, x1, x0_time, x1_time = batch
277 | t_span = x1_time.squeeze()
278 | x_pred = self.forward(x0, t_span)
279 | loss = self.loss_fn(x_pred[-1], x1)
280 | self.log('train_loss', loss)
281 | return loss
282 |
283 | def configure_optimizers(self):
284 | return torch.optim.Adam(self.parameters(), lr=self.lr)
285 |
286 | def validation_step(self, batch, batch_idx):
287 | x0, x0_class, x1, x0_time, x1_time = batch
288 | t_span = x1_time.squeeze()
289 | x_pred = self.forward(x0, t_span)
290 |
291 | # metrics
292 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
293 | for key, value in metricD.items():
294 | self.log(key+"_val", value)
295 |
296 | loss = self.loss_fn(x_pred[-1], x1)
297 | self.log('val_loss', loss)
298 | return loss
299 |
300 | def test_step(self, batch, batch_idx):
301 | x0, x0_class, x1, x0_time, x1_time = batch
302 | t_span = x1_time.squeeze()
303 | x_pred = self.forward(x0, t_span)
304 | loss = self.loss_fn(x_pred[-1], x1)
305 |
306 | # Calculate metrics
307 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
308 |
309 | for key, value in metricD.items():
310 | self.log(key+"_test", value)
311 |
312 | self.log('test_loss', loss)
313 | return loss
314 |
--------------------------------------------------------------------------------
/src/model/mlp_memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import ot as pot
8 | import torchdyn
9 | from torchdyn.core import NeuralODE
10 | import pytorch_lightning as pl
11 | from torch import optim
12 | import torch.functional as F
13 | import wandb
14 |
15 | from utils.visualize import *
16 | from utils.metric_calc import *
17 | from utils.sde import SDE
18 | from model.components.positional_encoding import *
19 | from model.components.mlp import * # in case need wrapper
20 | from model.components.grad_util import torch_wrapper_tv
21 | from utils.loss import mse_loss, l1_loss
22 |
23 |
24 |
25 | class MLP_Cond_Memory_Module(pl.LightningModule):
26 | def __init__(self,
27 | treatment_cond,
28 | memory=3, # can increase / tune to see effect
29 | dim=2,
30 | w=64,
31 | time_varying=True,
32 | conditional=True,
33 | lr=1e-6,
34 | sigma = 0.1,
35 | loss_fn = mse_loss,
36 | metrics = ['mse_loss', 'l1_loss'],
37 | implementation = "ODE", # can be SDE
38 | sde_noise = 0.1,
39 | clip = None,
40 | naming = None,
41 |
42 | ):
43 | super().__init__()
44 | self.model = MLP_conditional_liver_pe_memory(dim=dim,
45 | w=w,
46 | time_varying=time_varying,
47 | conditional=conditional,
48 | treatment_cond=treatment_cond,
49 | memory=memory,
50 | clip = clip)
51 | self.loss_fn = loss_fn
52 | self.save_hyperparameters()
53 | self.dim = dim
54 | # self.out_dim = out_dim
55 | self.w = w
56 | self.time_varying = time_varying
57 | self.conditional = conditional
58 | self.treatment_cond = treatment_cond
59 | self.lr = lr
60 | self.sigma = sigma
61 | self.naming = "MLP_Cond_memory_Module_"+implementation if naming is None else naming
62 | self.metrics = metrics
63 | self.implementation = implementation
64 | self.memory = memory
65 | self.sde_noise = sde_noise
66 | self.clip = clip
67 | if self.memory > 1:
68 | self.naming += "_Memory_"+str(self.memory)
69 |
70 | def __convert_tensor__(self, tensor):
71 | return tensor.to(torch.float32)
72 |
73 | def __x_processing__(self, x0, x1, t0, t1):
74 | # squeeze xs (prevent mismatch)
75 | x0 = x0.squeeze(0)
76 | x1 = x1.squeeze(0)
77 | t0 = t0.squeeze()
78 | t1 = t1.squeeze()
79 |
80 | t = torch.rand(x0.shape[0],1).to(x0.device)
81 | mu_t = x0 * (1 - t) + x1 * t
82 | data_t_diff = (t1 - t0).unsqueeze(1)
83 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)
84 |
85 | ut = (x1 - x0) / (data_t_diff + 1e-4)
86 | t_model = t * data_t_diff + t0.unsqueeze(1)
87 | futuretime = t1 - t_model
88 |
89 | return x, ut, t_model, futuretime, t
90 |
91 | def training_step(self, batch, batch_idx):
92 | """_summary_
93 |
94 | Args:
95 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
96 | batch_idx (_type_): _description_
97 |
98 | Returns:
99 | _type_: _description_
100 | """
101 | x0, x0_class, x1, x0_time, x1_time = batch
102 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)
103 |
104 |
105 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
106 |
107 |
108 | if len(x0_class.shape) == 3:
109 | x0_class = x0_class.squeeze()
110 |
111 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
112 | xt = self.model.forward_train(in_tensor)
113 |
114 | if self.implementation == "SDE":
115 | variance = t*(1-t)*self.sde_noise
116 | noise = torch.randn_like(xt[:,:self.dim]) * torch.sqrt(variance)
117 | loss = self.loss_fn(xt[:,:self.dim] + noise, x1) + self.loss_fn(xt[:,-1], futuretime)
118 | else:
119 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime)
120 | self.log('train_loss', loss)
121 | return loss
122 |
123 | def config_optimizer(self):
124 | return torch.optim.Adam(self.parameters(), lr=self.lr)
125 |
126 | def validation_step(self, batch, batch_idx):
127 | """validation_step
128 |
129 | Args:
130 | batch (_type_): batch size of 1 (since uneven)
131 | batch_idx (_type_): _description_
132 |
133 | Returns:
134 | _type_: _description_
135 | """
136 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val')
137 | self.log('val_loss', loss)
138 | for key, value in metricD.items():
139 | self.log(key+"_val", value)
140 | return {'val_loss':loss, 'traj_pairs':pairs}
141 |
142 | def test_step(self, batch, batch_idx):
143 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test')
144 | self.log('test_loss', loss)
145 | for key, value in metricD.items():
146 | self.log(key+"_test", value)
147 | return {'test_loss':loss, 'traj_pairs':pairs}
148 |
149 | def test_func_step(self, batch, batch_idx, mode='none'):
150 | """assuming each is one patient/batch"""
151 | total_loss = []
152 | traj_pairs = []
153 |
154 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
155 | times_x0 = times_x0.squeeze()
156 | times_x1 = times_x1.squeeze()
157 |
158 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0),
159 | x1_values[0,:,:self.dim]],
160 | dim=0)
161 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
162 | ind_loss, pred_traj = self.test_trajectory(batch)
163 | total_loss.append(ind_loss)
164 | traj_pairs.append([full_traj, pred_traj])
165 |
166 | full_traj = full_traj.detach().cpu().numpy()
167 | pred_traj = pred_traj.detach().cpu().numpy()
168 | full_time = full_time.detach().cpu().numpy()
169 |
170 | # graph
171 | fig = plot_3d_path_ind(pred_traj,
172 | full_traj,
173 | t_span=full_time,
174 | title="{}_trajectory_patient_{}".format(mode, batch_idx))
175 | if self.logger:
176 | # may cause problem if wandb disabled
177 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
178 |
179 | plt.close(fig)
180 |
181 | # metrics
182 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
183 | return np.mean(total_loss), traj_pairs, metricD
184 |
185 | def test_trajectory(self,pt_tensor):
186 | if self.implementation == "ODE":
187 | return self.test_trajectory_ode(pt_tensor)
188 | elif self.implementation == "SDE":
189 | return self.test_trajectory_sde(pt_tensor)
190 |
191 | def test_trajectory_ode(self,pt_tensor):
192 | """test_trajectory
193 |
194 | Args:
195 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
196 |
197 |
198 | Returns:
199 | mse_all, total_pred_tensor: _description_
200 | """
201 | node = NeuralODE(
202 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
203 | )
204 | total_pred = []
205 | mse = []
206 | t_max = 0
207 |
208 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
209 | # squeeze all
210 | x0_values = x0_values.squeeze(0)
211 | x1_values = x1_values.squeeze(0)
212 | times_x0 = times_x0.squeeze()
213 | times_x1 = times_x1.squeeze()
214 | x0_classes = x0_classes.squeeze()
215 |
216 | if len(x0_classes.shape) == 1:
217 | x0_classes = x0_classes.unsqueeze(1)
218 |
219 |
220 | total_pred.append(x0_values[0].unsqueeze(0))
221 | len_path = x0_values.shape[0]
222 | assert len_path == x1_values.shape[0]
223 |
224 | time_history = x0_classes[0][-(self.memory*self.dim):]
225 |
226 | for i in range(len_path):
227 | # print(i)
228 | t_max = (times_x1[i]-times_x0[i])
229 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
230 |
231 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
232 | with torch.no_grad():
233 | # get last pred, if none then use startpt
234 | if i == 0:
235 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
236 | else: # incorporate last prediction
237 | testpt = torch.cat([pred_traj, new_x_classes], dim=1)
238 | # print(testpt.shape)
239 | traj = node.trajectory(
240 | testpt,
241 | t_span=time_span,
242 | )
243 | pred_traj = traj[-1,:,:self.dim]
244 | total_pred.append(pred_traj)
245 | ground_truth_coords = x1_values[i]
246 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
247 |
248 | # history update
249 | flattened_coords = pred_traj.flatten()
250 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()
251 |
252 | mse_all = np.mean(mse)
253 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
254 | return mse_all, total_pred_tensor
255 |
256 | def configure_optimizers(self):
257 | # Define the optimizer
258 | optimizer = optim.Adam(self.parameters(), lr=self.lr)
259 | return optimizer
260 |
261 | def test_trajectory_sde(self,pt_tensor):
262 | """test_trajectory
263 |
264 | Args:
265 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
266 |
267 |
268 | Returns:
269 | mse_all, total_pred_tensor: _description_
270 | """
271 | sde = SDE(self.model, noise=0.1)
272 | total_pred = []
273 | mse = []
274 | t_max = 0
275 |
276 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
277 | # squeeze all
278 | x0_values = x0_values.squeeze(0)
279 | x1_values = x1_values.squeeze(0)
280 | times_x0 = times_x0.squeeze()
281 | times_x1 = times_x1.squeeze()
282 | x0_classes = x0_classes.squeeze()
283 |
284 | if len(x0_classes.shape) == 1:
285 | x0_classes = x0_classes.unsqueeze(1)
286 |
287 |
288 |
289 | total_pred.append(x0_values[0].unsqueeze(0))
290 | len_path = x0_values.shape[0]
291 | assert len_path == x1_values.shape[0]
292 |
293 | time_history = x0_classes[0][-(self.memory*self.dim):]
294 |
295 | for i in range(len_path):
296 |
297 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
298 |
299 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
300 | with torch.no_grad():
301 | # get last pred, if none then use startpt
302 | if i == 0:
303 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
304 | else: # incorporate last prediction
305 | testpt = torch.cat([pred_traj, new_x_classes], dim=1)
306 | # print(testpt.shape)
307 | traj = self._sde_solver(sde, testpt, time_span)
308 |
309 | pred_traj = traj[-1,:,:self.dim]
310 | total_pred.append(pred_traj)
311 | ground_truth_coords = x1_values[i]
312 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
313 |
314 | # history update
315 | flattened_coords = pred_traj.flatten()
316 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()
317 |
318 | mse_all = np.mean(mse)
319 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
320 | return mse_all, total_pred_tensor
321 |
322 |
323 | def _sde_solver(self, sde, initial_state, time_span):
324 | dt = time_span[1] - time_span[0] # Time step
325 | current_state = initial_state
326 | trajectory = [current_state]
327 |
328 | for t in time_span[1:]:
329 | drift = sde.f(t, current_state)
330 | diffusion = sde.g(t, current_state)
331 | noise = torch.randn_like(current_state) * torch.sqrt(dt)
332 | current_state = current_state + drift * dt + diffusion * noise
333 | trajectory.append(current_state)
334 |
335 | return torch.stack(trajectory)
336 |
337 |
338 |
339 |
--------------------------------------------------------------------------------
/src/model/mlp_noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import ot as pot
8 | import torchdyn
9 | from torchdyn.core import NeuralODE
10 | import pytorch_lightning as pl
11 | from torch import optim
12 | import torch.functional as F
13 | import wandb
14 |
15 | from utils.visualize import *
16 | from utils.metric_calc import *
17 | from utils.sde import SDE
18 | from model.components.positional_encoding import *
19 | from model.components.mlp import *
20 | from model.components.sde_func_solver import *
21 | from model.components.grad_util import *
22 |
23 | class MLP_conditional_memory(torch.nn.Module):
24 | """ Conditional with many available classes
25 |
26 | return the class as is
27 | """
28 | def __init__(self,
29 | dim,
30 | treatment_cond,
31 | memory, # how many time steps
32 | out_dim=None,
33 | w=64,
34 | time_varying=False,
35 | conditional=False,
36 | time_dim = NUM_FREQS * 2,
37 | clip = None,
38 | ):
39 | super().__init__()
40 | self.time_varying = time_varying
41 | if out_dim is None:
42 | self.out_dim = dim
43 | self.out_dim += 1 # for the time dimension
44 | self.treatment_cond = treatment_cond
45 | self.memory = memory
46 | self.dim = dim
47 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory)
48 | self.net = torch.nn.Sequential(
49 | torch.nn.Linear(self.indim, w),
50 | torch.nn.SELU(),
51 | torch.nn.Linear(w, w),
52 | torch.nn.SELU(),
53 | torch.nn.Linear(w, w),
54 | torch.nn.SELU(),
55 | torch.nn.Linear(w,self.out_dim),
56 | )
57 | self.default_class = 0
58 | self.clip = clip
59 | # self.encoding_function = positional_encoding_tensor()
60 |
61 | def encoding_function(self, time_tensor):
62 | return positional_encoding_tensor(time_tensor)
63 |
64 | def forward_train(self, x):
65 | """forward pass
66 | Assume first two dimensions are x, c, then t
67 | """
68 | time_tensor = x[:,-1]
69 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
70 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1)
71 | result = self.net(new_x)
72 | return torch.cat([result[:,:-1], x[:,self.dim:-1], result[:,-1].unsqueeze(1)], dim=1)
73 |
74 | def forward(self, x):
75 | """ call forward_train for training
76 | x here is x_t
77 | xt = (t)x1 + (1-t)x0
78 | (xt - tx1)/(1-t) = x0
79 | """
80 | x1 = self.forward_train(x)
81 | x1_coord = x1[:,:self.dim]
82 | t = x[:,-1]
83 | pred_time_till_t1 = x1[:,-1]
84 | x_coord = x[:,:self.dim]
85 | if self.clip is None:
86 | vt = (x1_coord - x_coord)/(pred_time_till_t1)
87 | else:
88 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip)
89 |
90 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1)
91 | return final_vt
92 |
93 |
94 | class MLP_conditional_memory_sde_noise(torch.nn.Module):
95 | """ Conditional with many available classes
96 |
97 | return the class as is
98 | """
99 | def __init__(self,
100 | dim,
101 | treatment_cond,
102 | memory, # how many time steps
103 | out_dim=None,
104 | w=64,
105 | time_varying=False,
106 | conditional=False,
107 | time_dim = NUM_FREQS * 2,
108 | clip = None,
109 | ):
110 | super().__init__()
111 | self.time_varying = time_varying
112 | if out_dim is None:
113 | self.out_dim = 1 # for noise
114 | self.treatment_cond = treatment_cond
115 | self.memory = memory
116 | self.dim = dim
117 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory)
118 | self.net = torch.nn.Sequential(
119 | torch.nn.Linear(self.indim, w),
120 | torch.nn.SELU(),
121 | torch.nn.Linear(w, w),
122 | torch.nn.SELU(),
123 | torch.nn.Linear(w, w),
124 | torch.nn.SELU(),
125 | torch.nn.Linear(w,self.out_dim),
126 | )
127 | self.default_class = 0
128 | self.clip = clip
129 | # self.encoding_function = positional_encoding_tensor()
130 |
131 | def encoding_function(self, time_tensor):
132 | return positional_encoding_tensor(time_tensor)
133 |
134 | def forward_train(self, x):
135 | """forward pass
136 | Assume first two dimensions are x, c, then t
137 | """
138 | time_tensor = x[:,-1]
139 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
140 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1)
141 | result = self.net(new_x)
142 | return result
143 |
144 | def forward(self,x):
145 | result = self.forward_train(x)
146 | return torch.cat([result, torch.zeros_like(x[:,1:-1])], dim=1)
147 |
148 | """ Lightning module """
149 | def mse_loss(pred, true):
150 | return torch.mean((pred - true) ** 2)
151 |
152 | def l1_loss(pred, true):
153 | return torch.mean(torch.abs(pred - true))
154 |
155 | class Noise_MLP_Cond_Memory_Module(pl.LightningModule):
156 | def __init__(self,
157 | treatment_cond,
158 | memory=3, # can increase / tune to see effect
159 | dim=2,
160 | w=64,
161 | time_varying=True,
162 | conditional=True,
163 | lr=1e-6,
164 | sigma = 0.1,
165 | loss_fn = mse_loss,
166 | metrics = ['mse_loss', 'l1_loss'],
167 | implementation = "ODE", # can be SDE
168 | sde_noise = 0.1,
169 | clip = None,
170 | naming = None,
171 |
172 | ):
173 | super().__init__()
174 | self.flow_model = MLP_conditional_memory(dim=dim,
175 | w=w,
176 | time_varying=time_varying,
177 | conditional=conditional,
178 | treatment_cond=treatment_cond,
179 | memory=memory,
180 | clip = clip)
181 | if implementation == "SDE":
182 | self.noise_model = MLP_conditional_memory_sde_noise(dim=dim, # @TODO: give \hat{x_1}?
183 | w=w,
184 | time_varying=time_varying,
185 | conditional=conditional,
186 | treatment_cond=treatment_cond,
187 | memory=memory,
188 | clip = clip)
189 | else:
190 | self.noise_model = MLP_conditional_memory(dim=dim, # @TODO: give \hat{x_1}?
191 | w=w,
192 | time_varying=time_varying,
193 | conditional=conditional,
194 | treatment_cond=treatment_cond,
195 | memory=memory,
196 | clip = clip)
197 | self.automatic_optimization = False
198 | self.loss_fn = loss_fn
199 | self.save_hyperparameters()
200 | self.dim = dim
201 | # self.out_dim = out_dim
202 | self.w = w
203 | self.time_varying = time_varying
204 | self.conditional = conditional
205 | self.treatment_cond = treatment_cond
206 | self.lr = lr
207 | self.sigma = sigma
208 | self.naming = "Noise_MLP_Cond_memory_Module_"+implementation if naming is None else naming
209 | self.metrics = metrics
210 | self.implementation = implementation
211 | self.memory = memory
212 | self.sde_noise = sde_noise
213 | self.clip = clip
214 | if self.memory > 1:
215 | self.naming += "_Memory_"+str(self.memory)
216 |
217 | def __convert_tensor__(self, tensor):
218 | return tensor.to(torch.float32)
219 |
220 | def __x_processing__(self, x0, x1, t0, t1):
221 | # squeeze xs (prevent mismatch)
222 | x0 = x0.squeeze(0)
223 | x1 = x1.squeeze(0)
224 | t0 = t0.squeeze()
225 | t1 = t1.squeeze()
226 |
227 | t = torch.rand(x0.shape[0],1).to(x0.device)
228 | mu_t = x0 * (1 - t) + x1 * t
229 | data_t_diff = (t1 - t0).unsqueeze(1)
230 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)
231 |
232 | ut = (x1 - x0) / (data_t_diff + 1e-4)
233 | t_model = t * data_t_diff + t0.unsqueeze(1)
234 | futuretime = t1 - t_model
235 |
236 | return x, ut, t_model, futuretime, t
237 |
238 | def training_step(self, batch, batch_idx):
239 | """_summary_
240 |
241 | Args:
242 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
243 | batch_idx (_type_): _description_
244 |
245 | Returns:
246 | _type_: _description_
247 | """
248 | flow_opt, noise_opt = self.optimizers()
249 |
250 | x0, x0_class, x1, x0_time, x1_time = batch
251 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)
252 |
253 |
254 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
255 |
256 |
257 | if len(x0_class.shape) == 3:
258 | x0_class = x0_class.squeeze()
259 |
260 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
261 | xt = self.flow_model.forward_train(in_tensor)
262 |
263 | if self.implementation == "SDE":
264 | sde_noise = self.noise_model.forward_train(in_tensor)
265 | variance = torch.sqrt(t*(1-t))*sde_noise
266 | noise = torch.randn_like(xt[:,:self.dim]) * variance
267 | loss = self.loss_fn(xt[:,:self.dim] + noise.clone().detach(), x1) + self.loss_fn(xt[:,-1], futuretime)
268 | uncertainty =(xt[:,:self.dim].clone().detach() + noise)
269 | noise_loss = self.loss_fn(uncertainty,x1)
270 | else:
271 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime)
272 | uncertainty = torch.abs(xt[:,:self.dim].clone().detach() - x1)
273 | # noise model incorporation (model loss)
274 | noise_loss = self.loss_fn(self.noise_model.forward_train(in_tensor)[:,:self.dim], uncertainty)
275 |
276 | flow_opt.zero_grad()
277 | self.manual_backward(loss)
278 | flow_opt.step()
279 |
280 | noise_opt.zero_grad()
281 | self.manual_backward(noise_loss)
282 | noise_opt.step()
283 |
284 | self.log('train_loss', loss)
285 | self.log('noise_loss', noise_loss)
286 | return loss + noise_loss
287 |
288 | def configure_optimizers(self):
289 | print("configuring optimizers")
290 | self.flow_optimizer = torch.optim.Adam(self.flow_model.parameters(), lr=self.lr)
291 | self.noise_optimizer = torch.optim.Adam(self.noise_model.parameters(), lr=self.lr)
292 | return [self.flow_optimizer, self.noise_optimizer]
293 |
294 | def validation_step(self, batch, batch_idx):
295 | """validation_step
296 |
297 | Args:
298 | batch (_type_): batch size of 1 (since uneven)
299 | batch_idx (_type_): _description_
300 |
301 | Returns:
302 | _type_: _description_
303 | """
304 | loss, pairs, metricD, noise_loss, noise_pair = self.test_func_step(batch, batch_idx, mode='val')
305 | self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True)
306 | self.log('noise_val_loss', noise_loss, on_epoch=True, on_step=False, sync_dist=True)
307 | for key, value in metricD.items():
308 | self.log(key+"_val", value, on_epoch=True, on_step=False, sync_dist=True)
309 | # return total_loss, traj_pairs
310 | return {'val_loss':loss, 'traj_pairs':pairs}
311 |
312 | def test_step(self, batch, batch_idx):
313 | loss, pairs, metricD, noise_loss, noise_pair = self.test_func_step(batch, batch_idx, mode='test')
314 | self.log('test_loss', loss, on_epoch=True, on_step=False, sync_dist=True)
315 | self.log('noise_test_loss', noise_loss, on_epoch=True, on_step=False, sync_dist=True)
316 | for key, value in metricD.items():
317 | self.log(key+"_test", value, on_epoch=True, on_step=False, sync_dist=True)
318 | # return total_loss, traj_pairs
319 | return {'test_loss':loss, 'traj_pairs':pairs}
320 |
321 | def test_func_step(self, batch, batch_idx, mode='none'):
322 | """assuming each is one patient/batch"""
323 | total_loss = []
324 | traj_pairs = []
325 |
326 | total_noise_loss = []
327 | noise_pairs = []
328 |
329 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
330 | times_x0 = times_x0.squeeze()
331 | times_x1 = times_x1.squeeze()
332 |
333 | # print(x0_values.shape)
334 | # print(x1_values.shape)
335 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0),
336 | x1_values[0,:,:self.dim]],
337 | dim=0)
338 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
339 | ind_loss, pred_traj, noise_mse, noise_pred = self.test_trajectory(batch)
340 | total_loss.append(ind_loss)
341 | traj_pairs.append([full_traj, pred_traj])
342 | noise_pairs.append([full_traj, noise_pred])
343 | total_noise_loss.append(noise_mse)
344 |
345 | full_traj = full_traj.detach().cpu().numpy()
346 | pred_traj = pred_traj.detach().cpu().numpy()
347 | full_time = full_time.detach().cpu().numpy()
348 |
349 | # graph
350 | fig = plot_3d_path_ind_noise(pred_traj,
351 | full_traj,
352 | noise_pred,
353 | t_span=full_time,
354 | title="{}_trajectory_patient_{}".format(mode, batch_idx))
355 | if self.logger:
356 | # may cause problem if wandb disabled
357 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
358 |
359 | plt.close(fig)
360 |
361 | # metrics
362 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
363 | return np.mean(total_loss), traj_pairs, metricD, np.mean(total_noise_loss), noise_pairs
364 |
365 | def test_trajectory(self,pt_tensor):
366 | if self.implementation == "ODE":
367 | return self.test_trajectory_ode(pt_tensor)
368 | elif self.implementation == "SDE":
369 | return self.test_trajectory_sde(pt_tensor)
370 |
371 | def test_trajectory_ode(self,pt_tensor):
372 | """test_trajectory
373 |
374 | Args:
375 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
376 |
377 |
378 | Returns:
379 | mse_all, total_pred_tensor: _description_
380 | """
381 | node = NeuralODE(
382 | torch_wrapper_tv(self.flow_model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
383 | )
384 | node_noise = NeuralODE(
385 | torch_wrapper_tv(self.noise_model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
386 | )
387 | total_pred = []
388 | noise_pred = []
389 | mse = []
390 | noise_mse = []
391 | # t_max = 0
392 |
393 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
394 | # squeeze all
395 | x0_values = x0_values.squeeze(0)
396 | x1_values = x1_values.squeeze(0)
397 | times_x0 = times_x0.squeeze()
398 | times_x1 = times_x1.squeeze()
399 | x0_classes = x0_classes.squeeze()
400 |
401 | if len(x0_classes.shape) == 1:
402 | x0_classes = x0_classes.unsqueeze(1)
403 |
404 | total_pred.append(x0_values[0].unsqueeze(0))
405 | len_path = x0_values.shape[0]
406 | assert len_path == x1_values.shape[0]
407 |
408 | time_history = x0_classes[0][-(self.memory*self.dim):]
409 |
410 | for i in range(len_path):
411 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
412 |
413 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
414 | with torch.no_grad():
415 | # get last pred, if none then use startpt
416 | if i == 0:
417 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
418 | else: # incorporate last prediction
419 | testpt = torch.cat([pred_traj, new_x_classes], dim=1)
420 | # print(testpt.shape)
421 | traj = node.trajectory(
422 | testpt,
423 | t_span=time_span,
424 | )
425 | # add noise prediction
426 | noise_traj = node_noise.trajectory(
427 | testpt,
428 | t_span=time_span,
429 | )
430 |
431 | pred_traj = traj[-1,:,:self.dim]
432 | noise_traj = noise_traj[-1,:,:self.dim]
433 | total_pred.append(pred_traj)
434 | noise_pred.append(noise_traj)
435 |
436 | ground_truth_coords = x1_values[i]
437 | mse_traj = self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()
438 | mse.append(mse_traj)
439 | uncertainty_traj = ground_truth_coords - pred_traj
440 | noise_mse_traj = self.loss_fn(noise_traj, uncertainty_traj).detach().cpu().numpy()
441 | noise_mse.append(noise_mse_traj)
442 |
443 | # history update
444 | flattened_coords = pred_traj.flatten()
445 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()
446 |
447 | mse_all = np.mean(mse)
448 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
449 | noise_pred = torch.stack(noise_pred).squeeze(1)
450 | return mse_all, total_pred_tensor, noise_mse, noise_pred
451 |
452 |
453 | def test_trajectory_sde(self,pt_tensor):
454 | """test_trajectory
455 |
456 | Args:
457 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),
458 |
459 |
460 | Returns:
461 | mse_all, total_pred_tensor: _description_
462 | """
463 | sde = SDE_func_solver(self.flow_model, noise=self.noise_model)
464 | total_pred = []
465 | mse = []
466 | noise_pred = []
467 | noise_mse = []
468 | t_max = 0
469 |
470 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
471 | # squeeze all
472 | x0_values = x0_values.squeeze(0)
473 | x1_values = x1_values.squeeze(0)
474 | times_x0 = times_x0.squeeze()
475 | times_x1 = times_x1.squeeze()
476 | x0_classes = x0_classes.squeeze()
477 |
478 | if len(x0_classes.shape) == 1:
479 | x0_classes = x0_classes.unsqueeze(1)
480 |
481 |
482 |
483 | total_pred.append(x0_values[0].unsqueeze(0))
484 | len_path = x0_values.shape[0]
485 | assert len_path == x1_values.shape[0]
486 |
487 | time_history = x0_classes[0][-(self.memory*self.dim):]
488 |
489 | for i in range(len_path):
490 |
491 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
492 |
493 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
494 | with torch.no_grad():
495 | # get last pred, if none then use startpt
496 | if i == 0:
497 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
498 | else: # incorporate last prediction
499 | testpt = torch.cat([pred_traj, new_x_classes], dim=1)
500 | traj, noise_traj = self._sde_solver(sde, testpt, time_span)
501 |
502 | pred_traj = traj[-1,:,:self.dim]
503 | noise_traj = noise_traj[-1,:,:self.dim]
504 |
505 | total_pred.append(pred_traj)
506 | noise_pred.append(noise_traj)
507 |
508 | ground_truth_coords = x1_values[i]
509 | calculated_mse = self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()
510 | mse.append(calculated_mse)
511 | noise_mse.append(calculated_mse)
512 |
513 | # history update
514 | flattened_coords = pred_traj.flatten()
515 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()
516 |
517 |
518 |
519 | mse_all = np.mean(mse)
520 | noise_mse_all = np.mean(noise_mse)
521 | total_pred_tensor = torch.stack(total_pred).squeeze(1)
522 | noise_pred_tensor = torch.stack(noise_pred).squeeze(1)
523 | return mse_all, total_pred_tensor, noise_mse_all, noise_pred_tensor
524 |
525 |
526 | def _sde_solver(self, sde, initial_state, time_span):
527 | dt = time_span[1] - time_span[0] # Time step
528 | current_state = initial_state
529 | trajectory = [current_state]
530 | noise_trajectory = []
531 |
532 | for t in time_span[1:]:
533 | drift = sde.f(t, current_state)
534 | diffusion = sde.g(t, current_state)
535 | noise = torch.randn_like(current_state) * torch.sqrt(dt)
536 | current_state = current_state + drift * dt + diffusion * noise # @NEED this or not?
537 | trajectory.append(current_state)
538 | pred_diff = diffusion * noise
539 | noise_trajectory.append(pred_diff)
540 |
541 | return torch.stack(trajectory), torch.stack(noise_trajectory)
542 |
543 |
544 |
545 |
546 |
--------------------------------------------------------------------------------
/src/model/ode_baseline.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Neural ODE and SDE baseline models.
4 | lr = 1e-3
5 |
6 | """
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import pytorch_lightning as pl
12 | from torchdiffeq import odeint
13 | from utils.metric_calc import *
14 | class ODEFunc(nn.Module):
15 | def __init__(self, dim, w=64):
16 | super().__init__()
17 | self.net = nn.Sequential(
18 | nn.Linear(dim, w),
19 | nn.Tanh(),
20 | nn.Linear(w, w),
21 | nn.Tanh(),
22 | nn.Linear(w, dim),
23 | )
24 |
25 | def forward(self, t, x):
26 | return self.net(x)
27 |
28 | class ODEBaseline(pl.LightningModule):
29 | def __init__(self,
30 | dim=2,
31 | w=64,
32 | lr=1e-5,
33 | loss_fn=nn.MSELoss(),
34 | metrics = ['mse_loss', 'l1_loss']):
35 | super().__init__()
36 | self.ode_func = ODEFunc(dim, w)
37 | self.lr = lr
38 | self.loss_fn = loss_fn
39 | self.naming = 'ODEBaseline'
40 | self.metrics = metrics
41 |
42 | def forward(self, x0, t_span):
43 | return odeint(self.ode_func, x0, t_span)
44 |
45 | def training_step(self, batch, batch_idx):
46 | """x0, x0_class, x1, x0_time, x1_time """
47 | x0, x0_class, x1, x0_time, x1_time = batch
48 | t_span = x1_time.squeeze() #- x0_time
49 | print("training_step")
50 | # print(x0.shape, x1.shape, t_span.shape) # torch.Size([256, 2]) torch.Size([256, 2]) torch.Size([256])
51 | x_pred = self.forward(x0, t_span)
52 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
53 | self.log('train_loss', loss)
54 | return loss
55 |
56 | def configure_optimizers(self):
57 | return torch.optim.Adam(self.parameters(), lr=self.lr)
58 |
59 | def validation_step(self, batch, batch_idx):
60 | x0, x0_class, x1, x0_time, x1_time = batch
61 | t_span = x1_time.squeeze()
62 | x_pred = self.forward(x0, t_span)
63 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
64 |
65 | # metrics
66 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
67 | for k, v in metricsD.items():
68 | self.log(f'{k}_val', v)
69 |
70 | self.log('val_loss', loss)
71 | return loss
72 |
73 | def test_step(self, batch, batch_idx):
74 | x0, x0_class, x1, x0_time, x1_time = batch
75 | t_span = x1_time.squeeze()
76 | x_pred = self.forward(x0, t_span)
77 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
78 | # metrics
79 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
80 | for k, v in metricsD.items():
81 | self.log(f'{k}_test', v)
82 |
83 | self.log('test_loss', loss)
84 | return loss
85 |
86 | # sde
87 | import torchsde
88 |
89 | class SDEFunc(nn.Module):
90 | noise_type = 'diagonal' # diagonal = noise is uncorrelated across dimensions
91 | sde_type = 'ito'
92 | def __init__(self, dim, w=64):
93 | super().__init__()
94 | self.mu = nn.Sequential(
95 | nn.Linear(dim, w),
96 | nn.Tanh(),
97 | nn.Linear(w, w),
98 | nn.Tanh(),
99 | nn.Linear(w, dim),
100 | )
101 | self.sigma = nn.Sequential(
102 | nn.Linear(dim, w),
103 | nn.Tanh(),
104 | nn.Linear(w, w),
105 | nn.Tanh(),
106 | nn.Linear(w, dim),
107 | )
108 |
109 | def f(self, t, x):
110 | return self.mu(x)
111 |
112 | def g(self, t, x):
113 | return self.sigma(x)
114 |
115 | class SDEBaseline(pl.LightningModule):
116 | def __init__(self,
117 | dim=2,
118 | w=64,
119 | lr=1e-5,
120 | loss_fn=nn.MSELoss(),
121 | metrics = ['mse_loss', 'l1_loss']):
122 | super().__init__()
123 | self.sde_func = SDEFunc(dim, w)
124 | self.lr = lr
125 | self.loss_fn = loss_fn
126 | self.naming = 'SDEBaseline'
127 | self.metrics = metrics
128 |
129 | def forward(self, x0, t_span):
130 | # print(x0.shape)
131 | x0 = x0.squeeze()
132 | batch_size, dim = x0.shape
133 | bm = torchsde.BrownianInterval(t0=t_span[0],
134 | t1=t_span[-1],
135 | dtype=x0.dtype,
136 | device=x0.device,
137 | size=(batch_size, dim),
138 | levy_area_approximation="space-time")
139 | return torchsde.sdeint(self.sde_func, x0, t_span, bm=bm)
140 |
141 | def training_step(self, batch, batch_idx):
142 | x0, x0_class, x1, x0_time, x1_time = batch
143 | t_span = x1_time.squeeze() #- x0_time
144 | # x0, x1, t_span = batch
145 | x_pred = self.forward(x0, t_span)
146 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
147 | self.log('train_loss', loss)
148 | return loss
149 |
150 | def configure_optimizers(self):
151 | return torch.optim.Adam(self.parameters(), lr=self.lr)
152 |
153 | def validation_step(self, batch, batch_idx):
154 | x0, x0_class, x1, x0_time, x1_time = batch
155 | t_span = x1_time.squeeze()
156 | x_pred = self.forward(x0, t_span)
157 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
158 | # metrics
159 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
160 | for k, v in metricsD.items():
161 | self.log(f'{k}_val', v)
162 |
163 | self.log('val_loss', loss)
164 | return loss
165 |
166 | def test_step(self, batch, batch_idx):
167 | x0, x0_class, x1, x0_time, x1_time = batch
168 | t_span = x1_time.squeeze()
169 | x_pred = self.forward(x0, t_span)
170 | loss = self.loss_fn(x_pred[-1], x1.squeeze())
171 | # metrics
172 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics)
173 | for k, v in metricsD.items():
174 | self.log(f'{k}_test', v)
175 |
176 | self.log('test_loss', loss)
177 | return loss
--------------------------------------------------------------------------------
/src/utils/latent_ode_utils.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import os
7 | import logging
8 | import pickle
9 |
10 | import torch
11 | import torch.nn as nn
12 | import numpy as np
13 | import pandas as pd
14 | import math
15 | import glob
16 | import re
17 | from shutil import copyfile
18 | import sklearn as sk
19 | import subprocess
20 | import datetime
21 |
22 | def makedirs(dirname):
23 | if not os.path.exists(dirname):
24 | os.makedirs(dirname)
25 |
26 | def save_checkpoint(state, save, epoch):
27 | if not os.path.exists(save):
28 | os.makedirs(save)
29 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
30 | torch.save(state, filename)
31 |
32 |
33 | def get_logger(logpath, filepath, package_files=[],
34 | displaying=True, saving=True, debug=False):
35 | logger = logging.getLogger()
36 | if debug:
37 | level = logging.DEBUG
38 | else:
39 | level = logging.INFO
40 | logger.setLevel(level)
41 | if saving:
42 | info_file_handler = logging.FileHandler(logpath, mode='w')
43 | info_file_handler.setLevel(level)
44 | logger.addHandler(info_file_handler)
45 | if displaying:
46 | console_handler = logging.StreamHandler()
47 | console_handler.setLevel(level)
48 | logger.addHandler(console_handler)
49 | logger.info(filepath)
50 |
51 | for f in package_files:
52 | logger.info(f)
53 | with open(f, 'r') as package_f:
54 | logger.info(package_f.read())
55 |
56 | return logger
57 |
58 |
59 | def inf_generator(iterable):
60 | """Allows training with DataLoaders in a single infinite loop:
61 | for i, (x, y) in enumerate(inf_generator(train_loader)):
62 | """
63 | iterator = iterable.__iter__()
64 | while True:
65 | try:
66 | yield iterator.__next__()
67 | except StopIteration:
68 | iterator = iterable.__iter__()
69 |
70 | def dump_pickle(data, filename):
71 | with open(filename, 'wb') as pkl_file:
72 | pickle.dump(data, pkl_file)
73 |
74 | def load_pickle(filename):
75 | with open(filename, 'rb') as pkl_file:
76 | filecontent = pickle.load(pkl_file)
77 | return filecontent
78 |
79 | def make_dataset(dataset_type = "spiral",**kwargs):
80 | if dataset_type == "spiral":
81 | data_path = "data/spirals.pickle"
82 | dataset = load_pickle(data_path)["dataset"]
83 | chiralities = load_pickle(data_path)["chiralities"]
84 | elif dataset_type == "chiralspiral":
85 | data_path = "data/chiral-spirals.pickle"
86 | dataset = load_pickle(data_path)["dataset"]
87 | chiralities = load_pickle(data_path)["chiralities"]
88 | else:
89 | raise Exception("Unknown dataset type " + dataset_type)
90 | return dataset, chiralities
91 |
92 |
93 | def split_last_dim(data):
94 | last_dim = data.size()[-1]
95 | last_dim = last_dim//2
96 |
97 | if len(data.size()) == 3:
98 | res = data[:,:,:last_dim], data[:,:,last_dim:]
99 |
100 | if len(data.size()) == 2:
101 | res = data[:,:last_dim], data[:,last_dim:]
102 | return res
103 |
104 |
105 | def init_network_weights(net, std = 0.1):
106 | for m in net.modules():
107 | if isinstance(m, nn.Linear):
108 | nn.init.normal_(m.weight, mean=0, std=std)
109 | nn.init.constant_(m.bias, val=0)
110 |
111 |
112 | def flatten(x, dim):
113 | return x.reshape(x.size()[:dim] + (-1, ))
114 |
115 |
116 | def subsample_timepoints(data, time_steps, mask, n_tp_to_sample = None):
117 | # n_tp_to_sample: number of time points to subsample. If not None, sample exactly n_tp_to_sample points
118 | if n_tp_to_sample is None:
119 | return data, time_steps, mask
120 | n_tp_in_batch = len(time_steps)
121 |
122 |
123 | if n_tp_to_sample > 1:
124 | # Subsample exact number of points
125 | assert(n_tp_to_sample <= n_tp_in_batch)
126 | n_tp_to_sample = int(n_tp_to_sample)
127 |
128 | for i in range(data.size(0)):
129 | missing_idx = sorted(np.random.choice(np.arange(n_tp_in_batch), n_tp_in_batch - n_tp_to_sample, replace = False))
130 |
131 | data[i, missing_idx] = 0.
132 | if mask is not None:
133 | mask[i, missing_idx] = 0.
134 |
135 | elif (n_tp_to_sample <= 1) and (n_tp_to_sample > 0):
136 | # Subsample percentage of points from each time series
137 | percentage_tp_to_sample = n_tp_to_sample
138 | for i in range(data.size(0)):
139 | # take mask for current training sample and sum over all features -- figure out which time points don't have any measurements at all in this batch
140 | current_mask = mask[i].sum(-1).cpu()
141 | non_missing_tp = np.where(current_mask > 0)[0]
142 | n_tp_current = len(non_missing_tp)
143 | n_to_sample = int(n_tp_current * percentage_tp_to_sample)
144 | subsampled_idx = sorted(np.random.choice(non_missing_tp, n_to_sample, replace = False))
145 | tp_to_set_to_zero = np.setdiff1d(non_missing_tp, subsampled_idx)
146 |
147 | data[i, tp_to_set_to_zero] = 0.
148 | if mask is not None:
149 | mask[i, tp_to_set_to_zero] = 0.
150 |
151 | return data, time_steps, mask
152 |
153 |
154 |
155 | def cut_out_timepoints(data, time_steps, mask, n_points_to_cut = None):
156 | # n_points_to_cut: number of consecutive time points to cut out
157 | if n_points_to_cut is None:
158 | return data, time_steps, mask
159 | n_tp_in_batch = len(time_steps)
160 |
161 | if n_points_to_cut < 1:
162 | raise Exception("Number of time points to cut out must be > 1")
163 |
164 | assert(n_points_to_cut <= n_tp_in_batch)
165 | n_points_to_cut = int(n_points_to_cut)
166 |
167 | for i in range(data.size(0)):
168 | start = np.random.choice(np.arange(5, n_tp_in_batch - n_points_to_cut-5), replace = False)
169 |
170 | data[i, start : (start + n_points_to_cut)] = 0.
171 | if mask is not None:
172 | mask[i, start : (start + n_points_to_cut)] = 0.
173 |
174 | return data, time_steps, mask
175 |
176 |
177 |
178 |
179 |
180 | def get_device(tensor):
181 | device = torch.device("cpu")
182 | if tensor.is_cuda:
183 | device = tensor.get_device()
184 | return device
185 |
186 | def sample_standard_gaussian(mu, sigma):
187 | device = get_device(mu)
188 |
189 | d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
190 | r = d.sample(mu.size()).squeeze(-1)
191 | return r * sigma.float() + mu.float()
192 |
193 |
194 | def split_train_test(data, train_fraq = 0.8):
195 | n_samples = data.size(0)
196 | data_train = data[:int(n_samples * train_fraq)]
197 | data_test = data[int(n_samples * train_fraq):]
198 | return data_train, data_test
199 |
200 | def split_train_test_data_and_time(data, time_steps, train_fraq = 0.8):
201 | n_samples = data.size(0)
202 | data_train = data[:int(n_samples * train_fraq)]
203 | data_test = data[int(n_samples * train_fraq):]
204 |
205 | assert(len(time_steps.size()) == 2)
206 | train_time_steps = time_steps[:, :int(n_samples * train_fraq)]
207 | test_time_steps = time_steps[:, int(n_samples * train_fraq):]
208 |
209 | return data_train, data_test, train_time_steps, test_time_steps
210 |
211 |
212 |
213 | def get_next_batch(dataloader):
214 | # Make the union of all time points and perform normalization across the whole dataset
215 | data_dict = dataloader.__next__()
216 |
217 | batch_dict = get_dict_template()
218 |
219 | # remove the time points where there are no observations in this batch
220 | non_missing_tp = torch.sum(data_dict["observed_data"],(0,2)) != 0.
221 | batch_dict["observed_data"] = data_dict["observed_data"][:, non_missing_tp]
222 | batch_dict["observed_tp"] = data_dict["observed_tp"][non_missing_tp]
223 |
224 | # print("observed data")
225 | # print(batch_dict["observed_data"].size())
226 |
227 | if ("observed_mask" in data_dict) and (data_dict["observed_mask"] is not None):
228 | batch_dict["observed_mask"] = data_dict["observed_mask"][:, non_missing_tp]
229 |
230 | batch_dict[ "data_to_predict"] = data_dict["data_to_predict"]
231 | batch_dict["tp_to_predict"] = data_dict["tp_to_predict"]
232 |
233 | non_missing_tp = torch.sum(data_dict["data_to_predict"],(0,2)) != 0.
234 | batch_dict["data_to_predict"] = data_dict["data_to_predict"][:, non_missing_tp]
235 | batch_dict["tp_to_predict"] = data_dict["tp_to_predict"][non_missing_tp]
236 |
237 | # print("data_to_predict")
238 | # print(batch_dict["data_to_predict"].size())
239 |
240 | if ("mask_predicted_data" in data_dict) and (data_dict["mask_predicted_data"] is not None):
241 | batch_dict["mask_predicted_data"] = data_dict["mask_predicted_data"][:, non_missing_tp]
242 |
243 | if ("labels" in data_dict) and (data_dict["labels"] is not None):
244 | batch_dict["labels"] = data_dict["labels"]
245 |
246 | batch_dict["mode"] = data_dict["mode"]
247 | return batch_dict
248 |
249 |
250 |
251 | def get_ckpt_model(ckpt_path, model, device):
252 | if not os.path.exists(ckpt_path):
253 | raise Exception("Checkpoint " + ckpt_path + " does not exist.")
254 | # Load checkpoint.
255 | checkpt = torch.load(ckpt_path)
256 | ckpt_args = checkpt['args']
257 | state_dict = checkpt['state_dict']
258 | model_dict = model.state_dict()
259 |
260 | # 1. filter out unnecessary keys
261 | state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
262 | # 2. overwrite entries in the existing state dict
263 | model_dict.update(state_dict)
264 | # 3. load the new state dict
265 | model.load_state_dict(state_dict)
266 | model.to(device)
267 |
268 |
269 | def update_learning_rate(optimizer, decay_rate = 0.999, lowest = 1e-3):
270 | for param_group in optimizer.param_groups:
271 | lr = param_group['lr']
272 | lr = max(lr * decay_rate, lowest)
273 | param_group['lr'] = lr
274 |
275 |
276 | def linspace_vector(start, end, n_points):
277 | # start is either one value or a vector
278 | size = np.prod(start.size())
279 |
280 | assert(start.size() == end.size())
281 | if size == 1:
282 | # start and end are 1d-tensors
283 | res = torch.linspace(start, end, n_points)
284 | else:
285 | # start and end are vectors
286 | res = torch.Tensor()
287 | for i in range(0, start.size(0)):
288 | res = torch.cat((res,
289 | torch.linspace(start[i], end[i], n_points)),0)
290 | res = torch.t(res.reshape(start.size(0), n_points))
291 | return res
292 |
293 | def reverse(tensor):
294 | idx = [i for i in range(tensor.size(0)-1, -1, -1)]
295 | return tensor[idx]
296 |
297 |
298 | def create_net(n_inputs, n_outputs, n_layers = 1,
299 | n_units = 100, nonlinear = nn.Tanh):
300 | layers = [nn.Linear(n_inputs, n_units)]
301 | for i in range(n_layers):
302 | layers.append(nonlinear())
303 | layers.append(nn.Linear(n_units, n_units))
304 |
305 | layers.append(nonlinear())
306 | layers.append(nn.Linear(n_units, n_outputs))
307 | return nn.Sequential(*layers)
308 |
309 |
310 | def get_item_from_pickle(pickle_file, item_name):
311 | from_pickle = load_pickle(pickle_file)
312 | if item_name in from_pickle:
313 | return from_pickle[item_name]
314 | return None
315 |
316 |
317 | def get_dict_template():
318 | return {"observed_data": None,
319 | "observed_tp": None,
320 | "data_to_predict": None,
321 | "tp_to_predict": None,
322 | "observed_mask": None,
323 | "mask_predicted_data": None,
324 | "labels": None
325 | }
326 |
327 |
328 | def normalize_data(data):
329 | reshaped = data.reshape(-1, data.size(-1))
330 |
331 | att_min = torch.min(reshaped, 0)[0]
332 | att_max = torch.max(reshaped, 0)[0]
333 |
334 | # we don't want to divide by zero
335 | att_max[ att_max == 0.] = 1.
336 |
337 | if (att_max != 0.).all():
338 | data_norm = (data - att_min) / att_max
339 | else:
340 | raise Exception("Zero!")
341 |
342 | if torch.isnan(data_norm).any():
343 | raise Exception("nans!")
344 |
345 | return data_norm, att_min, att_max
346 |
347 |
348 | def normalize_masked_data(data, mask, att_min, att_max):
349 | # we don't want to divide by zero
350 | att_max[ att_max == 0.] = 1.
351 |
352 | if (att_max != 0.).all():
353 | data_norm = (data - att_min) / att_max
354 | else:
355 | raise Exception("Zero!")
356 |
357 | if torch.isnan(data_norm).any():
358 | raise Exception("nans!")
359 |
360 | # set masked out elements back to zero
361 | data_norm[mask == 0] = 0
362 |
363 | return data_norm, att_min, att_max
364 |
365 |
366 | def shift_outputs(outputs, first_datapoint = None):
367 | outputs = outputs[:,:,:-1,:]
368 |
369 | if first_datapoint is not None:
370 | n_traj, n_dims = first_datapoint.size()
371 | first_datapoint = first_datapoint.reshape(1, n_traj, 1, n_dims)
372 | outputs = torch.cat((first_datapoint, outputs), 2)
373 | return outputs
374 |
375 |
376 |
377 |
378 | def split_data_extrap(data_dict, dataset = ""):
379 | device = get_device(data_dict["data"])
380 |
381 | n_observed_tp = data_dict["data"].size(1) // 2
382 | if dataset == "hopper":
383 | n_observed_tp = data_dict["data"].size(1) // 3
384 |
385 | split_dict = {"observed_data": data_dict["data"][:,:n_observed_tp,:].clone(),
386 | "observed_tp": data_dict["time_steps"][:n_observed_tp].clone(),
387 | "data_to_predict": data_dict["data"][:,n_observed_tp:,:].clone(),
388 | "tp_to_predict": data_dict["time_steps"][n_observed_tp:].clone()}
389 |
390 | split_dict["observed_mask"] = None
391 | split_dict["mask_predicted_data"] = None
392 | split_dict["labels"] = None
393 |
394 | if ("mask" in data_dict) and (data_dict["mask"] is not None):
395 | split_dict["observed_mask"] = data_dict["mask"][:, :n_observed_tp].clone()
396 | split_dict["mask_predicted_data"] = data_dict["mask"][:, n_observed_tp:].clone()
397 |
398 | if ("labels" in data_dict) and (data_dict["labels"] is not None):
399 | split_dict["labels"] = data_dict["labels"].clone()
400 |
401 | split_dict["mode"] = "extrap"
402 | return split_dict
403 |
404 |
405 |
406 |
407 |
408 | def split_data_interp(data_dict):
409 | device = get_device(data_dict["data"])
410 |
411 | split_dict = {"observed_data": data_dict["data"].clone(),
412 | "observed_tp": data_dict["time_steps"].clone(),
413 | "data_to_predict": data_dict["data"].clone(),
414 | "tp_to_predict": data_dict["time_steps"].clone()}
415 |
416 | split_dict["observed_mask"] = None
417 | split_dict["mask_predicted_data"] = None
418 | split_dict["labels"] = None
419 |
420 | if "mask" in data_dict and data_dict["mask"] is not None:
421 | split_dict["observed_mask"] = data_dict["mask"].clone()
422 | split_dict["mask_predicted_data"] = data_dict["mask"].clone()
423 |
424 | if ("labels" in data_dict) and (data_dict["labels"] is not None):
425 | split_dict["labels"] = data_dict["labels"].clone()
426 |
427 | split_dict["mode"] = "interp"
428 | return split_dict
429 |
430 |
431 |
432 | def add_mask(data_dict):
433 | data = data_dict["observed_data"]
434 | mask = data_dict["observed_mask"]
435 |
436 | if mask is None:
437 | mask = torch.ones_like(data).to(get_device(data))
438 |
439 | data_dict["observed_mask"] = mask
440 | return data_dict
441 |
442 |
443 | def subsample_observed_data(data_dict, n_tp_to_sample = None, n_points_to_cut = None):
444 | # n_tp_to_sample -- if not None, randomly subsample the time points. The resulting timeline has n_tp_to_sample points
445 | # n_points_to_cut -- if not None, cut out consecutive points on the timeline. The resulting timeline has (N - n_points_to_cut) points
446 |
447 | if n_tp_to_sample is not None:
448 | # Randomly subsample time points
449 | data, time_steps, mask = subsample_timepoints(
450 | data_dict["observed_data"].clone(),
451 | time_steps = data_dict["observed_tp"].clone(),
452 | mask = (data_dict["observed_mask"].clone() if data_dict["observed_mask"] is not None else None),
453 | n_tp_to_sample = n_tp_to_sample)
454 |
455 | if n_points_to_cut is not None:
456 | # Remove consecutive time points
457 | data, time_steps, mask = cut_out_timepoints(
458 | data_dict["observed_data"].clone(),
459 | time_steps = data_dict["observed_tp"].clone(),
460 | mask = (data_dict["observed_mask"].clone() if data_dict["observed_mask"] is not None else None),
461 | n_points_to_cut = n_points_to_cut)
462 |
463 | new_data_dict = {}
464 | for key in data_dict.keys():
465 | new_data_dict[key] = data_dict[key]
466 |
467 | new_data_dict["observed_data"] = data.clone()
468 | new_data_dict["observed_tp"] = time_steps.clone()
469 | new_data_dict["observed_mask"] = mask.clone()
470 |
471 | if n_points_to_cut is not None:
472 | # Cut the section in the data to predict as well
473 | # Used only for the demo on the periodic function
474 | new_data_dict["data_to_predict"] = data.clone()
475 | new_data_dict["tp_to_predict"] = time_steps.clone()
476 | new_data_dict["mask_predicted_data"] = mask.clone()
477 |
478 | return new_data_dict
479 |
480 |
481 | def split_and_subsample_batch(data_dict, args, data_type = "train"):
482 | if data_type == "train":
483 | # Training set
484 | if args.extrap:
485 | processed_dict = split_data_extrap(data_dict, dataset = args.dataset)
486 | else:
487 | processed_dict = split_data_interp(data_dict)
488 |
489 | else:
490 | # Test set
491 | if args.extrap:
492 | processed_dict = split_data_extrap(data_dict, dataset = args.dataset)
493 | else:
494 | processed_dict = split_data_interp(data_dict)
495 |
496 | # add mask
497 | processed_dict = add_mask(processed_dict)
498 |
499 | # Subsample points or cut out the whole section of the timeline
500 | if (args.sample_tp is not None) or (args.cut_tp is not None):
501 | processed_dict = subsample_observed_data(processed_dict,
502 | n_tp_to_sample = args.sample_tp,
503 | n_points_to_cut = args.cut_tp)
504 |
505 | # if (args.sample_tp is not None):
506 | # processed_dict = subsample_observed_data(processed_dict,
507 | # n_tp_to_sample = args.sample_tp)
508 | return processed_dict
509 |
510 |
511 |
512 |
513 |
514 | def compute_loss_all_batches(model,
515 | test_dataloader, args,
516 | n_batches, experimentID, device,
517 | n_traj_samples = 1, kl_coef = 1.,
518 | max_samples_for_eval = None):
519 |
520 | total = {}
521 | total["loss"] = 0
522 | total["likelihood"] = 0
523 | total["mse"] = 0
524 | total["kl_first_p"] = 0
525 | total["std_first_p"] = 0
526 | total["pois_likelihood"] = 0
527 | total["ce_loss"] = 0
528 |
529 | n_test_batches = 0
530 |
531 | classif_predictions = torch.Tensor([]).to(device)
532 | all_test_labels = torch.Tensor([]).to(device)
533 |
534 | for i in range(n_batches):
535 | print("Computing loss... " + str(i))
536 |
537 | batch_dict = get_next_batch(test_dataloader)
538 |
539 | results = model.compute_all_losses(batch_dict,
540 | n_traj_samples = n_traj_samples, kl_coef = kl_coef)
541 |
542 | if args.classif:
543 | n_labels = model.n_labels #batch_dict["labels"].size(-1)
544 | n_traj_samples = results["label_predictions"].size(0)
545 |
546 | classif_predictions = torch.cat((classif_predictions,
547 | results["label_predictions"].reshape(n_traj_samples, -1, n_labels)),1)
548 | all_test_labels = torch.cat((all_test_labels,
549 | batch_dict["labels"].reshape(-1, n_labels)),0)
550 |
551 | for key in total.keys():
552 | if key in results:
553 | var = results[key]
554 | if isinstance(var, torch.Tensor):
555 | var = var.detach()
556 | total[key] += var
557 |
558 | n_test_batches += 1
559 |
560 | # for speed
561 | if max_samples_for_eval is not None:
562 | if n_batches * batch_size >= max_samples_for_eval:
563 | break
564 |
565 | if n_test_batches > 0:
566 | for key, value in total.items():
567 | total[key] = total[key] / n_test_batches
568 |
569 | if args.classif:
570 | if args.dataset == "physionet":
571 | #all_test_labels = all_test_labels.reshape(-1)
572 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
573 | all_test_labels = all_test_labels.repeat(n_traj_samples,1,1)
574 |
575 |
576 | idx_not_nan = ~torch.isnan(all_test_labels)
577 | classif_predictions = classif_predictions[idx_not_nan]
578 | all_test_labels = all_test_labels[idx_not_nan]
579 |
580 | dirname = "plots/" + str(experimentID) + "/"
581 | os.makedirs(dirname, exist_ok=True)
582 |
583 | total["auc"] = 0.
584 | if torch.sum(all_test_labels) != 0.:
585 | print("Number of labeled examples: {}".format(len(all_test_labels.reshape(-1))))
586 | print("Number of examples with mortality 1: {}".format(torch.sum(all_test_labels == 1.)))
587 |
588 | # Cannot compute AUC with only 1 class
589 | total["auc"] = sk.metrics.roc_auc_score(all_test_labels.cpu().numpy().reshape(-1),
590 | classif_predictions.cpu().numpy().reshape(-1))
591 | else:
592 | print("Warning: Couldn't compute AUC -- all examples are from the same class")
593 |
594 | if args.dataset == "activity":
595 | all_test_labels = all_test_labels.repeat(n_traj_samples,1,1)
596 |
597 | labeled_tp = torch.sum(all_test_labels, -1) > 0.
598 |
599 | all_test_labels = all_test_labels[labeled_tp]
600 | classif_predictions = classif_predictions[labeled_tp]
601 |
602 | # classif_predictions and all_test_labels are in on-hot-encoding -- convert to class ids
603 | _, pred_class_id = torch.max(classif_predictions, -1)
604 | _, class_labels = torch.max(all_test_labels, -1)
605 |
606 | pred_class_id = pred_class_id.reshape(-1)
607 |
608 | total["accuracy"] = sk.metrics.accuracy_score(
609 | class_labels.cpu().numpy(),
610 | pred_class_id.cpu().numpy())
611 | return total
612 |
613 | def check_mask(data, mask):
614 | #check that "mask" argument indeed contains a mask for data
615 | n_zeros = torch.sum(mask == 0.).cpu().numpy()
616 | n_ones = torch.sum(mask == 1.).cpu().numpy()
617 |
618 | # mask should contain only zeros and ones
619 | assert((n_zeros + n_ones) == np.prod(list(mask.size())))
620 |
621 | # all masked out elements should be zeros
622 | assert(torch.sum(data[mask == 0.] != 0.) == 0)
623 |
624 |
625 | import time
626 | import numpy as np
627 |
628 | import torch
629 | import torch.nn as nn
630 |
631 | from torch.distributions.multivariate_normal import MultivariateNormal
632 |
633 | # git clone https://github.com/rtqichen/torchdiffeq.git
634 | from torchdiffeq import odeint as odeint
635 |
636 | #####################################################################################################
637 |
638 | class DiffeqSolver(nn.Module):
639 | def __init__(self, input_dim, ode_func, method, latents,
640 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")):
641 | super(DiffeqSolver, self).__init__()
642 |
643 | self.ode_method = method
644 | self.latents = latents
645 | self.device = device
646 | self.ode_func = ode_func
647 |
648 | self.odeint_rtol = odeint_rtol
649 | self.odeint_atol = odeint_atol
650 |
651 | def forward(self, first_point, time_steps_to_predict, backwards = False):
652 | """
653 | # Decode the trajectory through ODE Solver
654 | """
655 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1]
656 | n_dims = first_point.size()[-1]
657 |
658 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict,
659 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
660 | pred_y = pred_y.permute(1,2,0,3)
661 |
662 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001)
663 | assert(pred_y.size()[0] == n_traj_samples)
664 | assert(pred_y.size()[1] == n_traj)
665 |
666 | return pred_y
667 |
668 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict,
669 | n_traj_samples = 1):
670 | """
671 | # Decode the trajectory through ODE Solver using samples from the prior
672 |
673 | time_steps_to_predict: time steps at which we want to sample the new trajectory
674 | """
675 | func = self.ode_func.sample_next_point_from_prior
676 |
677 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict,
678 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
679 | # shape: [n_traj_samples, n_traj, n_tp, n_dim]
680 | pred_y = pred_y.permute(1,2,0,3)
681 | return pred_y
682 |
683 |
684 | """
685 | https://github.com/YuliaRubanova/latent_ode/blob/master/lib/encoder_decoder.py
686 | """
687 | from torch.nn.modules.rnn import LSTM, GRU
688 |
689 |
690 | # GRU description:
691 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
692 | class GRU_unit(nn.Module):
693 | def __init__(self, latent_dim, input_dim,
694 | n_units = 100,
695 | # device = torch.device("cpu"),
696 | ):
697 | super(GRU_unit, self).__init__()
698 |
699 | self.update_gate = nn.Sequential(
700 | nn.Linear(latent_dim * 2 + input_dim, n_units),
701 | nn.Tanh(),
702 | nn.Linear(n_units, latent_dim),
703 | nn.Sigmoid())
704 |
705 | self.reset_gate = nn.Sequential(
706 | nn.Linear(latent_dim * 2 + input_dim, n_units),
707 | nn.Tanh(),
708 | nn.Linear(n_units, latent_dim),
709 | nn.Sigmoid())
710 |
711 | self.new_state_net = nn.Sequential(
712 | nn.Linear(latent_dim * 2 + input_dim, n_units),
713 | nn.Tanh(),
714 | nn.Linear(n_units, latent_dim * 2))
715 |
716 |
717 |
718 | def forward(self, y_mean, y_std, x):
719 | y_concat = torch.cat([y_mean, y_std, x], -1)
720 |
721 | update_gate = self.update_gate(y_concat)
722 | reset_gate = self.reset_gate(y_concat)
723 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)
724 |
725 | new_state, new_state_std = split_last_dim(self.new_state_net(concat))
726 | new_state_std = new_state_std.abs()
727 |
728 | new_y = (1-update_gate) * new_state + update_gate * y_mean
729 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std
730 |
731 | assert(not torch.isnan(new_y).any())
732 |
733 | # took out masked update (no mask provided)
734 |
735 | new_y_std = new_y_std.abs()
736 | return new_y, new_y_std
737 |
--------------------------------------------------------------------------------
/src/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def mse_loss(pred, true):
4 | return torch.mean((pred - true) ** 2)
5 |
6 | def l1_loss(pred, true):
7 | return torch.mean(torch.abs(pred - true))
--------------------------------------------------------------------------------
/src/utils/metric_calc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 | # dist calculations
5 | import math
6 | from typing import Union
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from typing import Optional
12 | from .mmd import linear_mmd2, mix_rbf_mmd2, poly_mmd2
13 | # from .optimal_transport import wasserstein
14 | import ot as pot
15 | from functools import partial
16 | import scipy.stats as stats
17 | from scipy.stats import pearsonr, spearmanr
18 | import pandas as pd
19 |
20 |
21 | # auroc and auprc
22 | from sklearn.metrics import roc_auc_score, average_precision_score
23 |
24 |
25 |
26 | def wasserstein(
27 | x0: torch.Tensor,
28 | x1: torch.Tensor,
29 | method: Optional[str] = None,
30 | reg: float = 0.05,
31 | power: int = 2,
32 | **kwargs,
33 | ) -> float:
34 | assert power == 1 or power == 2
35 | # ot_fn should take (a, b, M) as arguments where a, b are marginals and
36 | # M is a cost matrix
37 | if method == "exact" or method is None:
38 | ot_fn = pot.emd2
39 | elif method == "sinkhorn":
40 | ot_fn = partial(pot.sinkhorn2, reg=reg)
41 | else:
42 | raise ValueError(f"Unknown method: {method}")
43 |
44 | a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
45 | if x0.dim() > 2:
46 | x0 = x0.reshape(x0.shape[0], -1)
47 | if x1.dim() > 2:
48 | x1 = x1.reshape(x1.shape[0], -1)
49 | M = torch.cdist(x0, x1)
50 | if power == 2:
51 | M = M**2
52 | ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7)
53 | if power == 2:
54 | ret = math.sqrt(ret)
55 | return ret
56 |
57 | def compute_distances(pred, true):
58 | """computes distances between vectors."""
59 | mse = torch.nn.functional.mse_loss(pred, true).item()
60 | me = math.sqrt(mse)
61 | return mse, me, torch.nn.functional.l1_loss(pred, true).item()
62 |
63 | def compute_distribution_distances_new(pred: torch.Tensor, true: Union[torch.Tensor, list]):
64 | """computes distances between distributions.
65 | pred: [batch, times, dims] tensor
66 | true: [batch, times, dims] tensor or list[batch[i], dims] of length times
67 |
68 | This handles jagged times as a list of tensors.
69 | return the eval for the last time point
70 | """
71 | NAMES = [
72 | "1-Wasserstein",
73 | "2-Wasserstein",
74 | "Linear_MMD",
75 | "Poly_MMD",
76 | "RBF_MMD",
77 | "Mean_MSE",
78 | "Mean_L2",
79 | "Mean_L1",
80 | "Median_MSE",
81 | "Median_L2",
82 | "Median_L1",
83 | ]
84 | is_jagged = isinstance(true, list)
85 | pred_is_jagged = isinstance(pred, list)
86 | dists = []
87 | to_return = []
88 | names = []
89 | filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")]
90 | ts = len(pred) if pred_is_jagged else pred.shape[1]
91 | # for t in np.arange(ts):
92 | t = max(ts - 1, 0)
93 | if pred_is_jagged:
94 | a = pred[t]
95 | else:
96 | a = torch.tensor(pred).float().clone().detach()
97 | if is_jagged:
98 | b = true[t]
99 | else:
100 | b = torch.tensor(true).float().clone().detach()
101 | w1 = wasserstein(a, b, power=1)
102 | w2 = wasserstein(a, b, power=2)
103 |
104 | if not pred_is_jagged and not is_jagged:
105 | mmd_linear = linear_mmd2(a, b).item()
106 | mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item()
107 | mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
108 | mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0))
109 | median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0])
110 |
111 | if pred_is_jagged or is_jagged:
112 | dists.append((w1, w2, *mean_dists, *median_dists))
113 | else:
114 | dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists))
115 |
116 | to_return.extend(np.array(dists).mean(axis=0))
117 | names.extend(filtered_names)
118 | return names, to_return
119 |
120 |
121 | def metrics_calculation(pred, true, metrics=['mse_loss', 'l1_loss'], cutoff=-0.91, map_idx = 1):
122 |
123 | # if pred is a tensor, convert to numpy
124 | if isinstance(pred, torch.Tensor):
125 | pred = pred.detach().cpu().squeeze().numpy()
126 | true = true.detach().cpu().squeeze().numpy()
127 |
128 | loss_D = {key : None for key in metrics}
129 | for metric in metrics:
130 | if metric == 'mse_loss':
131 | loss_D['mse_loss'] = np.mean((pred - true)**2)
132 | # self.log('mse_loss', self.loss_fn(pred, true))
133 | if metric == 'l1_loss':
134 | loss_D['l1_loss'] = np.mean(np.abs(pred - true))
135 | # self.log('l1_loss', torch.mean(torch.abs(pred - true)))
136 | if metric == 'crit_map':
137 | auroc, auprc = critical_state_pred(pred, true, cutoff, map_idx)
138 | loss_D['crit_state_auroc'] = auroc
139 | loss_D['crit_state_auprc'] = auprc
140 | if metric == 'variance_dist':
141 | # calculate distribution difference in variance
142 | # add to loss_D, multiple items
143 | var_d = variance_dist(pred, true)
144 | for key, val in var_d.items():
145 | loss_D[key] = val
146 | # remove empty key 'variange_dist'
147 | if 'variance_dist' in loss_D:
148 | del loss_D['variance_dist']
149 |
150 | return loss_D
151 |
152 |
153 | def variance_dist(pred, true):
154 | """Calculates the variance (between data points) for a full trajectory
155 |
156 | Args:
157 | pred (numpy array): _description_
158 | true (numpy array): _description_
159 | """
160 | pred_var = np.diff(pred, axis=0)
161 | true_var = np.diff(true, axis=0)
162 | # convert both to torch tensor and use compute_distances
163 | pred_var = torch.tensor(pred_var).float()
164 | true_var = torch.tensor(true_var).float()
165 | names, values = compute_distribution_distances_new(pred_var, true_var)
166 | var_met_dict = {name: value for name, value in zip(names, values)}
167 | return var_met_dict
168 |
169 |
170 | def critical_state_pred(pred, true, cutoff=-0.91, map_idx = 1):
171 | """calculate the percentage of critical state predicted
172 | always in format: HR, MAP
173 | MAP: 60 cutoff
174 |
175 | Args:
176 | pred (_type_): _description_
177 | true (_type_): _description_
178 | cutoff (float, optional): _description_. Defaults to -0.91.
179 | -0.91 for normalized (not scaled)
180 | """
181 | # return the percentage of critical state predicted
182 | critical_cutoff = cutoff
183 | pred_critical = (pred[:, map_idx] < critical_cutoff).astype(int)
184 | true_critical = (true[:, map_idx] < critical_cutoff).astype(int)
185 | auroc = roc_auc_score(true_critical, pred_critical)
186 | auprc = average_precision_score(true_critical, pred_critical)
187 | return auroc, auprc
--------------------------------------------------------------------------------
/src/utils/mmd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 |
4 | import torch
5 |
6 | min_var_est = 1e-8
7 |
8 |
9 | # Consider linear time MMD with a linear kernel:
10 | # K(f(x), f(y)) = f(x)^Tf(y)
11 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
12 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
13 | #
14 | # f_of_X: batch_size * k
15 | # f_of_Y: batch_size * k
16 | def linear_mmd2(f_of_X, f_of_Y):
17 | loss = 0.0
18 | delta = f_of_X - f_of_Y
19 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
20 | return loss
21 |
22 |
23 | # Consider linear time MMD with a polynomial kernel:
24 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d
25 | # f_of_X: batch_size * k
26 | # f_of_Y: batch_size * k
27 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
28 | K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c
29 | K_XX_mean = torch.mean(K_XX.pow(d))
30 |
31 | K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c
32 | K_YY_mean = torch.mean(K_YY.pow(d))
33 |
34 | K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c
35 | K_XY_mean = torch.mean(K_XY.pow(d))
36 |
37 | K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c
38 | K_YX_mean = torch.mean(K_YX.pow(d))
39 |
40 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean
41 |
42 |
43 | def _mix_rbf_kernel(X, Y, sigma_list):
44 | assert X.size(0) == Y.size(0)
45 | m = X.size(0)
46 |
47 | Z = torch.cat((X, Y), 0)
48 | ZZT = torch.mm(Z, Z.t())
49 | diag_ZZT = torch.diag(ZZT).unsqueeze(1)
50 | Z_norm_sqr = diag_ZZT.expand_as(ZZT)
51 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
52 |
53 | K = 0.0
54 | for sigma in sigma_list:
55 | gamma = 1.0 / (2 * sigma**2)
56 | K += torch.exp(-gamma * exponent)
57 |
58 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
59 |
60 |
61 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
62 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
63 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
64 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
65 |
66 |
67 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
68 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
69 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
70 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
71 |
72 |
73 | ################################################################################
74 | # Helper functions to compute variances based on kernel matrices
75 | ################################################################################
76 |
77 |
78 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
79 | m = K_XX.size(0) # assume X, Y are same shape
80 |
81 | # Get the various sums of kernels that we'll use
82 | # Kts drop the diagonal, but we don't need to compute them explicitly
83 | if const_diagonal is not False:
84 | diag_X = diag_Y = const_diagonal
85 | sum_diag_X = sum_diag_Y = m * const_diagonal
86 | else:
87 | diag_X = torch.diag(K_XX) # (m,)
88 | diag_Y = torch.diag(K_YY) # (m,)
89 | sum_diag_X = torch.sum(diag_X)
90 | sum_diag_Y = torch.sum(diag_Y)
91 |
92 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
93 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
94 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
95 |
96 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
97 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
98 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
99 |
100 | if biased:
101 | mmd2 = (
102 | (Kt_XX_sum + sum_diag_X) / (m * m)
103 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
104 | - 2.0 * K_XY_sum / (m * m)
105 | )
106 | else:
107 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
108 |
109 | return mmd2
110 |
111 |
112 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
113 | mmd2, var_est = _mmd2_and_variance(
114 | K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased
115 | )
116 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
117 | return loss, mmd2, var_est
118 |
119 |
120 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
121 | m = K_XX.size(0) # assume X, Y are same shape
122 |
123 | # Get the various sums of kernels that we'll use
124 | # Kts drop the diagonal, but we don't need to compute them explicitly
125 | if const_diagonal is not False:
126 | diag_X = diag_Y = const_diagonal
127 | sum_diag_X = sum_diag_Y = m * const_diagonal
128 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
129 | else:
130 | diag_X = torch.diag(K_XX) # (m,)
131 | diag_Y = torch.diag(K_YY) # (m,)
132 | sum_diag_X = torch.sum(diag_X)
133 | sum_diag_Y = torch.sum(diag_Y)
134 | sum_diag2_X = diag_X.dot(diag_X)
135 | sum_diag2_Y = diag_Y.dot(diag_Y)
136 |
137 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
138 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
139 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
140 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e
141 |
142 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
143 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
144 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
145 |
146 | Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2
147 | Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2
148 | K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2
149 |
150 | if biased:
151 | mmd2 = (
152 | (Kt_XX_sum + sum_diag_X) / (m * m)
153 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
154 | - 2.0 * K_XY_sum / (m * m)
155 | )
156 | else:
157 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
158 |
159 | var_est = (
160 | 2.0
161 | / (m**2 * (m - 1.0) ** 2)
162 | * (
163 | 2 * Kt_XX_sums.dot(Kt_XX_sums)
164 | - Kt_XX_2_sum
165 | + 2 * Kt_YY_sums.dot(Kt_YY_sums)
166 | - Kt_YY_2_sum
167 | )
168 | - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
169 | + 4.0
170 | * (m - 2.0)
171 | / (m**3 * (m - 1.0) ** 2)
172 | * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
173 | - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum)
174 | - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2
175 | + 8.0
176 | / (m**3 * (m - 1.0))
177 | * (
178 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
179 | - Kt_XX_sums.dot(K_XY_sums_1)
180 | - Kt_YY_sums.dot(K_XY_sums_0)
181 | )
182 | )
183 | return mmd2, var_est
--------------------------------------------------------------------------------
/src/utils/sde.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class SDE(torch.nn.Module):
4 |
5 | noise_type = "diagonal"
6 | sde_type = "ito"
7 |
8 | # noise is sigma in this notebook for the equation sigma * (t * (1 - t))
9 | def __init__(self, ode_drift, noise=1.0, reverse=False):
10 | super().__init__()
11 | self.drift = ode_drift
12 | self.reverse = reverse
13 | self.noise = noise
14 |
15 | # Drift
16 | def f(self, t, y):
17 | if self.reverse:
18 | t = 1 - t
19 | if len(t.shape) == len(y.shape):
20 | x = torch.cat([y, t], 1)
21 | else:
22 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)
23 | return self.drift(x)
24 |
25 | # Diffusion
26 | def g(self, t, y):
27 | return torch.ones_like(t) * torch.ones_like(y) * self.noise
--------------------------------------------------------------------------------
/src/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from mpl_toolkits.mplot3d import Axes3D
5 |
6 | def plot_3d_path_ind(traj, groundtruth, t_span=torch.linspace(0, 4 * np.pi, 100), title=""):
7 | n = len(t_span)
8 | fig = plt.figure(figsize=(15, 10))
9 | ax1 = fig.add_subplot(1, 1, 1, projection='3d')
10 | len_traj = traj.shape[0]
11 | ax1.scatter([0] * len_traj, traj[0, 0], traj[0, 1], alpha=0.5, c="red") # start
12 | for i in range(n - 1):
13 | ax1.plot([t_span[i], t_span[i + 1]], [traj[i, 0], traj[i + 1, 0]], [traj[i, 1], traj[i + 1, 1]], alpha=1, c="olive") # path
14 | ax1.plot([t_span[i], t_span[i + 1]], [groundtruth[i, 0], groundtruth[i + 1, 0]], [groundtruth[i, 1], groundtruth[i + 1, 1]], alpha=1, c="pink")
15 | ax1.scatter(t_span, traj[:, 0], traj[:, 1], alpha=0.5, c="blue") # end
16 | ax1.scatter(t_span, groundtruth[:, 0], groundtruth[:, 1], alpha=0.5, c="purple") # ground truth
17 | ax1.set_title(title)
18 |
19 | return fig
20 |
21 |
22 | def plot_3d_path_ind_noise(traj, groundtruth, noise, t_span=torch.linspace(0, 4 * np.pi, 100), title=""):
23 | n = len(t_span)
24 | fig = plt.figure(figsize=(15, 10))
25 | ax1 = fig.add_subplot(1, 1, 1, projection='3d')
26 |
27 | noise = noise.cpu().numpy()
28 |
29 | len_traj = traj.shape[0]
30 | ax1.scatter([0] * len_traj, traj[0, 0], traj[0, 1], alpha=0.5, c="red") # start
31 |
32 | # Plot trajectory and ground truth
33 | ax1.plot(t_span, traj[:, 0], traj[:, 1], label='Trajectory', c='olive')
34 | ax1.plot(t_span, groundtruth[:, 0], groundtruth[:, 1], label='Ground Truth', c='pink')
35 | ax1.scatter(t_span, traj[:, 0], traj[:, 1], alpha=0.5, c="blue") # end
36 | ax1.scatter(t_span, groundtruth[:, 0], groundtruth[:, 1], alpha=0.5, c="purple") # ground truth
37 | # Plot uncertainty as scatter points around each trajectory point
38 | # Plus and minus noise values for visualization
39 | for i in range(n-1):
40 | if i == 0:
41 | continue
42 | x_noise_pos = traj[i+1, 0] + noise[i, 0]
43 | y_noise_pos = traj[i+1, 1] + noise[i, 1]
44 | x_noise_neg = traj[i+1, 0] - noise[i, 0]
45 | y_noise_neg = traj[i+1, 1] - noise[i, 1]
46 | ax1.scatter([t_span[i]]*2, [x_noise_pos, x_noise_neg], [y_noise_pos, y_noise_neg], color='gray', alpha=0.5)
47 |
48 | ax1.set_title(title)
49 | ax1.set_xlabel('Time')
50 | ax1.set_ylabel('X')
51 | ax1.set_zlabel('Y')
52 | ax1.legend()
53 |
54 | return fig
55 |
56 |
57 |
58 | def join_3d_plots(figs, rows, cols):
59 | new_fig = plt.figure(figsize=(15 * cols, 10 * rows))
60 | for i, fig in enumerate(figs):
61 | ax = new_fig.add_subplot(rows, cols, i + 1, projection='3d')
62 | original_ax = fig.get_children()[1]
63 | for line in original_ax.get_lines():
64 | ax.add_line(line)
65 | for patch in original_ax.get_patches():
66 | ax.add_patch(patch)
67 | return new_fig
--------------------------------------------------------------------------------
/src/utils/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class torch_wrapper_tv(torch.nn.Module):
4 | """Wraps model to torchdyn compatible format."""
5 |
6 | def __init__(self, model):
7 | super().__init__()
8 | self.model = model
9 |
10 | def forward(self, t, x, *args, **kwargs):
11 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))
--------------------------------------------------------------------------------