├── .gitignore ├── LICENSE ├── README.md ├── assets ├── PLACEHOLDER.py ├── crossing_oscillation.png ├── pred_t.png ├── revised_tfm_fig1.png └── traj_concept.png ├── data └── toy_data.pkl ├── environment.yml ├── notebook └── 3Oscillation.ipynb └── src ├── conf ├── config.yaml ├── data │ ├── eICU.yaml │ ├── eICU_ablated.yaml │ ├── eICU_multdim.yaml │ └── mimic_liver.yaml ├── model │ ├── cfm.yaml │ ├── fm.yaml │ ├── latentODE.yaml │ ├── ode.yaml │ ├── ode_256.yaml │ ├── sde.yaml │ ├── sde_256.yaml │ ├── tfm_ablated_size_memory_uncertainty.yaml │ ├── tfm_ablated_uncertainty.yaml │ ├── tfm_ode.yaml │ ├── tfm_ode_ablated_size_memory_uncertainty.yaml │ ├── tfm_ode_ablated_uncertainty.yaml │ └── tfm_sde.yaml ├── perturb_config.yaml └── trainer.yaml ├── data └── datamodule.py ├── main.py ├── model ├── FM_baseline.py ├── base_models.py ├── cfm_baseline.py ├── components │ ├── grad_util.py │ ├── mlp.py │ ├── positional_encoding.py │ └── sde_func_solver.py ├── latent_ode.py ├── mlp_memory.py ├── mlp_noise.py └── ode_baseline.py └── utils ├── latent_ode_utils.py ├── loss.py ├── metric_calc.py ├── mmd.py ├── sde.py ├── visualize.py └── wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/outputs/* 2 | outputs/* 3 | *pycache* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xi (Nicole) Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Trajctory Flow Matching 4 | 9 | 10 | [![TFM Preprint](http://img.shields.io/badge/paper-arxiv.2410.21154-B31B1B.svg)](http://arxiv.org/abs/2410.21154) 11 | [![pytorch](https://img.shields.io/badge/PyTorch_1.8+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 12 | [![lightning](https://img.shields.io/badge/-Lightning_1.6+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) 13 | [![hydra](https://img.shields.io/badge/Config-Hydra_1.2-89b8cd)](https://hydra.cc/) 14 | [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://opensource.org/license/mit) 15 | Template 16 | 17 |
18 | 19 | 20 | ## Description 21 | Trajectory Flow Matching (TFM) is a method that leverages the flow matching technique from generative modeling to model time series. This approach offers a simulation-free training process, allowing for efficient fitting of stochastic differential equations to time-series data. Augmented with memory, time interval prediction, and uncertainty prediction, TFM can better model irregularly sampled trajectories with stochastic nature, for example clinical time series. 22 | 23 |

24 | 25 |

26 | 27 | The idea of TFM lies in using flow matching concept to predict both stochastic uncertainty and the next value in the time series. The prediction is conditioned on past data and conditional variables. 28 | 29 | 30 | 31 | ## How to run 32 | 37 | 38 | ### Initialize environment 39 | ```bash 40 | # clone project 41 | git clone https://github.com/nZhangx/TrajectoryFlowMatching.git 42 | cd TrajectoryFlowMatching 43 | 44 | # [OPTIONAL] create conda environment 45 | conda create -n tfm python=3.10 46 | conda activate tfm 47 | 48 | # install requirements 49 | conda env create -f environment.yml 50 | ``` 51 | 52 | ### Run experiments 53 | Under `src`, create new a `DATA_NAME.yml` under `conf/data` and a `MODEL_NAME.yml` under `conf/model` with desired configurations. Then replace `data` and `model` definitions in `conf/config.yaml` with your `DATA_NAME` and `MODEL_NAME`. Then run 54 | ```bash 55 | python src/main.py 56 | ``` 57 | 58 | ### Demo 59 | We have included an example of TFM modeling three crossing oscillations in a self-contained Jupyter notebook `notebook/3Oscillation.ipynb`. 60 | 61 | ### Implemented models 62 | 63 | - TFM and ablations (size ablated, uncertainty ablated) 64 | - Aligned Flow Matching [[Liu et al., 2023](https://arxiv.org/abs/2302.05872)][[Somnath et al., 2023](https://arxiv.org/abs/2302.11419)] 65 | - NeuralODE [[Chen et al., 2018](https://arxiv.org/abs/1806.07366)] 66 | - NeuralSDE [[Li et al., 2020](https://arxiv.org/abs/2001.01328)] [[Kidger et al., 2021](https://arxiv.org/abs/2102.03657)] 67 | - LatentODE [[Rubanova et al. 2019](https://arxiv.org/abs/1907.03907)] 68 | 69 | | | ICU Sepsis | ICU Cardiac Arrest | ICU GIB | ED GIB | 70 | |-------------------------------|----------------------------|----------------------------|----------------------------|-----------------------------| 71 | | NeuralODE | 4.776 $\pm$ 0.000 | 6.153 $\pm$ 0.000 | 3.170 $\pm$ 0.000 | 10.859 $\pm$ 0.000 | 72 | | FM baseline ODE | 4.671 $\pm$ 0.791 | 10.207 $\pm$ 1.076 | 118.439 $\pm$ 17.947 | 11.923 $\pm$ 1.123 | 73 | | LatentODE-RNN | 61.806 $\pm$ 46.573 | 386.190 $\pm$ 558.140 | 422.886 $\pm$ 431.954 | 980.228 $\pm$ 1032.393 | 74 | | **TFM-ODE (ours)** | **0.793 $\pm$ 0.017** | 2.762 $\pm$ 0.021 | 2.673 $\pm$ 0.069 | **8.245 $\pm$ 0.495** | 75 | || 76 | | NeuralSDE | 4.747 $\pm$ 0.000 | 3.250 $\pm$ 0.024 | 3.186 $\pm$ 0.000 | 10.850 $\pm$ 0.043 | 77 | | **TFM (ours)** | 0.796 $\pm$ 0.026 | **2.755 $\pm$ 0.015** | **2.596 $\pm$ 0.079** | 8.613 $\pm$ 0.260 | 78 | 79 | 80 | ### Available datasets 81 | 82 | We plan to share the clinical data we used that are from the [eICU Collaborative Research Database v2.0](https://physionet.org/content/eicu-crd/2.0/) (ICU sepsis and ICU Cardiac Arrest) and the [Medical Information Mart for Intensive Care III (MIMIC-III) critical care database](https://physionet.org/content/mimiciii/1.4/) (ICU GIB) on [Physionet](https://physionet.org/). 83 | 84 | ## How to cite 85 | 86 | This repository contains the code to reproduce the main experiments and illustrations of the preprint [Trajectory Flow Matching with Applications to 87 | Clinical Time Series Modeling](https://arxiv.org/abs/2410.21154). We are excited that it was marked as a **spotlight** presentation. 88 | 89 | If you find this code useful in your research, please cite (expand for BibTeX): 90 | 91 |
92 | 93 | bibtex citation 94 | 95 | 96 | ```bibtex 97 | @article{TFM, 98 | title = {Trajectory Flow Matching with Applications to Clinical Time Series Modelling}, 99 | author = {Zhang, Xi and Pu, Yuan and Kawamura, Yuki and Loza, Andrew and Bengio, Yoshua and Shung, Dennis and Tong, Alexander}, 100 | year = 2024, 101 | journal = {NeurIPS}, 102 | } 103 | ``` 104 |
105 | 106 | 107 | 108 | ## License 109 | This repo is licensed under the [MIT License](https://opensource.org/license/mit). 110 | -------------------------------------------------------------------------------- /assets/PLACEHOLDER.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/PLACEHOLDER.py -------------------------------------------------------------------------------- /assets/crossing_oscillation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/crossing_oscillation.png -------------------------------------------------------------------------------- /assets/pred_t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/pred_t.png -------------------------------------------------------------------------------- /assets/revised_tfm_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/revised_tfm_fig1.png -------------------------------------------------------------------------------- /assets/traj_concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/assets/traj_concept.png -------------------------------------------------------------------------------- /data/toy_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nZhangx/TrajectoryFlowMatching/b562e386074cfec62ee7ddd23e826c3f33dd8ba2/data/toy_data.pkl -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ti-env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - asttokens=2.4.1 9 | - bzip2=1.0.8 10 | - ca-certificates=2024.6.2 11 | - comm=0.2.2 12 | - debugpy=1.6.7 13 | - decorator=5.1.1 14 | - executing=2.0.1 15 | - importlib-metadata=7.2.1 16 | - importlib_metadata=7.2.1 17 | - ipykernel=6.29.4 18 | - ipython=8.25.0 19 | - jedi=0.19.1 20 | - jupyter_client=8.6.2 21 | - jupyter_core=5.7.2 22 | - ld_impl_linux-64=2.38 23 | - libffi=3.4.4 24 | - libgcc-ng=13.2.0 25 | - libgomp=13.2.0 26 | - libsodium=1.0.18 27 | - libstdcxx-ng=11.2.0 28 | - libuuid=1.41.5 29 | - matplotlib-inline=0.1.7 30 | - ncurses=6.4 31 | - nest-asyncio=1.6.0 32 | - openssl=3.3.1 33 | - packaging=24.1 34 | - parso=0.8.4 35 | - pexpect=4.9.0 36 | - pickleshare=0.7.5 37 | - pip=24.0 38 | - platformdirs=4.2.2 39 | - prompt-toolkit=3.0.47 40 | - ptyprocess=0.7.0 41 | - pure_eval=0.2.2 42 | - pygments=2.18.0 43 | - python=3.10.14 44 | - python_abi=3.10 45 | - pyzmq=25.1.2 46 | - readline=8.2 47 | - setuptools=69.5.1 48 | - six=1.16.0 49 | - sqlite=3.45.3 50 | - stack_data=0.6.2 51 | - tk=8.6.14 52 | - tornado=6.4.1 53 | - traitlets=5.14.3 54 | - typing_extensions=4.12.2 55 | - wcwidth=0.2.13 56 | - wheel=0.43.0 57 | - xz=5.4.6 58 | - zeromq=4.3.5 59 | - zipp=3.19.2 60 | - zlib=1.2.13 61 | - pip: 62 | - aiohttp==3.9.5 63 | - aiosignal==1.3.1 64 | - alembic==1.13.1 65 | - anndata==0.10.7 66 | - antlr4-python3-runtime==4.9.3 67 | - array-api-compat==1.7.1 68 | - async-timeout==4.0.3 69 | - attrs==23.2.0 70 | - autopage==0.5.2 71 | - black==24.4.2 72 | - certifi==2024.6.2 73 | - cfgv==3.4.0 74 | - charset-normalizer==3.3.2 75 | - click==8.1.7 76 | - cliff==4.7.0 77 | - cloudpickle==3.0.0 78 | - cmaes==0.10.0 79 | - cmd2==2.4.3 80 | - colorlog==6.8.2 81 | - contourpy==1.2.1 82 | - cycler==0.12.1 83 | - distlib==0.3.8 84 | - docker-pycreds==0.4.0 85 | - exceptiongroup==1.2.1 86 | - filelock==3.15.1 87 | - fire==0.6.0 88 | - flake8==7.0.0 89 | - flake8-pyproject==1.2.3 90 | - fonttools==4.53.0 91 | - frozenlist==1.4.1 92 | - fsspec==2024.6.0 93 | - gitdb==4.0.11 94 | - gitpython==3.1.43 95 | - greenlet==3.0.3 96 | - h5py==3.11.0 97 | - huggingface-hub==0.23.4 98 | - hydra-colorlog==1.2.0 99 | - hydra-core==1.2.0 100 | - hydra-optuna-sweeper==1.2.0 101 | - hydra-submitit-launcher==1.2.0 102 | - identify==2.5.36 103 | - idna==3.7 104 | - iniconfig==2.0.0 105 | - ipywidgets==8.1.3 106 | - isort==5.13.2 107 | - joblib==1.4.2 108 | - jupyterlab-widgets==3.0.11 109 | - kiwisolver==1.4.5 110 | - legacy-api-wrap==1.4 111 | - lightning-bolts==0.6.0.post1 112 | - lightning-utilities==0.3.0 113 | - llvmlite==0.43.0 114 | - mako==1.3.5 115 | - markdown-it-py==3.0.0 116 | - markupsafe==2.1.5 117 | - matplotlib==3.9.0 118 | - mccabe==0.7.0 119 | - mdurl==0.1.2 120 | - multidict==6.0.5 121 | - mypy-extensions==1.0.0 122 | - natsort==8.4.0 123 | - networkx==3.3 124 | - nodeenv==1.9.1 125 | - numba==0.60.0 126 | - numpy==1.26.4 127 | - omegaconf==2.3.0 128 | - optuna==2.10.1 129 | - pandas==2.0.3 130 | - pastel==0.2.1 131 | - pathspec==0.12.1 132 | - patsy==0.5.6 133 | - pbr==6.0.0 134 | - pillow==10.3.0 135 | - pluggy==1.5.0 136 | - poethepoet==0.10.0 137 | - pot==0.9.3 138 | - pre-commit==3.7.1 139 | - prettytable==3.10.0 140 | - protobuf==5.27.1 141 | - psutil==5.9.8 142 | - pycodestyle==2.11.1 143 | - pyflakes==3.2.0 144 | - pynndescent==0.5.12 145 | - pyparsing==3.1.2 146 | - pyperclip==1.8.2 147 | - pyrootutils==1.0.4 148 | - pytest==8.2.2 149 | - python-dateutil==2.9.0.post0 150 | - python-dotenv==1.0.1 151 | # - python-graphviz==0.20.3 152 | - pytorch-lightning==1.8.3 153 | - pytz==2024.1 154 | - pyyaml==6.0.1 155 | - requests==2.32.3 156 | - rich==13.7.1 157 | - safetensors==0.4.3 158 | - scanpy==1.10.1 159 | - scikit-learn==1.5.0 160 | - scipy==1.13.1 161 | - scprep==1.2.3 162 | - seaborn==0.13.2 163 | - sentry-sdk==2.5.1 164 | - session-info==1.0.0 165 | - setproctitle==1.3.3 166 | - smmap==5.0.1 167 | - sqlalchemy==2.0.30 168 | - statsmodels==0.14.2 169 | - stdlib-list==0.10.0 170 | - stevedore==5.2.0 171 | - submitit==1.5.1 172 | - tensorboardx==2.6.2.2 173 | - termcolor==2.4.0 174 | - threadpoolctl==3.5.0 175 | - timm==1.0.3 176 | - tomli==2.0.1 177 | - tomlkit==0.12.5 178 | - torch==1.12.1+cu113 179 | - torchaudio==0.12.1+cu113 180 | - torchcde==0.2.5 181 | # - torchcubicspline==0.0.3 182 | - torchdiffeq==0.2.4 183 | - torchdyn==1.0.6 184 | - torchmetrics==0.11.0 185 | - torchsde==0.2.6 186 | - torchvision==0.13.1+cu113 187 | - torchviz==0.0.2 188 | - tqdm==4.66.4 189 | - trampoline==0.1.2 190 | - tzdata==2024.1 191 | - umap-learn==0.5.6 192 | - urllib3==2.2.1 193 | - virtualenv==20.26.2 194 | - wandb==0.17.1 195 | - widgetsnbextension==4.0.11 196 | - yarl==1.9.4 197 | -------------------------------------------------------------------------------- /src/conf/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - data: eICU 4 | - model: tfm_sde 5 | 6 | wandb_logging: true 7 | wandb_project: clinical_trajectory 8 | wandb_dir: wandb_log/ 9 | ckpt_dir: checkpoints/ 10 | max_epochs: 200 11 | max_time: 00:48:00:00 12 | check_val_every_n_epoch: 10 13 | skip_training: false 14 | mode: None 15 | seed: 42 -------------------------------------------------------------------------------- /src/conf/data/eICU.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_module: 3 | _target_: data.datamodule.clinical_DataModule 4 | train_consecutive: true 5 | file_path: data/toy_data.pkl 6 | naming: eICU_DataModule_J30_v1 7 | t_headings: time_scaled_v1 8 | x_headings: 9 | - hr_normalized 10 | - map_normalized 11 | cond_headings: 12 | - apache_outcome_prob 13 | - norepi_inf_scaled 14 | memory: 0 15 | -------------------------------------------------------------------------------- /src/conf/data/eICU_ablated.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_module: 3 | _target_: data.datamodule.clinical_DataModule 4 | train_consecutive: true 5 | file_path: data/toy_data.pkl 6 | naming: eICU_DataModule_J30_v1 7 | t_headings: time_scaled_v1 8 | x_headings: 9 | - hr_normalized 10 | - map_normalized 11 | cond_headings: 12 | - apache_outcome_prob 13 | memory: 0 -------------------------------------------------------------------------------- /src/conf/data/eICU_multdim.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_module: 3 | _target_: data.datamodule.clinical_DataModule 4 | train_consecutive: true 5 | file_path: eICU_CART_downsampled.pkl 6 | naming: eICU-CART_DataModule_APR16 7 | t_headings: time_scaled_v1 8 | x_headings: 9 | - hr_normalized_scaled 10 | - dbp_normalized_scaled 11 | - rr_normalized_scaled 12 | cond_headings: 13 | - AGE_AT_ADM_normalized 14 | memory: 0 -------------------------------------------------------------------------------- /src/conf/data/mimic_liver.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_module: 3 | _target_: data.mimic_liver_data.MIMICLiverDataModule 4 | naming: MIMICLiverDataModule 5 | train_consecutive: true 6 | memory: 0 7 | x_headings: 8 | - 1 9 | - MAP 10 | cond_headings: 11 | - prbc_outcome 12 | - pressor 13 | - bloodprod 14 | - severe_liver 15 | -------------------------------------------------------------------------------- /src/conf/model/cfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.cfm_baseline.MLP_CFM 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | clip: 1e-2 8 | sigma: 0.1 9 | dim: 2 10 | metrics: 11 | - variance_dist 12 | - mse_loss 13 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/fm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.FM_baseline.MLP_FM 4 | time_varying: true 5 | conditional: false 6 | treatment_cond: 0 7 | clip: 1e-2 8 | sigma: 0.1 9 | dim: 2 10 | metrics: 11 | - variance_dist 12 | - mse_loss 13 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/latentODE.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.latent_ode.LatentODE_pl 4 | input_dim: 2 5 | latent_dim: 2 6 | output_dim: 2 7 | metrics: 8 | - variance_dist 9 | - mse_loss 10 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/ode.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.ode_baseline.ODEBaseline 4 | dim: 2 5 | w: 64 6 | lr: 1e-5 7 | metrics: 8 | - variance_dist 9 | - mse_loss 10 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/ode_256.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.ode_baseline.ODEBaseline 4 | dim: 2 5 | w: 256 6 | lr: 1e-5 7 | metrics: 8 | - variance_dist 9 | - mse_loss 10 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/sde.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.ode_baseline.SDEBaseline 4 | dim: 2 5 | w: 64 6 | lr: 1e-5 7 | metrics: 8 | - variance_dist 9 | - mse_loss 10 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/sde_256.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.ode_baseline.SDEBaseline 4 | dim: 2 5 | w: 256 6 | lr: 1e-5 7 | metrics: 8 | - variance_dist 9 | - mse_loss 10 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/tfm_ablated_size_memory_uncertainty.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | dim: 2 9 | implementation: SDE 10 | sde_noise: 0.1 11 | clip: 1e-2 12 | sigma: 0.1 13 | metrics: 14 | - variance_dist 15 | - mse_loss 16 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/tfm_ablated_uncertainty.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | w: 256 9 | implementation: SDE 10 | clip: 1e-2 11 | sigma: 0.1 12 | dim: 2 13 | metrics: 14 | - variance_dist 15 | - mse_loss 16 | - l1_loss -------------------------------------------------------------------------------- /src/conf/model/tfm_ode.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_noise.Noise_MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | dim: 2 9 | w: 256 10 | clip: 1e-2 11 | sigma: 0.1 12 | metrics: 13 | - variance_dist 14 | - mse_loss 15 | - l1_loss 16 | -------------------------------------------------------------------------------- /src/conf/model/tfm_ode_ablated_size_memory_uncertainty.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | dim: 2 9 | clip: 1e-2 10 | sigma: 0.1 11 | metrics: 12 | - variance_dist 13 | - mse_loss 14 | - l1_loss 15 | -------------------------------------------------------------------------------- /src/conf/model/tfm_ode_ablated_uncertainty.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_memory.MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | w: 256 9 | clip: 1e-2 10 | sigma: 0.1 11 | dim: 2 12 | metrics: 13 | - variance_dist 14 | - mse_loss 15 | - l1_loss 16 | -------------------------------------------------------------------------------- /src/conf/model/tfm_sde.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_module: 3 | _target_: model.mlp_noise.Noise_MLP_Cond_Memory_Module 4 | time_varying: true 5 | conditional: true 6 | treatment_cond: 0 7 | memory: 3 8 | dim: 2 9 | w: 256 10 | clip: 1e-2 11 | sigma: 0.1 12 | metrics: 13 | - variance_dist 14 | - mse_loss 15 | - l1_loss 16 | implementation: SDE -------------------------------------------------------------------------------- /src/conf/perturb_config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - data: eICU 4 | - model: noise_mlp_memory_sde 5 | 6 | 7 | wandb_logging: true 8 | wandb_project: clinical_trajectory 9 | wandb_dir: wandb_log/ 10 | ckpt_dir: checkpoints/ 11 | max_epochs: 200 12 | max_time: 00:48:00:00 13 | check_val_every_n_epoch: 10 14 | skip_training: false 15 | mode: None 16 | seed: 42 -------------------------------------------------------------------------------- /src/conf/trainer.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | max_epochs: 20000 3 | check_val_every_n_epoch: 100 4 | gpus: 1 5 | precision: 32 6 | -------------------------------------------------------------------------------- /src/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import pandas as pd 3 | import pickle 4 | import numpy as np 5 | 6 | # from pytorch_lightning.utilities.types import EVAL_DATALOADERS 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | """Helper class for specific training/testing dataset""" 11 | class TrainingDataset(Dataset): 12 | def __init__(self, x0_values, x0_classes, x1_values, times_x0, times_x1): 13 | self.x0_values = x0_values 14 | self.x0_classes = x0_classes 15 | self.x1_values = x1_values 16 | self.times_x0 = times_x0 17 | self.times_x1 = times_x1 18 | 19 | def __len__(self): 20 | return len(self.x0_values) 21 | 22 | def __getitem__(self, idx): 23 | return (self.x0_values[idx], self.x0_classes[idx], self.x1_values[idx], self.times_x0[idx], self.times_x1[idx]) 24 | 25 | class PatientDataset(Dataset): 26 | def __init__(self, patient_data): 27 | self.patient_data = patient_data 28 | 29 | def __len__(self): 30 | return len(self.patient_data) 31 | 32 | def __getitem__(self, idx): 33 | return self.patient_data[idx] 34 | 35 | class clinical_DataModule(pl.LightningDataModule): 36 | 37 | """returns: 38 | x0_values, x0_classes, x1_values, times_x0, times_x1 39 | """ 40 | 41 | def __init__(self, 42 | train_consecutive=False, 43 | batch_size=256, 44 | file_path=None, 45 | naming = None, 46 | t_headings = None, 47 | x_headings = [None], 48 | cond_headings = [None], 49 | memory=0): 50 | super().__init__() 51 | self.batch_size = batch_size 52 | self.file_path = file_path 53 | self.x_headings = x_headings 54 | self.cond_headings = cond_headings 55 | self.t_headings = t_headings 56 | self.input_dim = len(self.x_headings) + len(self.cond_headings) 57 | self.output_dim = len(self.x_headings) 58 | self.naming = naming 59 | self.memory = memory 60 | self.min_timept = 5 + self.memory 61 | self.train_consecutive = train_consecutive 62 | print("DataModule initialized to x_headings: ", self.x_headings, " cond_headings: ", self.cond_headings, " t_headings: ", self.t_headings, " train_consecutive: ", self.train_consecutive) 63 | 64 | def prepare_data(self): 65 | pass 66 | 67 | def __filter_data(self, data_set): 68 | # filter out data with less than 5 time points 69 | return data_set.groupby('HADM_ID').filter(lambda x: len(x) > self.min_timept) 70 | 71 | def __unpack__(self, data_set): 72 | x = data_set[self.x_headings].values 73 | cond = data_set[self.cond_headings].values 74 | t = data_set[self.t_headings].values 75 | return x, cond, t 76 | 77 | def __getitem__(self, idx): 78 | sample = self.data.iloc[idx].values.astype(np.float32) 79 | return sample 80 | 81 | def setup(self, stage=None): 82 | self.data = pd.read_pickle(self.file_path) 83 | if stage == 'fit' or stage is None: 84 | self.train = self.__filter_data(self.data['train']) 85 | self.val = self.__filter_data(self.data['val']) 86 | if stage == 'test' or stage is None: 87 | self.test = self.__filter_data(self.data['test']) 88 | 89 | def __sort_group__(self, data_set): 90 | grouped = data_set.groupby('HADM_ID') 91 | grouped_sorted = grouped.apply(lambda x: x.sort_values([self.t_headings], ascending = True)).reset_index(drop=True) 92 | return grouped_sorted 93 | 94 | def train_dataloader(self, shuffle=True): 95 | if not self.train_consecutive: 96 | train_data = self.create_pairs(self.train) 97 | # print(len(train_data)) 98 | train_dataset = TrainingDataset(*train_data) 99 | return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=1) 100 | else: 101 | train_data = self.create_patient_data_t0(self.train) 102 | train_dataset = PatientDataset(train_data) 103 | return DataLoader(train_dataset, batch_size=1, shuffle=shuffle, num_workers=1) 104 | 105 | def val_dataloader(self): 106 | if self.train_consecutive: 107 | val_data = self.create_patient_data_t0(self.val) 108 | val_dataset = PatientDataset(val_data) 109 | return DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) 110 | val_data = self.create_patient_data(self.val) 111 | val_dataset = PatientDataset(val_data) 112 | return DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) 113 | 114 | def test_dataloader(self): 115 | if self.train_consecutive: 116 | test_data = self.create_patient_data_t0(self.test) 117 | test_dataset = PatientDataset(test_data) 118 | return DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) 119 | test_data = self.create_patient_data(self.test) 120 | test_dataset = PatientDataset(test_data) 121 | return DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) 122 | 123 | def create_patient_data(self, df): 124 | """Create formatted patient data from the DataFrame 125 | ex. patient 0: x0_values, x0_classes, times_x0... 126 | 127 | Args: 128 | df (_type_): _description_ 129 | time_column (str, optional): _description_. Defaults to 'time_normalized'. 130 | """ 131 | patient_lst = [] 132 | for _, group in df.groupby('HADM_ID'): 133 | 134 | x0_values = [] 135 | x0_classes = [] 136 | x1_values = [] 137 | times_x0 = [] 138 | times_x1 = [] 139 | 140 | sorted_group = group.sort_values(by=self.t_headings) 141 | x0_values, x0_classes, x1_values, times_x0, times_x1 = self.create_pairs(sorted_group) 142 | 143 | if len(self.cond_headings)<2: 144 | x0_classes = np.expand_dims(x0_classes, axis=1) 145 | else: 146 | x0_classes = x0_classes.squeeze().astype(np.float32) 147 | 148 | 149 | patient_lst.append((x0_values.squeeze().astype(np.float32), 150 | x0_classes, 151 | x1_values.squeeze().astype(np.float32), 152 | times_x0.squeeze().astype(np.float32), 153 | times_x1.squeeze().astype(np.float32))) 154 | return patient_lst 155 | 156 | 157 | def create_pairs(self, df): 158 | """create pairs of consecutive points from the DataFrame (for training the model) 159 | 160 | Args: 161 | df (pandas.DataFrame): _description_ 162 | time_column (str, optional): _description_. Defaults to 'time_normalized'. 163 | 164 | Returns: 165 | numpy.array : x0_values, x0_classes, x1_values, times_x0, times_x1 166 | 167 | """ 168 | # Initialize empty lists to store the components of the pairs 169 | x0_values = [] 170 | x0_classes = [] 171 | x1_values = [] 172 | times_x0 = [] 173 | times_x1 = [] 174 | 175 | # Group the DataFrame by HADM_ID and iterate through each group 176 | for _, group in df.groupby('HADM_ID'): 177 | # Sort the group by time_normalized 178 | sorted_group = group.sort_values(by=self.t_headings) 179 | 180 | # Iterate through the sorted group to create pairs of consecutive points 181 | for i in range(self.memory,len(sorted_group) - 1): 182 | x0 = sorted_group.iloc[i] 183 | x0_class = x0[self.cond_headings].values 184 | x0_value = x0[self.x_headings].values 185 | 186 | x1 = sorted_group.iloc[i + 1] 187 | x1_value = x1[self.x_headings].values 188 | 189 | # memory component 190 | if self.memory>0: 191 | x0_memory = sorted_group.iloc[i - self.memory:i] 192 | x0_memory_flatten = x0_memory[self.x_headings].values.flatten() 193 | x0_class = np.append(x0_class, x0_memory_flatten) 194 | 195 | x0_values.append(x0_value) 196 | x0_classes.append(x0_class) 197 | x1_values.append(x1_value) 198 | times_x0.append(x0[self.t_headings]) 199 | times_x1.append(x1[self.t_headings]) 200 | 201 | # Convert the lists to arrays 202 | x0_values = np.array(x0_values).squeeze().astype(np.float32) 203 | x0_classes = np.array(x0_classes).squeeze().astype(np.float32) 204 | x1_values = np.array(x1_values).squeeze().astype(np.float32) 205 | times_x0 = np.array(times_x0).squeeze().astype(np.float32) 206 | times_x1 = np.array(times_x1).squeeze().astype(np.float32) 207 | 208 | if len(self.cond_headings)<2: 209 | x0_classes = np.expand_dims(x0_classes, axis=1) 210 | 211 | return x0_values, x0_classes, x1_values, times_x0, times_x1 212 | 213 | 214 | def create_patient_data_t0(self, df): 215 | """Create formatted patient data from the DataFrame 216 | ex. patient 0: x0_values, x0_classes, times_x0... 217 | This version has x0 constant and x1 varying (as well as t) 218 | 219 | Args: 220 | df (_type_): _description_ 221 | time_column (str, optional): _description_. Defaults to 'time_normalized'. 222 | """ 223 | patient_lst = [] 224 | for _, group in df.groupby('HADM_ID'): 225 | 226 | x0_values = [] 227 | x0_classes = [] 228 | x1_values = [] 229 | times_x0 = [] 230 | times_x1 = [] 231 | 232 | sorted_group = group.sort_values(by=self.t_headings) 233 | x0_values, x0_classes, x1_values, times_x0, times_x1 = self.create_pairs(sorted_group) 234 | 235 | if len(self.cond_headings)<2: 236 | x0_classes = np.expand_dims(x0_classes, axis=1) 237 | else: 238 | x0_classes = x0_classes.squeeze().astype(np.float32) 239 | 240 | # repeat the first point for x0_values and x0_classes, and times_x0 241 | x0_values = np.repeat(x0_values[0][None, :], len(x0_values), axis=0) 242 | x0_classes = np.repeat(x0_classes[0][None, :], len(x0_values), axis=0) 243 | times_x0 = np.repeat(times_x0[0], len(x0_values)) 244 | 245 | patient_lst.append((x0_values.squeeze().astype(np.float32), 246 | x0_classes, 247 | x1_values.squeeze().astype(np.float32), 248 | times_x0.squeeze().astype(np.float32), 249 | times_x1.squeeze().astype(np.float32))) 250 | return patient_lst 251 | 252 | @property 253 | def dims(self): 254 | # x, cond, t 255 | return len(self.x_headings), len(self.cond_headings), len(self.t_headings) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # training script for lightning model 2 | """ 3 | run via: 4 | (changing data and model) 5 | python main.py data=data/data3.yaml model=model/model3.yaml 6 | 7 | to test: 8 | python main.py skip_training=true 9 | """ 10 | import pytorch_lightning as pl 11 | import torch 12 | from torch import nn 13 | import pandas as pd 14 | import numpy as np 15 | import wandb 16 | from pytorch_lightning.loggers import WandbLogger 17 | import os 18 | 19 | import pytorch_lightning as pl 20 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 21 | 22 | 23 | import hydra 24 | from omegaconf import OmegaConf 25 | import pytorch_lightning as pl 26 | from hydra.utils import instantiate 27 | 28 | from pytorch_lightning.strategies import DDPStrategy 29 | 30 | @hydra.main(config_path="conf", config_name="config") 31 | def train_model(cfg): 32 | # set seed 33 | pl.seed_everything(cfg.seed) 34 | 35 | print(cfg) 36 | if 'memory' in cfg['model_module'].keys(): # if is key 37 | memory = cfg['model_module']['memory'] 38 | cfg.data_module.memory = memory 39 | 40 | data_module = instantiate(cfg.data_module) 41 | 42 | # correct dim 43 | x_dim = data_module.dims[0] 44 | if 'dim' in cfg.model_module.keys(): 45 | cfg.model_module.dim = x_dim 46 | elif 'input_dim' in cfg.model_module.keys(): 47 | cfg.model_module.input_dim = x_dim 48 | cfg.model_module.output_dim = x_dim 49 | 50 | 51 | if 'treatment_cond' in cfg.model_module.keys(): 52 | # for conditional models, need this to configure 53 | cfg.model_module.treatment_cond = len(data_module.cond_headings) 54 | 55 | model = instantiate(cfg.model_module) 56 | 57 | # conditional models need train_consecutive false! 58 | if not('Cond' in model.naming): 59 | cfg.data_module.train_consecutive = True 60 | data_module = instantiate(cfg.data_module) 61 | else: 62 | cfg.data_module.train_consecutive = False 63 | data_module = instantiate(cfg.data_module) 64 | 65 | wandb_config = {key: value for key, value in cfg.model_module.items() if key not in ['_target_']} 66 | wandb_config['model'] = model.naming 67 | wandb_config['data'] = data_module.naming 68 | wandb_config['mode'] = 'batch_run' 69 | wandb_config['x_headings'] = data_module.x_headings 70 | wandb_config['cond_headings'] = data_module.cond_headings 71 | wandb_config['t_headings'] = data_module.t_headings 72 | wandb_config['seed'] = cfg.seed 73 | 74 | if cfg.wandb_logging and not(cfg.skip_training): 75 | wandb.init(project="clinical_trajectory", 76 | dir = '/home/mila/x/xi.zhang/scratch/shung_ICU/wandb_log/', 77 | config = wandb_config 78 | ) 79 | wandb_logger = WandbLogger() 80 | 81 | ckpt_savedir = '/home/mila/x/xi.zhang/scratch/shung_ICU/checkpoints/'+model.naming+'_'+data_module.naming+'/' 82 | if not os.path.exists(ckpt_savedir): 83 | os.makedirs(ckpt_savedir) 84 | 85 | checkpoint_callback = ModelCheckpoint( 86 | dirpath='/home/mila/x/xi.zhang/scratch/shung_ICU/checkpoints/'+model.naming+'_'+data_module.naming+'/', 87 | filename='best_model', 88 | save_top_k=1, 89 | verbose=True, 90 | monitor='val_loss', 91 | mode='min', 92 | save_last=True 93 | ) 94 | 95 | early_stopping_callback = EarlyStopping( 96 | monitor='val_loss', 97 | patience=3, 98 | verbose=True, 99 | mode='min' 100 | ) 101 | 102 | strategy_ddps = DDPStrategy(find_unused_parameters=True) 103 | 104 | trainer = pl.Trainer( 105 | max_epochs=cfg.max_epochs, 106 | max_time=cfg.max_time, 107 | check_val_every_n_epoch=50, 108 | callbacks=[checkpoint_callback, early_stopping_callback], 109 | accelerator='gpu' if torch.cuda.is_available() else 'cpu', 110 | logger=wandb_logger if cfg.wandb_logging else None, 111 | limit_train_batches=0 if cfg.skip_training else 1.0, 112 | strategy=strategy_ddps, 113 | ) 114 | 115 | # Train the model 116 | trainer.fit(model, datamodule=data_module) 117 | 118 | # Test the model 119 | trainer.test(model, datamodule=data_module) 120 | 121 | wandb.finish() 122 | 123 | def main(): 124 | train_model() 125 | 126 | if __name__ == '__main__': 127 | main() -------------------------------------------------------------------------------- /src/model/FM_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import ot as pot 8 | import torchdyn 9 | from torchdyn.core import NeuralODE 10 | import pytorch_lightning as pl 11 | from torch import optim 12 | import torch.functional as F 13 | import wandb 14 | 15 | from utils.visualize import * 16 | from utils.metric_calc import * 17 | from utils.sde import SDE 18 | 19 | from model.components.positional_encoding import * 20 | from model.components.grad_util import torch_wrapper_tv 21 | 22 | 23 | class MLP(torch.nn.Module): 24 | def __init__(self, dim, out_dim=None, w=64, time_varying=False): 25 | super().__init__() 26 | self.time_varying = time_varying 27 | if out_dim is None: 28 | out_dim = dim 29 | self.net = torch.nn.Sequential( 30 | torch.nn.Linear(dim + (1 if time_varying else 0), w), 31 | torch.nn.SELU(), 32 | torch.nn.Linear(w, w), 33 | torch.nn.SELU(), 34 | torch.nn.Linear(w, w), 35 | torch.nn.SELU(), 36 | torch.nn.Linear(w, out_dim), 37 | ) 38 | 39 | def forward(self, x, *args, **kwargs): 40 | return self.net(x) 41 | 42 | 43 | # conditional liver model 44 | class MLP_conditional_liver(torch.nn.Module): 45 | """ Conditional with many available classes 46 | 47 | return the class as is 48 | """ 49 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False): 50 | super().__init__() 51 | self.time_varying = time_varying 52 | if out_dim is None: 53 | self.out_dim = dim 54 | self.treatment_cond = treatment_cond 55 | self.dim = dim 56 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0) 57 | self.net = torch.nn.Sequential( 58 | torch.nn.Linear(self.indim, w), 59 | torch.nn.SELU(), 60 | torch.nn.Linear(w, w), 61 | torch.nn.SELU(), 62 | torch.nn.Linear(w, w), 63 | torch.nn.SELU(), 64 | torch.nn.Linear(w,self.out_dim), 65 | ) 66 | self.default_class = 0 67 | 68 | 69 | def forward(self, x): 70 | """forward pass 71 | Assume first two dimensions are x, c, then t 72 | """ 73 | result = self.net(x) 74 | return torch.cat([result, x[:,self.dim:-1]], dim=1) 75 | 76 | class FM_baseline(torch.nn.Module): 77 | """ Conditional with many available classes 78 | 79 | return the class as is 80 | """ 81 | def __init__(self, dim, 82 | out_dim=None, 83 | w=64, 84 | time_varying=False, 85 | conditional=False, 86 | treatment_cond = 0, 87 | time_dim = NUM_FREQS * 2, 88 | clip = None): 89 | super().__init__() 90 | self.time_varying = time_varying 91 | if out_dim is None: 92 | self.out_dim = dim 93 | self.out_dim += 1 94 | self.treatment_cond = treatment_cond 95 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) 96 | self.net = torch.nn.Sequential( 97 | torch.nn.Linear(self.indim, w), 98 | torch.nn.SELU(), 99 | torch.nn.Linear(w, w), 100 | torch.nn.SELU(), 101 | torch.nn.Linear(w, w), 102 | torch.nn.SELU(), 103 | torch.nn.Linear(w,self.out_dim), 104 | ) 105 | self.default_class = 0 106 | # self.encoding_function = positional_encoding_tensor() 107 | self.clip = clip 108 | 109 | def encoding_function(self, time_tensor): 110 | return positional_encoding_tensor(time_tensor) 111 | 112 | def forward_train(self, x): 113 | """forward pass 114 | Assume first two dimensions are x, c, then t 115 | input: x0 116 | output: vt 117 | """ 118 | time_tensor = x[:,-1] 119 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 120 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32) 121 | result = self.net(new_x) 122 | return torch.cat([result[:,:-1], result[:,-1].unsqueeze(1)], dim=1) 123 | 124 | def forward(self,x): 125 | """Function for simulation testing 126 | 127 | 128 | Args: 129 | x (_type_): x + time dimension 130 | 131 | Returns: 132 | forward_train(x)[:,:-1]: x without time dimension 133 | """ 134 | return self.forward_train(x)[:,:-1] 135 | 136 | 137 | 138 | """ Lightning module """ 139 | def mse_loss(pred, true): 140 | return torch.mean((pred - true) ** 2) 141 | 142 | def l1_loss(pred, true): 143 | return torch.mean(torch.abs(pred - true)) 144 | 145 | class MLP_FM(pl.LightningModule): 146 | def __init__(self, 147 | treatment_cond, 148 | dim=2, 149 | w=64, 150 | time_varying=True, 151 | conditional=False, 152 | lr=1e-6, 153 | sigma = 0.1, 154 | loss_fn = mse_loss, 155 | metrics = ['mse_loss', 'l1_loss'], 156 | implementation = "ODE", # can be SDE 157 | sde_noise = 0.1, 158 | clip = None, # float 159 | naming = None, 160 | ): 161 | super().__init__() 162 | self.model = FM_baseline(dim=dim, 163 | w=w, 164 | time_varying=time_varying, 165 | conditional=conditional, # no conditional for baseline 166 | clip = clip, 167 | treatment_cond=treatment_cond) 168 | self.loss_fn = loss_fn 169 | self.save_hyperparameters() 170 | self.dim = dim 171 | # self.out_dim = out_dim 172 | self.w = w 173 | self.time_varying = time_varying 174 | self.conditional = conditional 175 | self.treatment_cond = treatment_cond 176 | self.lr = lr 177 | self.sigma = sigma 178 | self.naming = "FM_baseline_"+implementation 179 | self.metrics = metrics 180 | self.implementation = implementation 181 | self.sde_noise = sde_noise 182 | self.clip = clip 183 | 184 | 185 | def __convert_tensor__(self, tensor): 186 | return tensor.to(torch.float32) 187 | 188 | def __x_processing__(self, x0, x1, t0, t1): 189 | # squeeze xs (prevent mismatch) 190 | x0 = x0.squeeze(0) 191 | x1 = x1.squeeze(0) 192 | t0 = t0.squeeze() 193 | t1 = t1.squeeze() 194 | 195 | t = torch.rand(x0.shape[0],1).to(x0.device) 196 | mu_t = x0 * (1 - t) + x1 * t 197 | data_t_diff = (t1 - t0).unsqueeze(1) 198 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device) 199 | ut = (x1 - x0) / (data_t_diff + 1e-4) 200 | t_model = t * data_t_diff + t0.unsqueeze(1) 201 | futuretime = t1 - t_model 202 | return x, ut, t_model, futuretime, t 203 | 204 | def training_step(self, batch, batch_idx): 205 | """_summary_ 206 | 207 | Args: 208 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1 209 | batch_idx (_type_): _description_ 210 | 211 | Returns: 212 | _type_: _description_ 213 | """ 214 | x0, x0_class, x1, x0_time, x1_time = batch 215 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time) 216 | 217 | 218 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time) 219 | 220 | 221 | if len(x0_class.shape) == 3: 222 | x0_class = x0_class.squeeze(0) 223 | 224 | # in_tensor = torch.cat([x,x0_class, t_model], dim = -1) 225 | in_tensor = torch.cat([x, t_model], dim = -1) 226 | vt = self.model.forward_train(in_tensor) 227 | 228 | # SDE: inject noise in the loss 229 | if self.implementation == "SDE": 230 | variance = t*(1-t)*(self.sde_noise ** 2) 231 | noise = torch.randn_like(vt[:,:self.dim]) * torch.sqrt(variance) 232 | loss = self.loss_fn(vt[:,:self.dim]+noise, ut) + self.loss_fn(vt[:,-1], futuretime) 233 | else: 234 | loss = self.loss_fn(vt[:,:self.dim], ut) + self.loss_fn(vt[:,-1], futuretime) 235 | self.log('train_loss', loss) 236 | return loss 237 | 238 | def config_optimizer(self): 239 | return torch.optim.Adam(self.parameters(), lr=self.lr) 240 | 241 | def validation_step(self, batch, batch_idx): 242 | """validation_step 243 | 244 | Args: 245 | batch (_type_): batch size of 1 (since uneven) 246 | batch_idx (_type_): _description_ 247 | 248 | Returns: 249 | _type_: _description_ 250 | """ 251 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val') 252 | self.log('val_loss', loss) 253 | for key, value in metricD.items(): 254 | self.log(key+"_val", value) 255 | # return total_loss, traj_pairs 256 | return {'val_loss':loss, 'traj_pairs':pairs} 257 | 258 | def test_step(self, batch, batch_idx): 259 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test') 260 | self.log('test_loss', loss) 261 | for key, value in metricD.items(): 262 | self.log(key+"_test", value) 263 | # return total_loss, traj_pairs 264 | return {'test_loss':loss, 'traj_pairs':pairs} 265 | 266 | def test_func_step(self, batch, batch_idx, mode='none'): 267 | """assuming each is one patient/batch""" 268 | total_loss = [] 269 | traj_pairs = [] 270 | 271 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch 272 | times_x0 = times_x0.squeeze() 273 | times_x1 = times_x1.squeeze() 274 | 275 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 276 | x1_values[0,:,:self.dim]], 277 | dim=0) 278 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0) 279 | ind_loss, pred_traj = self.test_trajectory(batch) 280 | total_loss.append(ind_loss) 281 | traj_pairs.append([full_traj, pred_traj]) 282 | 283 | full_traj = full_traj.detach().cpu().numpy() 284 | pred_traj = pred_traj.detach().cpu().numpy() 285 | full_time = full_time.detach().cpu().numpy() 286 | 287 | # graph 288 | fig = plot_3d_path_ind(pred_traj, 289 | full_traj, 290 | t_span=full_time, 291 | title="{}_trajectory_patient_{}".format(mode, batch_idx)) 292 | if self.logger: 293 | # may cause problem if wandb disabled 294 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)}) 295 | 296 | plt.close(fig) 297 | 298 | # metrics 299 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics) 300 | return np.mean(total_loss), traj_pairs, metricD 301 | 302 | def test_trajectory(self,pt_tensor): 303 | if self.implementation == "ODE": 304 | return self.test_trajectory_ode(pt_tensor) 305 | elif self.implementation == "SDE": 306 | return self.test_trajectory_sde(pt_tensor) 307 | 308 | def test_trajectory_ode(self,pt_tensor): 309 | """test_trajectory 310 | 311 | Args: 312 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 313 | 314 | 315 | Returns: 316 | mse_all, total_pred_tensor: _description_ 317 | """ 318 | node = NeuralODE( 319 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 320 | ) 321 | total_pred = [] 322 | mse = [] 323 | t_max = 0 324 | 325 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 326 | # squeeze all 327 | x0_values = x0_values.squeeze(0) 328 | x1_values = x1_values.squeeze(0) 329 | times_x0 = times_x0.squeeze() 330 | times_x1 = times_x1.squeeze() 331 | x0_classes = x0_classes.squeeze() 332 | 333 | if len(x0_classes.shape) == 1: 334 | x0_classes = x0_classes.unsqueeze(1) 335 | 336 | 337 | 338 | total_pred.append(x0_values[0].unsqueeze(0)) 339 | len_path = x0_values.shape[0] 340 | assert len_path == x1_values.shape[0] 341 | for i in range(len_path): 342 | # print(i) 343 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative) 344 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 345 | # encoded_time_span = positional_encoding_tensor(time_span).squeeze(1).reshape(-1, NUM_FREQS * 2) 346 | with torch.no_grad(): 347 | # get last pred, if none then use startpt 348 | if i == 0: 349 | testpt = torch.cat([x0_values[i].unsqueeze(0)],dim=1) 350 | else: # incorporate last prediction 351 | testpt = pred_traj 352 | # print(testpt.shape) 353 | traj = node.trajectory( 354 | testpt, 355 | t_span=time_span, 356 | ) 357 | pred_traj = traj[-1,:,:self.dim] 358 | total_pred.append(pred_traj) 359 | ground_truth_coords = x1_values[i] 360 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 361 | mse_all = np.mean(mse) 362 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 363 | return mse_all, total_pred_tensor 364 | 365 | def configure_optimizers(self): 366 | # Define the optimizer 367 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 368 | return optimizer 369 | 370 | def test_trajectory_sde(self,pt_tensor): 371 | """test_trajectory 372 | 373 | Args: 374 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 375 | 376 | 377 | Returns: 378 | mse_all, total_pred_tensor: _description_ 379 | """ 380 | sde = SDE(self.model, noise=self.sde_noise) 381 | total_pred = [] 382 | mse = [] 383 | t_max = 0 384 | 385 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 386 | # squeeze all 387 | x0_values = x0_values.squeeze(0) 388 | x1_values = x1_values.squeeze(0) 389 | times_x0 = times_x0.squeeze() 390 | times_x1 = times_x1.squeeze() 391 | x0_classes = x0_classes.squeeze() 392 | 393 | if len(x0_classes.shape) == 1: 394 | x0_classes = x0_classes.unsqueeze(1) 395 | 396 | 397 | 398 | total_pred.append(x0_values[0].unsqueeze(0)) 399 | len_path = x0_values.shape[0] 400 | assert len_path == x1_values.shape[0] 401 | for i in range(len_path): 402 | # print(i) 403 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative) 404 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 405 | 406 | with torch.no_grad(): 407 | # get last pred, if none then use startpt 408 | if i == 0: 409 | testpt = torch.cat([x0_values[i].unsqueeze(0)],dim=1) 410 | else: 411 | testpt = pred_traj 412 | # print(testpt.shape) 413 | traj = self._sde_solver(sde, testpt, time_span) 414 | 415 | pred_traj = traj[-1,:,:self.dim] 416 | total_pred.append(pred_traj) 417 | ground_truth_coords = x1_values[i] 418 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 419 | mse_all = np.mean(mse) 420 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 421 | return mse_all, total_pred_tensor 422 | 423 | 424 | def _sde_solver(self, sde, initial_state, time_span): 425 | dt = time_span[1] - time_span[0] # Time step 426 | current_state = initial_state 427 | trajectory = [current_state] 428 | 429 | for t in time_span[1:]: 430 | drift = sde.f(t, current_state) 431 | diffusion = sde.g(t, current_state) 432 | noise = torch.randn_like(current_state) * torch.sqrt(dt) 433 | current_state = current_state + drift * dt + diffusion * noise 434 | trajectory.append(current_state) 435 | 436 | return torch.stack(trajectory) 437 | 438 | -------------------------------------------------------------------------------- /src/model/base_models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ########################### 4 | # Adapted from: 5 | # Latent ODEs for Irregularly-Sampled Time Series 6 | # Author: Yulia Rubanova 7 | ########################### 8 | 9 | import utils.latent_ode_utils as utils 10 | from utils.latent_ode_utils import get_device 11 | 12 | from torch.distributions.multivariate_normal import MultivariateNormal 13 | from torch.distributions.normal import Normal 14 | from torch.distributions import kl_divergence, Independent 15 | 16 | 17 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None): 18 | n_data_points = mu_2d.size()[-1] 19 | 20 | if n_data_points > 0: 21 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1) 22 | log_prob = gaussian.log_prob(data_2d) 23 | log_prob = log_prob / n_data_points 24 | else: 25 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() 26 | return log_prob 27 | 28 | 29 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas): 30 | # masked_log_lambdas and masked_data 31 | n_data_points = masked_data.size()[-1] 32 | 33 | if n_data_points > 0: 34 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices] 35 | #log_prob = log_prob / n_data_points 36 | else: 37 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze() 38 | return log_prob 39 | 40 | 41 | 42 | def compute_binary_CE_loss(label_predictions, mortality_label): 43 | #print("Computing binary classification loss: compute_CE_loss") 44 | 45 | mortality_label = mortality_label.reshape(-1) 46 | 47 | if len(label_predictions.size()) == 1: 48 | label_predictions = label_predictions.unsqueeze(0) 49 | 50 | n_traj_samples = label_predictions.size(0) 51 | label_predictions = label_predictions.reshape(n_traj_samples, -1) 52 | 53 | idx_not_nan = ~torch.isnan(mortality_label) 54 | if len(idx_not_nan) == 0.: 55 | print("All are labels are NaNs!") 56 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label)) 57 | 58 | label_predictions = label_predictions[:,idx_not_nan] 59 | mortality_label = mortality_label[idx_not_nan] 60 | 61 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0: 62 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.") 63 | 64 | assert(not torch.isnan(label_predictions).any()) 65 | assert(not torch.isnan(mortality_label).any()) 66 | 67 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 68 | mortality_label = mortality_label.repeat(n_traj_samples, 1) 69 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) 70 | 71 | # divide by number of patients in a batch 72 | ce_loss = ce_loss / n_traj_samples 73 | return ce_loss 74 | 75 | 76 | def compute_multiclass_CE_loss(label_predictions, true_label, mask): 77 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss") 78 | 79 | if (len(label_predictions.size()) == 3): 80 | label_predictions = label_predictions.unsqueeze(0) 81 | 82 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size() 83 | 84 | # assert(not torch.isnan(label_predictions).any()) 85 | # assert(not torch.isnan(true_label).any()) 86 | 87 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 88 | true_label = true_label.repeat(n_traj_samples, 1, 1) 89 | 90 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims) 91 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims) 92 | 93 | # choose time points with at least one measurement 94 | mask = torch.sum(mask, -1) > 0 95 | 96 | # repeat the mask for each label to mark that the label for this time point is present 97 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0) 98 | 99 | label_mask = mask 100 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1) 101 | label_mask = label_mask.repeat(n_traj_samples,1,1,1) 102 | 103 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims) 104 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1) 105 | 106 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): 107 | assert(label_predictions.size(-1) == true_label.size(-1)) 108 | # targets are in one-hot encoding -- convert to indices 109 | _, true_label = true_label.max(-1) 110 | 111 | res = [] 112 | for i in range(true_label.size(0)): 113 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool()) 114 | labels = torch.masked_select(true_label[i], label_mask[i].bool()) 115 | 116 | pred_masked = pred_masked.reshape(-1, n_dims) 117 | 118 | if (len(labels) == 0): 119 | continue 120 | 121 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long()) 122 | res.append(ce_loss) 123 | 124 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions)) 125 | ce_loss = torch.mean(ce_loss) 126 | # # divide by number of patients in a batch 127 | # ce_loss = ce_loss / n_traj_samples 128 | return ce_loss 129 | 130 | 131 | 132 | 133 | def compute_masked_likelihood(mu, data, mask, likelihood_func): 134 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 135 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 136 | 137 | res = [] 138 | for i in range(n_traj_samples): 139 | for k in range(n_traj): 140 | for j in range(n_dims): 141 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool()) 142 | 143 | #assert(torch.sum(data_masked == 0.) < 10) 144 | 145 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool()) 146 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j)) 147 | res.append(log_prob) 148 | # shape: [n_traj*n_traj_samples, 1] 149 | 150 | res = torch.stack(res, 0).to(get_device(data)) 151 | res = res.reshape((n_traj_samples, n_traj, n_dims)) 152 | # Take mean over the number of dimensions 153 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean 154 | res = res.transpose(0,1) 155 | return res 156 | 157 | 158 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None): 159 | # these cases are for plotting through plot_estim_density 160 | if (len(mu.size()) == 3): 161 | # add additional dimension for gp samples 162 | mu = mu.unsqueeze(0) 163 | 164 | if (len(data.size()) == 2): 165 | # add additional dimension for gp samples and time step 166 | data = data.unsqueeze(0).unsqueeze(2) 167 | elif (len(data.size()) == 3): 168 | # add additional dimension for gp samples 169 | data = data.unsqueeze(0) 170 | 171 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 172 | 173 | assert(data.size()[-1] == n_dims) 174 | 175 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 176 | if mask is None: 177 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 178 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 179 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 180 | 181 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std) 182 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1) 183 | else: 184 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 185 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices) 186 | res = compute_masked_likelihood(mu, data, mask, func) 187 | return res 188 | 189 | 190 | 191 | def mse(mu, data, indices = None): 192 | n_data_points = mu.size()[-1] 193 | 194 | if n_data_points > 0: 195 | mse = nn.MSELoss()(mu, data) 196 | else: 197 | mse = torch.zeros([1]).to(get_device(data)).squeeze() 198 | return mse 199 | 200 | 201 | def compute_mse(mu, data, mask = None): 202 | # these cases are for plotting through plot_estim_density 203 | if (len(mu.size()) == 3): 204 | # add additional dimension for gp samples 205 | mu = mu.unsqueeze(0) 206 | 207 | if (len(data.size()) == 2): 208 | # add additional dimension for gp samples and time step 209 | data = data.unsqueeze(0).unsqueeze(2) 210 | elif (len(data.size()) == 3): 211 | # add additional dimension for gp samples 212 | data = data.unsqueeze(0) 213 | 214 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 215 | assert(data.size()[-1] == n_dims) 216 | 217 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 218 | if mask is None: 219 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 220 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 221 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 222 | res = mse(mu_flat, data_flat) 223 | else: 224 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 225 | res = compute_masked_likelihood(mu, data, mask, mse) 226 | return res 227 | 228 | 229 | 230 | 231 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None): 232 | # Compute Poisson likelihood 233 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process 234 | # Sum log lambdas across all time points 235 | if mask is None: 236 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"] 237 | # Sum over data dims 238 | poisson_log_l = torch.mean(poisson_log_l, -1) 239 | else: 240 | # Compute likelihood of the data under the predictions 241 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 242 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1) 243 | 244 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 245 | int_lambda = info["int_lambda"] 246 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda) 247 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f) 248 | poisson_log_l = poisson_log_l.permute(1,0) 249 | # Take mean over n_traj 250 | #poisson_log_l = torch.mean(poisson_log_l, 1) 251 | 252 | # poisson_log_l shape: [n_traj_samples, n_traj] 253 | return poisson_log_l 254 | 255 | 256 | import numpy as np 257 | import torch 258 | import torch.nn as nn 259 | from torch.nn.functional import relu 260 | 261 | from utils.latent_ode_utils import * 262 | from model.latent_ode import * 263 | # from lib.likelihood_eval import * 264 | 265 | from torch.distributions.multivariate_normal import MultivariateNormal 266 | from torch.distributions.normal import Normal 267 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 268 | 269 | 270 | 271 | from torch.distributions.normal import Normal 272 | from torch.distributions import Independent 273 | from torch.nn.parameter import Parameter 274 | 275 | def create_classifier(z0_dim, n_labels): 276 | return nn.Sequential( 277 | nn.Linear(z0_dim, 300), 278 | nn.ReLU(), 279 | nn.Linear(300, 300), 280 | nn.ReLU(), 281 | nn.Linear(300, n_labels),) 282 | 283 | class VAE_Baseline(nn.Module): 284 | def __init__(self, input_dim, latent_dim, 285 | z0_prior, device, 286 | obsrv_std = 0.01, 287 | use_binary_classif = False, 288 | classif_per_tp = False, 289 | use_poisson_proc = False, 290 | linear_classifier = False, 291 | n_labels = 1, 292 | train_classif_w_reconstr = False): 293 | 294 | super(VAE_Baseline, self).__init__() 295 | 296 | self.input_dim = input_dim 297 | self.latent_dim = latent_dim 298 | self.device = device 299 | self.n_labels = n_labels 300 | 301 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device) 302 | 303 | self.z0_prior = z0_prior 304 | self.use_binary_classif = use_binary_classif 305 | self.classif_per_tp = classif_per_tp 306 | self.use_poisson_proc = use_poisson_proc 307 | self.linear_classifier = linear_classifier 308 | self.train_classif_w_reconstr = train_classif_w_reconstr 309 | 310 | z0_dim = latent_dim 311 | if use_poisson_proc: 312 | z0_dim += latent_dim 313 | 314 | if use_binary_classif: 315 | if linear_classifier: 316 | self.classifier = nn.Sequential( 317 | nn.Linear(z0_dim, n_labels)) 318 | else: 319 | self.classifier = create_classifier(z0_dim, n_labels) 320 | utils.init_network_weights(self.classifier) 321 | 322 | 323 | def get_gaussian_likelihood(self, truth, pred_y, mask = None): 324 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 325 | # truth shape [n_traj, n_tp, n_dim] 326 | n_traj, n_tp, n_dim = truth.size() 327 | 328 | # Compute likelihood of the data under the predictions 329 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 330 | 331 | if mask is not None: 332 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 333 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated, 334 | obsrv_std = self.obsrv_std, mask = mask) 335 | log_density_data = log_density_data.permute(1,0) 336 | log_density = torch.mean(log_density_data, 1) 337 | 338 | # shape: [n_traj_samples] 339 | return log_density 340 | 341 | 342 | def get_mse(self, truth, pred_y, mask = None): 343 | 344 | # Compute likelihood of the data under the predictions 345 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 346 | 347 | if mask is not None: 348 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 349 | 350 | # Compute likelihood of the data under the predictions 351 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask) 352 | # shape: [1] 353 | return torch.mean(log_density_data) 354 | 355 | 356 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1.): 357 | # Condition on subsampled points 358 | # Make predictions for all the points 359 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"], 360 | batch_dict["observed_data"], batch_dict["observed_tp"], 361 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples, 362 | mode = batch_dict["mode"]) 363 | fp_mu, fp_std, fp_enc = info["first_point"] 364 | fp_std = fp_std.abs() 365 | fp_distr = Normal(fp_mu, fp_std) 366 | 367 | assert(torch.sum(fp_std < 0) == 0.) 368 | 369 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior) 370 | 371 | if torch.isnan(kldiv_z0).any(): 372 | print(fp_mu) 373 | print(fp_std) 374 | raise Exception("kldiv_z0 is Nan!") 375 | 376 | # Mean over number of latent dimensions 377 | kldiv_z0 = torch.mean(kldiv_z0,(1,2)) 378 | 379 | # Compute likelihood of all the points 380 | rec_likelihood = self.get_gaussian_likelihood( 381 | batch_dict["data_to_predict"], pred_y, 382 | mask = batch_dict["mask_predicted_data"]) 383 | 384 | mse = self.get_mse( 385 | batch_dict["data_to_predict"], pred_y, 386 | mask = batch_dict["mask_predicted_data"]) 387 | 388 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"])) 389 | if self.use_poisson_proc: 390 | pois_log_likelihood = compute_poisson_proc_likelihood( 391 | batch_dict["data_to_predict"], pred_y, 392 | info, mask = batch_dict["mask_predicted_data"]) 393 | # Take mean over n_traj 394 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1) 395 | 396 | 397 | # IWAE loss 398 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0) 399 | if torch.isnan(loss): 400 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0) 401 | 402 | results = {} 403 | results["loss"] = torch.mean(loss) 404 | results["likelihood"] = torch.mean(rec_likelihood).detach() 405 | results["mse"] = torch.mean(mse).detach() 406 | results["kl_first_p"] = torch.mean(kldiv_z0).detach() 407 | results["std_first_p"] = torch.mean(fp_std).detach() 408 | 409 | if batch_dict["labels"] is not None and self.use_binary_classif: 410 | results["label_predictions"] = info["label_predictions"].detach() 411 | 412 | return results -------------------------------------------------------------------------------- /src/model/cfm_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import ot as pot 8 | import torchdyn 9 | from torchdyn.core import NeuralODE 10 | import pytorch_lightning as pl 11 | from torch import optim 12 | import torch.functional as F 13 | import wandb 14 | from model.components.grad_util import * 15 | 16 | from utils.visualize import * 17 | from utils.metric_calc import * 18 | from utils.sde import SDE 19 | from model.components.positional_encoding import * 20 | from utils.loss import mse_loss, l1_loss 21 | 22 | 23 | 24 | class MLP(torch.nn.Module): 25 | def __init__(self, dim, out_dim=None, w=64, time_varying=False): 26 | super().__init__() 27 | self.time_varying = time_varying 28 | if out_dim is None: 29 | out_dim = dim 30 | self.net = torch.nn.Sequential( 31 | torch.nn.Linear(dim + (1 if time_varying else 0), w), 32 | torch.nn.SELU(), 33 | torch.nn.Linear(w, w), 34 | torch.nn.SELU(), 35 | torch.nn.Linear(w, w), 36 | torch.nn.SELU(), 37 | torch.nn.Linear(w, out_dim), 38 | ) 39 | 40 | def forward(self, x, *args, **kwargs): 41 | return self.net(x) 42 | 43 | 44 | # conditional liver model 45 | class MLP_conditional_liver(torch.nn.Module): 46 | """ Conditional with many available classes 47 | 48 | return the class as is 49 | """ 50 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False): 51 | super().__init__() 52 | self.time_varying = time_varying 53 | if out_dim is None: 54 | self.out_dim = dim 55 | self.treatment_cond = treatment_cond 56 | self.dim = dim 57 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0) 58 | self.net = torch.nn.Sequential( 59 | torch.nn.Linear(self.indim, w), 60 | torch.nn.SELU(), 61 | torch.nn.Linear(w, w), 62 | torch.nn.SELU(), 63 | torch.nn.Linear(w, w), 64 | torch.nn.SELU(), 65 | torch.nn.Linear(w,self.out_dim), 66 | ) 67 | self.default_class = 0 68 | 69 | 70 | def forward(self, x): 71 | """forward pass 72 | Assume first two dimensions are x, c, then t 73 | """ 74 | result = self.net(x) 75 | return torch.cat([result, x[:,self.dim:-1]], dim=1) 76 | 77 | class FM_baseline(torch.nn.Module): 78 | """ Conditional with many available classes 79 | 80 | return the class as is 81 | """ 82 | def __init__(self, dim, 83 | out_dim=None, 84 | w=64, 85 | time_varying=False, 86 | conditional=False, 87 | treatment_cond = 0, 88 | time_dim = NUM_FREQS * 2, 89 | clip = None): 90 | super().__init__() 91 | self.dim = dim 92 | self.time_varying = time_varying 93 | if out_dim is None: 94 | self.out_dim = dim 95 | self.out_dim += 1 96 | self.treatment_cond = treatment_cond 97 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) 98 | self.net = torch.nn.Sequential( 99 | torch.nn.Linear(self.indim, w), 100 | torch.nn.SELU(), 101 | torch.nn.Linear(w, w), 102 | torch.nn.SELU(), 103 | torch.nn.Linear(w, w), 104 | torch.nn.SELU(), 105 | torch.nn.Linear(w,self.out_dim), 106 | ) 107 | self.default_class = 0 108 | self.clip = clip 109 | 110 | def encoding_function(self, time_tensor): 111 | return positional_encoding_tensor(time_tensor) 112 | 113 | def forward_train(self, x): 114 | """forward pass 115 | Assume first two dimensions are x, c, then t 116 | input: x0 117 | output: vt 118 | """ 119 | time_tensor = x[:,-1] 120 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 121 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32) 122 | result = self.net(new_x) 123 | return torch.cat([result[:,:-1], x[:,self.dim:-1],result[:,-1].unsqueeze(1)], dim=1) 124 | 125 | def forward(self,x): 126 | """Function for simulation testing 127 | 128 | 129 | Args: 130 | x (_type_): x + time dimension 131 | 132 | Returns: 133 | forward_train(x)[:,:-1]: x without time dimension 134 | """ 135 | return self.forward_train(x)[:,:-1] 136 | 137 | 138 | 139 | """ Lightning module """ 140 | 141 | class MLP_CFM(pl.LightningModule): 142 | def __init__(self, 143 | treatment_cond, 144 | dim=2, 145 | w=64, 146 | time_varying=True, 147 | conditional=True, 148 | lr=1e-6, 149 | sigma = 0.1, 150 | loss_fn = mse_loss, 151 | metrics = ['mse_loss', 'l1_loss'], 152 | implementation = "ODE", # can be SDE 153 | sde_noise = 0.1, 154 | clip = None, # float 155 | naming = None, 156 | ): 157 | super().__init__() 158 | self.model = FM_baseline(dim=dim, 159 | w=w, 160 | time_varying=time_varying, 161 | conditional=conditional, # no conditional for baseline 162 | clip = clip, 163 | treatment_cond=treatment_cond) 164 | self.loss_fn = loss_fn 165 | self.save_hyperparameters() 166 | self.dim = dim 167 | # self.out_dim = out_dim 168 | self.w = w 169 | self.time_varying = time_varying 170 | self.conditional = conditional 171 | self.treatment_cond = treatment_cond 172 | self.lr = lr 173 | self.sigma = sigma 174 | self.naming = "CFM_baseline_"+implementation 175 | self.metrics = metrics 176 | self.implementation = implementation 177 | self.sde_noise = sde_noise 178 | self.clip = clip 179 | 180 | 181 | def __convert_tensor__(self, tensor): 182 | return tensor.to(torch.float32) 183 | 184 | def __x_processing__(self, x0, x1, t0, t1): 185 | x0 = x0.squeeze(0) 186 | x1 = x1.squeeze(0) 187 | t0 = t0.squeeze() 188 | t1 = t1.squeeze() 189 | 190 | t = torch.rand(x0.shape[0],1).to(x0.device) 191 | mu_t = x0 * (1 - t) + x1 * t 192 | data_t_diff = (t1 - t0).unsqueeze(1) 193 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device) 194 | ut = (x1 - x0) / (data_t_diff + 1e-4) 195 | t_model = t * data_t_diff + t0.unsqueeze(1) 196 | futuretime = t1 - t_model 197 | return x, ut, t_model, futuretime, t 198 | 199 | def training_step(self, batch, batch_idx): 200 | """_summary_ 201 | 202 | Args: 203 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1 204 | batch_idx (_type_): _description_ 205 | 206 | Returns: 207 | _type_: _description_ 208 | """ 209 | x0, x0_class, x1, x0_time, x1_time = batch 210 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time) 211 | 212 | 213 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time) 214 | 215 | 216 | if len(x0_class.shape) == 3: 217 | x0_class = x0_class.squeeze(0) 218 | 219 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1) 220 | vt = self.model.forward_train(in_tensor) 221 | 222 | # SDE: inject noise in the loss 223 | if self.implementation == "SDE": 224 | variance = t*(1-t)*(self.sde_noise ** 2) 225 | noise = torch.randn_like(vt[:,:self.dim]) * torch.sqrt(variance) 226 | loss = self.loss_fn(vt[:,:self.dim]+noise, ut) + self.loss_fn(vt[:,-1], futuretime) 227 | else: 228 | loss = self.loss_fn(vt[:,:self.dim], ut) + self.loss_fn(vt[:,-1], futuretime) 229 | self.log('train_loss', loss) 230 | return loss 231 | 232 | def config_optimizer(self): 233 | return torch.optim.Adam(self.parameters(), lr=self.lr) 234 | 235 | def validation_step(self, batch, batch_idx): 236 | """validation_step 237 | 238 | Args: 239 | batch (_type_): batch size of 1 (since uneven) 240 | batch_idx (_type_): _description_ 241 | 242 | Returns: 243 | _type_: _description_ 244 | """ 245 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val') 246 | self.log('val_loss', loss) 247 | for key, value in metricD.items(): 248 | self.log(key+"_val", value) 249 | return {'val_loss':loss, 'traj_pairs':pairs} 250 | 251 | def test_step(self, batch, batch_idx): 252 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test') 253 | self.log('test_loss', loss) 254 | for key, value in metricD.items(): 255 | self.log(key+"_test", value) 256 | return {'test_loss':loss, 'traj_pairs':pairs} 257 | 258 | def test_func_step(self, batch, batch_idx, mode='none'): 259 | """assuming each is one patient/batch""" 260 | total_loss = [] 261 | traj_pairs = [] 262 | 263 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch 264 | times_x0 = times_x0.squeeze() 265 | times_x1 = times_x1.squeeze() 266 | 267 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 268 | x1_values[0,:,:self.dim]], 269 | dim=0) 270 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0) 271 | ind_loss, pred_traj = self.test_trajectory(batch) 272 | total_loss.append(ind_loss) 273 | traj_pairs.append([full_traj, pred_traj]) 274 | 275 | full_traj = full_traj.detach().cpu().numpy() 276 | pred_traj = pred_traj.detach().cpu().numpy() 277 | full_time = full_time.detach().cpu().numpy() 278 | 279 | # graph 280 | fig = plot_3d_path_ind(pred_traj, 281 | full_traj, 282 | t_span=full_time, 283 | title="{}_trajectory_patient_{}".format(mode, batch_idx)) 284 | if self.logger: 285 | # may cause problem if wandb disabled 286 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)}) 287 | 288 | plt.close(fig) 289 | 290 | # metrics 291 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics) 292 | return np.mean(total_loss), traj_pairs, metricD 293 | 294 | def test_trajectory(self,pt_tensor): 295 | if self.implementation == "ODE": 296 | return self.test_trajectory_ode(pt_tensor) 297 | elif self.implementation == "SDE": 298 | return self.test_trajectory_sde(pt_tensor) 299 | 300 | def test_trajectory_ode(self,pt_tensor): 301 | """test_trajectory 302 | 303 | Args: 304 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 305 | 306 | 307 | Returns: 308 | mse_all, total_pred_tensor: _description_ 309 | """ 310 | node = NeuralODE( 311 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 312 | ) 313 | total_pred = [] 314 | mse = [] 315 | t_max = 0 316 | 317 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 318 | # squeeze all 319 | x0_values = x0_values.squeeze(0) 320 | x1_values = x1_values.squeeze(0) 321 | times_x0 = times_x0.squeeze() 322 | times_x1 = times_x1.squeeze() 323 | x0_classes = x0_classes.squeeze() 324 | 325 | if len(x0_classes.shape) == 1: 326 | x0_classes = x0_classes.unsqueeze(1) 327 | 328 | 329 | 330 | total_pred.append(x0_values[0].unsqueeze(0)) 331 | len_path = x0_values.shape[0] 332 | assert len_path == x1_values.shape[0] 333 | for i in range(len_path): 334 | t_max = (times_x1[i]-times_x0[i]) 335 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 336 | with torch.no_grad(): 337 | if i == 0: 338 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1) 339 | else: # incorporate last prediction 340 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1) 341 | traj = node.trajectory( 342 | testpt, 343 | t_span=time_span, 344 | ) 345 | pred_traj = traj[-1,:,:self.dim] 346 | total_pred.append(pred_traj) 347 | ground_truth_coords = x1_values[i] 348 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 349 | mse_all = np.mean(mse) 350 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 351 | return mse_all, total_pred_tensor 352 | 353 | def configure_optimizers(self): 354 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 355 | return optimizer 356 | 357 | def test_trajectory_sde(self,pt_tensor): 358 | """test_trajectory 359 | 360 | Args: 361 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 362 | 363 | 364 | Returns: 365 | mse_all, total_pred_tensor: _description_ 366 | """ 367 | sde = SDE(self.model, noise=self.sde_noise) 368 | total_pred = [] 369 | mse = [] 370 | t_max = 0 371 | 372 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 373 | # squeeze all 374 | x0_values = x0_values.squeeze(0) 375 | x1_values = x1_values.squeeze(0) 376 | times_x0 = times_x0.squeeze() 377 | times_x1 = times_x1.squeeze() 378 | x0_classes = x0_classes.squeeze() 379 | 380 | if len(x0_classes.shape) == 1: 381 | x0_classes = x0_classes.unsqueeze(1) 382 | 383 | 384 | 385 | total_pred.append(x0_values[0].unsqueeze(0)) 386 | len_path = x0_values.shape[0] 387 | assert len_path == x1_values.shape[0] 388 | for i in range(len_path): 389 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 390 | 391 | with torch.no_grad(): 392 | # get last pred, if none then use startpt 393 | if i == 0: 394 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1) 395 | else: # incorporate last prediction 396 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1) 397 | traj = self._sde_solver(sde, testpt, time_span) 398 | 399 | pred_traj = traj[-1,:,:self.dim] 400 | total_pred.append(pred_traj) 401 | ground_truth_coords = x1_values[i] 402 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 403 | mse_all = np.mean(mse) 404 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 405 | return mse_all, total_pred_tensor 406 | 407 | 408 | def _sde_solver(self, sde, initial_state, time_span): 409 | dt = time_span[1] - time_span[0] # Time step 410 | current_state = initial_state 411 | trajectory = [current_state] 412 | 413 | for t in time_span[1:]: 414 | drift = sde.f(t, current_state) 415 | diffusion = sde.g(t, current_state) 416 | noise = torch.randn_like(current_state) * torch.sqrt(dt) 417 | current_state = current_state + drift * dt + diffusion * noise 418 | trajectory.append(current_state) 419 | 420 | return torch.stack(trajectory) 421 | 422 | -------------------------------------------------------------------------------- /src/model/components/grad_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GradModel(torch.nn.Module): 4 | def __init__(self, action): 5 | super().__init__() 6 | self.action = action 7 | 8 | def forward(self, x): 9 | x = x.requires_grad_(True) 10 | grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0] 11 | return grad[:, :-1] 12 | 13 | 14 | class torch_wrapper(torch.nn.Module): 15 | """Wraps model to torchdyn compatible format.""" 16 | 17 | def __init__(self, model): 18 | super().__init__() 19 | self.model = model 20 | 21 | def forward(self, t, x): 22 | return self.model(x) 23 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) 24 | 25 | class torch_wrapper_tv(torch.nn.Module): 26 | """Wraps model to torchdyn compatible format.""" 27 | 28 | def __init__(self, model): 29 | super().__init__() 30 | self.model = model 31 | 32 | def forward(self, t, x, *args, **kwargs): 33 | # print(x.shape, t.shape) 34 | # print(t) 35 | # return self.model(x) 36 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) 37 | 38 | 39 | 40 | class torch_wrapper_cond(torch.nn.Module): 41 | """Wraps model to torchdyn compatible format.""" 42 | 43 | def __init__(self, model): 44 | super().__init__() 45 | self.model = model 46 | 47 | def forward(self, t, x, *args, **kwargs): 48 | # unpack the input 49 | # class_cond = torch.zeros(x.shape[0],1) 50 | input = torch.cat([x, t.repeat(x.shape[0])[:, None]], 1) 51 | print(input.shape) 52 | return self.model(input) -------------------------------------------------------------------------------- /src/model/components/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import ot as pot 8 | import torchdyn 9 | from torchdyn.core import NeuralODE 10 | import pytorch_lightning as pl 11 | from torch import optim 12 | import torch.functional as F 13 | import wandb 14 | 15 | from utils.visualize import * 16 | from utils.metric_calc import * 17 | from utils.sde import SDE 18 | 19 | from model.components.positional_encoding import * 20 | from model.components.grad_util import torch_wrapper_tv 21 | from utils.loss import mse_loss 22 | 23 | 24 | class MLP(torch.nn.Module): 25 | def __init__(self, dim, out_dim=None, w=64, time_varying=False): 26 | super().__init__() 27 | self.time_varying = time_varying 28 | if out_dim is None: 29 | out_dim = dim 30 | self.net = torch.nn.Sequential( 31 | torch.nn.Linear(dim + (1 if time_varying else 0), w), 32 | torch.nn.SELU(), 33 | torch.nn.Linear(w, w), 34 | torch.nn.SELU(), 35 | torch.nn.Linear(w, w), 36 | torch.nn.SELU(), 37 | torch.nn.Linear(w, out_dim), 38 | ) 39 | 40 | def forward(self, x, *args, **kwargs): 41 | return self.net(x) 42 | 43 | 44 | # conditional liver model 45 | class MLP_conditional_liver(torch.nn.Module): 46 | """ Conditional with many available classes 47 | 48 | return the class as is 49 | """ 50 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False): 51 | super().__init__() 52 | self.time_varying = time_varying 53 | if out_dim is None: 54 | self.out_dim = dim 55 | self.treatment_cond = treatment_cond 56 | self.indim = dim + (1 if time_varying else 0) + (self.treatment_cond if conditional else 0) 57 | self.net = torch.nn.Sequential( 58 | torch.nn.Linear(self.indim, w), 59 | torch.nn.SELU(), 60 | torch.nn.Linear(w, w), 61 | torch.nn.SELU(), 62 | torch.nn.Linear(w, w), 63 | torch.nn.SELU(), 64 | torch.nn.Linear(w,self.out_dim), 65 | ) 66 | self.default_class = 0 67 | 68 | 69 | def forward(self, x): 70 | """forward pass 71 | Assume first two dimensions are x, c, then t 72 | """ 73 | result = self.net(x) 74 | # print(result.shape) 75 | # print(x[:,2:-2].shape) 76 | return torch.cat([result, x[:,2:-1]], dim=1) 77 | 78 | class MLP_conditional_liver_pe(torch.nn.Module): 79 | """ Conditional with many available classes 80 | 81 | return the class as is 82 | """ 83 | def __init__(self, dim, treatment_cond, out_dim=None, w=64, time_varying=False, conditional=False, time_dim = NUM_FREQS * 2, clip = None): 84 | super().__init__() 85 | self.time_varying = time_varying 86 | if out_dim is None: 87 | self.out_dim = dim 88 | self.out_dim += 1 89 | self.treatment_cond = treatment_cond 90 | self.dim = dim 91 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) 92 | self.net = torch.nn.Sequential( 93 | torch.nn.Linear(self.indim, w), 94 | torch.nn.SELU(), 95 | torch.nn.Linear(w, w), 96 | torch.nn.SELU(), 97 | torch.nn.Linear(w, w), 98 | torch.nn.SELU(), 99 | torch.nn.Linear(w,self.out_dim), 100 | ) 101 | self.default_class = 0 102 | self.clip = clip 103 | 104 | def encoding_function(self, time_tensor): 105 | return positional_encoding_tensor(time_tensor) 106 | 107 | def forward_train(self, x): 108 | """forward pass 109 | Assume first two dimensions are x, c, then t 110 | """ 111 | time_tensor = x[:,-1] 112 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 113 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32) 114 | result = self.net(new_x) 115 | return torch.cat([result[:,:-1], x[:,self.dim:-1],result[:,-1].unsqueeze(1)], dim=1) 116 | 117 | def forward(self, x): 118 | """ call forward_train for training 119 | x here is x_t 120 | xt = (t)x1 + (1-t)x0 121 | (xt - tx1)/(1-t) = x0 122 | """ 123 | x1 = self.forward_train(x) 124 | x1_coord = x1[:,:self.dim] 125 | t = x[:,-1] 126 | pred_time_till_t1 = x1[:,-1] 127 | x_coord = x[:,:self.dim] 128 | if self.clip is None: 129 | vt = (x1_coord - x_coord)/(pred_time_till_t1) 130 | else: 131 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip) 132 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1) 133 | return final_vt 134 | 135 | class MLP_Cond_Module(pl.LightningModule): 136 | def __init__(self, 137 | treatment_cond, 138 | dim=2, 139 | w=64, 140 | time_varying=True, 141 | conditional=True, 142 | lr=1e-6, 143 | sigma = 0.1, 144 | loss_fn = mse_loss, 145 | metrics = ['mse_loss', 'l1_loss'], 146 | implementation = "ODE", # can be SDE 147 | sde_noise = 0.1, 148 | clip = None, # float 149 | naming = None, 150 | ): 151 | super().__init__() 152 | self.model = MLP_conditional_liver_pe(dim=dim, 153 | w=w, 154 | time_varying=time_varying, 155 | conditional=conditional, 156 | treatment_cond=treatment_cond, 157 | clip = clip) 158 | self.loss_fn = loss_fn 159 | self.save_hyperparameters() 160 | self.dim = dim 161 | # self.out_dim = out_dim 162 | self.w = w 163 | self.time_varying = time_varying 164 | self.conditional = conditional 165 | self.treatment_cond = treatment_cond 166 | self.lr = lr 167 | self.sigma = sigma 168 | self.naming = "MLP_Cond_Module_"+implementation if naming is None else naming 169 | self.metrics = metrics 170 | self.implementation = implementation 171 | self.sde_noise = sde_noise 172 | self.clip = clip 173 | 174 | 175 | def __convert_tensor__(self, tensor): 176 | return tensor.to(torch.float32) 177 | 178 | def __x_processing__(self, x0, x1, t0, t1): 179 | # squeeze xs (prevent mismatch) 180 | x0 = x0.squeeze(0) 181 | x1 = x1.squeeze(0) 182 | t0 = t0.squeeze() 183 | t1 = t1.squeeze() 184 | 185 | t = torch.rand(x0.shape[0],1).to(x0.device) 186 | mu_t = x0 * (1 - t) + x1 * t 187 | data_t_diff = (t1 - t0).unsqueeze(1) 188 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device) 189 | ut = (x1 - x0) / (data_t_diff + 1e-4) 190 | t_model = t * data_t_diff + t0.unsqueeze(1) 191 | futuretime = t1 - t_model 192 | return x, ut, t_model, futuretime, t 193 | 194 | def training_step(self, batch, batch_idx): 195 | """_summary_ 196 | 197 | Args: 198 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1 199 | batch_idx (_type_): _description_ 200 | 201 | Returns: 202 | _type_: _description_ 203 | """ 204 | x0, x0_class, x1, x0_time, x1_time = batch 205 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time) 206 | 207 | 208 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time) 209 | 210 | 211 | if len(x0_class.shape) == 3: 212 | x0_class = x0_class.squeeze(0) 213 | 214 | if self.conditional: 215 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1) 216 | else: 217 | in_tensor = torch.cat([x, t_model], dim = -1) 218 | xt = self.model.forward_train(in_tensor) 219 | 220 | if self.implementation == "SDE": 221 | variance = t*(1-t)*(self.sde_noise ** 2) 222 | noise = torch.randn_like(xt[:,:self.dim]) * torch.sqrt(variance) 223 | loss = self.loss_fn(xt[:,:self.dim]+noise, x1) + self.loss_fn(xt[:,-1], futuretime) 224 | else: 225 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime) 226 | self.log('train_loss', loss) 227 | return loss 228 | 229 | def config_optimizer(self): 230 | return torch.optim.Adam(self.parameters(), lr=self.lr) 231 | 232 | def validation_step(self, batch, batch_idx): 233 | """validation_step 234 | 235 | Args: 236 | batch (_type_): batch size of 1 (since uneven) 237 | batch_idx (_type_): _description_ 238 | 239 | Returns: 240 | _type_: _description_ 241 | """ 242 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val') 243 | self.log('val_loss', loss) 244 | for key, value in metricD.items(): 245 | self.log(key+"_val", value) 246 | # return total_loss, traj_pairs 247 | return {'val_loss':loss, 'traj_pairs':pairs} 248 | 249 | def test_step(self, batch, batch_idx): 250 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test') 251 | self.log('test_loss', loss) 252 | for key, value in metricD.items(): 253 | self.log(key+"_test", value) 254 | return {'test_loss':loss, 'traj_pairs':pairs} 255 | 256 | def test_func_step(self, batch, batch_idx, mode='none'): 257 | """assuming each is one patient/batch""" 258 | total_loss = [] 259 | traj_pairs = [] 260 | 261 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch 262 | times_x0 = times_x0.squeeze() 263 | times_x1 = times_x1.squeeze() 264 | 265 | # print(x0_values.shape) 266 | # print(x1_values.shape) 267 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 268 | x1_values[0,:,:self.dim]], 269 | dim=0) 270 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0) 271 | ind_loss, pred_traj = self.test_trajectory(batch) 272 | total_loss.append(ind_loss) 273 | traj_pairs.append([full_traj, pred_traj]) 274 | 275 | full_traj = full_traj.detach().cpu().numpy() 276 | pred_traj = pred_traj.detach().cpu().numpy() 277 | full_time = full_time.detach().cpu().numpy() 278 | 279 | # graph 280 | fig = plot_3d_path_ind(pred_traj, 281 | full_traj, 282 | t_span=full_time, 283 | title="{}_trajectory_patient_{}".format(mode, batch_idx)) 284 | if self.logger: 285 | # may cause problem if wandb disabled 286 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)}) 287 | 288 | plt.close(fig) 289 | 290 | # metrics 291 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics) 292 | return np.mean(total_loss), traj_pairs, metricD 293 | 294 | def test_trajectory(self,pt_tensor): 295 | if self.implementation == "ODE": 296 | return self.test_trajectory_ode(pt_tensor) 297 | elif self.implementation == "SDE": 298 | return self.test_trajectory_sde(pt_tensor) 299 | 300 | def test_trajectory_ode(self,pt_tensor): 301 | """test_trajectory 302 | 303 | Args: 304 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 305 | 306 | 307 | Returns: 308 | mse_all, total_pred_tensor: _description_ 309 | """ 310 | node = NeuralODE( 311 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 312 | ) 313 | total_pred = [] 314 | mse = [] 315 | t_max = 0 316 | 317 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 318 | # squeeze all 319 | x0_values = x0_values.squeeze(0) 320 | x1_values = x1_values.squeeze(0) 321 | times_x0 = times_x0.squeeze() 322 | times_x1 = times_x1.squeeze() 323 | x0_classes = x0_classes.squeeze() 324 | 325 | if len(x0_classes.shape) == 1: 326 | x0_classes = x0_classes.unsqueeze(1) 327 | 328 | 329 | 330 | total_pred.append(x0_values[0].unsqueeze(0)) 331 | len_path = x0_values.shape[0] 332 | assert len_path == x1_values.shape[0] 333 | for i in range(len_path): 334 | t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative) 335 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 336 | with torch.no_grad(): 337 | # get last pred, if none then use startpt 338 | if i == 0: 339 | if self.conditional: 340 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1) 341 | else: 342 | testpt = x0_values[i].unsqueeze(0) 343 | else: # incorporate last prediction 344 | if self.conditional: 345 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1) 346 | else: 347 | testpt = pred_traj 348 | traj = node.trajectory( 349 | testpt, 350 | t_span=time_span, 351 | ) 352 | pred_traj = traj[-1,:,:self.dim] 353 | total_pred.append(pred_traj) 354 | ground_truth_coords = x1_values[i] 355 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 356 | mse_all = np.mean(mse) 357 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 358 | return mse_all, total_pred_tensor 359 | 360 | def configure_optimizers(self): 361 | # Define the optimizer 362 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 363 | return optimizer 364 | 365 | def test_trajectory_sde(self,pt_tensor): 366 | """test_trajectory 367 | 368 | Args: 369 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 370 | 371 | 372 | Returns: 373 | mse_all, total_pred_tensor: _description_ 374 | """ 375 | sde = SDE(self.model, noise=0.1) 376 | total_pred = [] 377 | mse = [] 378 | t_max = 0 379 | 380 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 381 | # squeeze all 382 | x0_values = x0_values.squeeze(0) 383 | x1_values = x1_values.squeeze(0) 384 | times_x0 = times_x0.squeeze() 385 | times_x1 = times_x1.squeeze() 386 | x0_classes = x0_classes.squeeze() 387 | 388 | if len(x0_classes.shape) == 1: 389 | x0_classes = x0_classes.unsqueeze(1) 390 | 391 | 392 | 393 | total_pred.append(x0_values[0].unsqueeze(0)) 394 | len_path = x0_values.shape[0] 395 | assert len_path == x1_values.shape[0] 396 | for i in range(len_path): 397 | t_max = (times_x1[i]-times_x0[i]) 398 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 399 | 400 | with torch.no_grad(): 401 | # get last pred, if none then use startpt 402 | if i == 0: 403 | if self.conditional: 404 | testpt = torch.cat([x0_values[i].unsqueeze(0),x0_classes[i].unsqueeze(0)],dim=1) 405 | else: 406 | testpt = x0_values[i].unsqueeze(0) 407 | else: # incorporate last prediction 408 | if self.conditional: 409 | testpt = torch.cat([pred_traj, x0_classes[i].unsqueeze(0)], dim=1) 410 | else: 411 | testpt = pred_traj 412 | traj = self._sde_solver(sde, testpt, time_span) 413 | 414 | pred_traj = traj[-1,:,:self.dim] 415 | total_pred.append(pred_traj) 416 | ground_truth_coords = x1_values[i] 417 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 418 | mse_all = np.mean(mse) 419 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 420 | return mse_all, total_pred_tensor 421 | 422 | 423 | def _sde_solver(self, sde, initial_state, time_span): 424 | dt = time_span[1] - time_span[0] # Time step 425 | current_state = initial_state 426 | trajectory = [current_state] 427 | 428 | for t in time_span[1:]: 429 | drift = sde.f(t, current_state) 430 | diffusion = sde.g(t, current_state) 431 | noise = torch.randn_like(current_state) * torch.sqrt(dt) 432 | current_state = current_state + drift * dt + diffusion * noise 433 | trajectory.append(current_state) 434 | 435 | return torch.stack(trajectory) 436 | 437 | 438 | 439 | 440 | class MLP_conditional_liver_pe_memory(torch.nn.Module): 441 | """ Conditional with many available classes 442 | 443 | return the class as is 444 | """ 445 | def __init__(self, 446 | dim, 447 | treatment_cond, 448 | memory, # how many time steps 449 | out_dim=None, 450 | w=64, 451 | time_varying=False, 452 | conditional=False, 453 | time_dim = NUM_FREQS * 2, 454 | clip = None, 455 | ): 456 | super().__init__() 457 | self.time_varying = time_varying 458 | if out_dim is None: 459 | self.out_dim = dim 460 | self.out_dim += 1 # for the time dimension 461 | self.treatment_cond = treatment_cond 462 | self.memory = memory 463 | self.dim = dim 464 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory) 465 | self.net = torch.nn.Sequential( 466 | torch.nn.Linear(self.indim, w), 467 | torch.nn.SELU(), 468 | torch.nn.Linear(w, w), 469 | torch.nn.SELU(), 470 | torch.nn.Linear(w, w), 471 | torch.nn.SELU(), 472 | torch.nn.Linear(w,self.out_dim), 473 | ) 474 | self.default_class = 0 475 | self.clip = clip 476 | 477 | def encoding_function(self, time_tensor): 478 | return positional_encoding_tensor(time_tensor) 479 | 480 | def forward_train(self, x): 481 | """forward pass 482 | Assume first two dimensions are x, c, then t 483 | """ 484 | time_tensor = x[:,-1] 485 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 486 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1) 487 | result = self.net(new_x) 488 | return torch.cat([result[:,:-1], x[:,self.dim:-1], result[:,-1].unsqueeze(1)], dim=1) 489 | 490 | def forward(self, x): 491 | """ call forward_train for training 492 | x here is x_t 493 | xt = (t)x1 + (1-t)x0 494 | (xt - tx1)/(1-t) = x0 495 | """ 496 | x1 = self.forward_train(x) 497 | x1_coord = x1[:,:self.dim] 498 | t = x[:,-1] 499 | pred_time_till_t1 = x1[:,-1] 500 | x_coord = x[:,:self.dim] 501 | if self.clip is None: 502 | vt = (x1_coord - x_coord)/(pred_time_till_t1) 503 | else: 504 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip) 505 | 506 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1) 507 | return final_vt -------------------------------------------------------------------------------- /src/model/components/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | PE_BASE = 0.012 # 0.012615662610100801 4 | NUM_FREQS = 10 5 | 6 | def positional_encoding_tensor(time_tensor, num_frequencies=NUM_FREQS, base=PE_BASE): 7 | # Ensure the time tensor is in the range [0, 1] 8 | time_tensor = time_tensor.clamp(0, 1).unsqueeze(1) # Clamp and add dimension for broadcasting 9 | 10 | # Compute the arguments for the sine and cosine functions using the custom base 11 | frequencies = torch.pow(base, -torch.arange(0, num_frequencies, dtype=torch.float32) / num_frequencies).to(time_tensor.device) 12 | angles = time_tensor * frequencies 13 | 14 | # Compute the sine and cosine for even and odd indices respectively 15 | sine = torch.sin(angles) 16 | cosine = torch.cos(angles) 17 | 18 | # Stack them along the last dimension 19 | pos_encoding = torch.stack((sine, cosine), dim=-1) 20 | pos_encoding = pos_encoding.flatten(start_dim=2) 21 | 22 | # Normalize to have values between 0 and 1 23 | pos_encoding = (pos_encoding + 1) / 2 # Now values are between 0 and 1 24 | 25 | return pos_encoding 26 | 27 | def positional_encoding_df(df, col_mod = "time_normalized"): 28 | pe_tensors = torch.tensor(df[col_mod].values).astype(torch.float32) 29 | pos_encoding = positional_encoding_tensor(pe_tensors) 30 | pos_encoding_array = pos_encoding.numpy().reshape(-1, NUM_FREQS * 2) 31 | return pos_encoding_array 32 | -------------------------------------------------------------------------------- /src/model/components/sde_func_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SDE_func_solver(torch.nn.Module): 4 | 5 | noise_type = "diagonal" 6 | sde_type = "ito" 7 | 8 | # noise is sigma in this notebook for the equation sigma * (t * (1 - t)) 9 | def __init__(self, ode_drift, noise, reverse=False): 10 | super().__init__() 11 | self.drift = ode_drift 12 | self.reverse = reverse 13 | self.noise = noise # changeable, a model itself 14 | 15 | # Drift 16 | def f(self, t, y): 17 | if self.reverse: 18 | t = 1 - t 19 | if len(t.shape) == len(y.shape): 20 | x = torch.cat([y, t], 1) 21 | else: 22 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1) 23 | return self.drift(x) 24 | 25 | # Diffusion 26 | def g(self, t, y): 27 | if self.reverse: 28 | t = 1 - t 29 | if len(t.shape) == len(y.shape): 30 | x = torch.cat([y, t], 1) 31 | else: 32 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1) 33 | noise_result = self.noise(x) 34 | return noise_result* torch.sqrt(t * (1 - t)) -------------------------------------------------------------------------------- /src/model/latent_ode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | from torchdiffeq import odeint 5 | from utils.metric_calc import * 6 | from utils.latent_ode_utils import * 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, input_dim, latent_dim): 10 | super().__init__() 11 | self.net = nn.Sequential( 12 | nn.Linear(input_dim, 20), 13 | nn.Tanh(), 14 | nn.Linear(20, latent_dim), 15 | ) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | 22 | class Encoder_z0_ODE_RNN(nn.Module): 23 | # Derive z0 by running ode backwards. 24 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i 25 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1 26 | # Continue until we get to z0 27 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None, 28 | z0_dim = None, GRU_update = None, 29 | n_gru_units = 100, 30 | ): 31 | 32 | super(Encoder_z0_ODE_RNN, self).__init__() 33 | 34 | self.z0_dim = latent_dim 35 | self.GRU_update = GRU_unit(latent_dim, input_dim, 36 | n_units = n_gru_units) 37 | # device=device).to(device) 38 | 39 | self.z0_diffeq_solver = z0_diffeq_solver 40 | self.latent_dim = latent_dim 41 | self.input_dim = input_dim 42 | # self.device = device 43 | self.extra_info = None 44 | 45 | self.transform_z0 = nn.Sequential( 46 | nn.Linear(latent_dim * 2, 100), 47 | nn.Tanh(), 48 | nn.Linear(100, self.z0_dim * 2),) 49 | 50 | 51 | def forward(self, x1): 52 | # data, time_steps -- observations and their time stamps 53 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 54 | # xi is timepoint i of the data 55 | device_of_x1 = x1.device 56 | x1_len = x1.shape[1] 57 | 58 | prev_y = torch.zeros((1, x1_len, self.latent_dim)).to(device_of_x1) 59 | prev_std = torch.zeros((1, x1_len, self.latent_dim)).to(device_of_x1) 60 | 61 | # x1 = x1.unsqueeze(0) 62 | 63 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, x1) 64 | 65 | means_z0 = last_yi.reshape(1, x1_len, self.latent_dim) 66 | std_z0 = last_yi_std.reshape(1, x1_len, self.latent_dim) 67 | 68 | mean_z0, std_z0 = split_last_dim(self.transform_z0( torch.cat((means_z0, std_z0), -1))) 69 | std_z0 = std_z0.abs() 70 | 71 | return mean_z0, std_z0 72 | 73 | 74 | 75 | from torch.distributions.multivariate_normal import MultivariateNormal 76 | from torch.distributions.normal import Normal 77 | from torch.distributions import kl_divergence, Independent 78 | from model.base_models import VAE_Baseline 79 | 80 | 81 | class VAE(nn.Module): 82 | def __init__(self, encoder, decoder, ode_func, latent_dim): 83 | super(VAE, self).__init__() 84 | self.encoder = encoder 85 | self.decoder = decoder 86 | self.ode_func = ode_func 87 | self.latent_dim = latent_dim 88 | 89 | def forward(self, x, t, reverse=False): 90 | """Forward pass for training with the option to reverse time. 91 | 92 | Args: 93 | x: Input data (e.g., sequence of observations). 94 | t: Corresponding time points for the input data. 95 | reverse: Whether to reverse time for the ODE solver. 96 | """ 97 | if len(t.shape)>1: 98 | t = t.squeeze() 99 | 100 | if reverse: 101 | # Training mode: reverse time 102 | t = torch.flip(t, dims=[0]) 103 | x = torch.flip(x, dims=[1]) 104 | 105 | mu, std_z0 = self.encoder(x) 106 | z = self.reparameterize(mu, std_z0) 107 | z_pred = odeint(self.ode_func, z, t) 108 | x_pred = self.decoder(z_pred) 109 | 110 | if reverse: 111 | # Flip predictions back to original time order 112 | x_pred = torch.flip(x_pred, dims=[1]) 113 | 114 | return x_pred, mu, std_z0 115 | 116 | def reparameterize(self, mu, logvar): 117 | std = torch.exp(0.5 * logvar) 118 | eps = torch.randn_like(std) 119 | return mu + eps * std 120 | 121 | 122 | def extrapolate(self,z0, t): 123 | """predict future states - call during val/test 124 | 125 | Args: 126 | z0: Initial latent state from which to extrapolate. 127 | t: Future time points to predict. 128 | """ 129 | # do one at a time 130 | z_pred = odeint(self.ode_func, z0, t) 131 | return self.decoder(z_pred) 132 | 133 | 134 | 135 | class Decoder(nn.Module): 136 | def __init__(self, latent_dim, output_dim): 137 | super().__init__() 138 | self.net = nn.Sequential( 139 | nn.Linear(latent_dim, output_dim), 140 | ) 141 | 142 | def forward(self, x): 143 | return self.net(x) 144 | 145 | class LatentODEFunc(nn.Module): 146 | def __init__(self, latent_dim): 147 | super().__init__() 148 | self.net = nn.Sequential( 149 | nn.Linear(latent_dim, 50), 150 | nn.Tanh(), 151 | nn.Linear(50, 50), 152 | nn.Tanh(), 153 | nn.Linear(50, latent_dim), 154 | ) 155 | 156 | def forward(self, t, x): 157 | return self.net(x) 158 | 159 | class LatentODE_pl(pl.LightningModule): 160 | def __init__(self, 161 | input_dim, 162 | latent_dim, 163 | output_dim, 164 | lr=1e-3, 165 | loss_fn=nn.MSELoss(), 166 | metrics = ['mse_loss', 'l1_loss'], 167 | ): 168 | super().__init__() 169 | self.encoder = Encoder_z0_ODE_RNN(latent_dim,input_dim) 170 | self.decoder = Decoder(latent_dim, output_dim) 171 | self.vae = VAE(self.encoder, self.decoder, LatentODEFunc(latent_dim), latent_dim) 172 | self.lr = lr 173 | self.loss_fn = loss_fn 174 | self.naming = 'LatentODE_RNN' 175 | self.metrics = metrics 176 | 177 | def forward(self, x, x_time, mode='train'): 178 | if mode == 'train': 179 | reverse = True 180 | else: 181 | reverse = False 182 | x_pred, mu, std_z0 = self.vae(x, x_time, reverse=reverse) 183 | return x_pred, mu, std_z0 184 | 185 | def training_step(self, batch, batch_idx): 186 | x0, x0_class, x1, x0_time, x1_time = batch 187 | t_span = x1_time.squeeze() 188 | x_pred, mu, std_z0 = self.forward(x1, x1_time, mode='train') 189 | # loss = self.loss_fn(x_pred[-1], x1) 190 | loss = self.loss_function(x_pred[-1], x0, mu, std_z0) 191 | self.log('train_loss', loss) 192 | return loss 193 | 194 | def configure_optimizers(self): 195 | return torch.optim.Adam(self.parameters(), lr=self.lr) 196 | 197 | def validation_step(self, batch, batch_idx): 198 | x0, x0_class, x1, x0_time, x1_time = batch 199 | x_pred, x1 = self.testing_vae(batch) 200 | loss = self.loss_fn(x_pred[-1], x1) 201 | 202 | # metrics 203 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 204 | for key, value in metricD.items(): 205 | self.log(key+"_val", value) 206 | 207 | loss = self.loss_fn(x_pred[-1], x1) 208 | self.log('val_loss', loss) 209 | return loss 210 | 211 | def test_step(self, batch, batch_idx): 212 | x0, x0_class, x1, x0_time, x1_time = batch 213 | x_pred, x1 = self.testing_vae(batch) 214 | loss = self.loss_fn(x_pred[-1], x1) 215 | 216 | # Calculate metrics 217 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 218 | 219 | for key, value in metricD.items(): 220 | self.log(key+"_test", value) 221 | 222 | self.log('test_loss', loss) 223 | return loss 224 | 225 | 226 | 227 | def loss_function(self, recon_x, x, mu, logvar): 228 | recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') 229 | kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 230 | return recon_loss + kld_loss 231 | 232 | 233 | def sample(self, num_samples=1): 234 | """Sample from the latent space. 235 | 236 | Args: 237 | num_samples: Number of samples to generate. 238 | """ 239 | z = torch.randn(num_samples, self.latent_dim) 240 | return self.decoder(z) 241 | 242 | def testing_vae(self, batch): 243 | x0, x0_class, x1, x0_time, x1_time = batch 244 | t_span = x1_time.squeeze() 245 | z0_mean, z0_logvar = self.encoder(x0) 246 | z0 = self.vae.reparameterize(z0_mean, z0_logvar) # sample 247 | #@HERE: squeeze z0? 248 | x_pred = self.vae.extrapolate(z0, t_span) 249 | return x_pred, x1 250 | 251 | 252 | class LatentODE_deprecated(pl.LightningModule): 253 | def __init__(self, 254 | input_dim, 255 | latent_dim, 256 | output_dim, 257 | lr=1e-3, 258 | loss_fn=nn.MSELoss(), 259 | metrics = ['mse_loss', 'l1_loss']): 260 | super().__init__() 261 | self.encoder = Encoder(input_dim, latent_dim) 262 | self.decoder = Decoder(latent_dim, output_dim) 263 | self.ode_func = LatentODEFunc(latent_dim) 264 | self.lr = lr 265 | self.loss_fn = loss_fn 266 | self.naming = 'LatentODE' 267 | self.metrics = metrics 268 | 269 | def forward(self, x0, t_span): 270 | z0 = self.encoder(x0) 271 | z_pred = odeint(self.ode_func, z0, t_span) 272 | x_pred = self.decoder(z_pred) 273 | return x_pred 274 | 275 | def training_step(self, batch, batch_idx): 276 | x0, x0_class, x1, x0_time, x1_time = batch 277 | t_span = x1_time.squeeze() 278 | x_pred = self.forward(x0, t_span) 279 | loss = self.loss_fn(x_pred[-1], x1) 280 | self.log('train_loss', loss) 281 | return loss 282 | 283 | def configure_optimizers(self): 284 | return torch.optim.Adam(self.parameters(), lr=self.lr) 285 | 286 | def validation_step(self, batch, batch_idx): 287 | x0, x0_class, x1, x0_time, x1_time = batch 288 | t_span = x1_time.squeeze() 289 | x_pred = self.forward(x0, t_span) 290 | 291 | # metrics 292 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 293 | for key, value in metricD.items(): 294 | self.log(key+"_val", value) 295 | 296 | loss = self.loss_fn(x_pred[-1], x1) 297 | self.log('val_loss', loss) 298 | return loss 299 | 300 | def test_step(self, batch, batch_idx): 301 | x0, x0_class, x1, x0_time, x1_time = batch 302 | t_span = x1_time.squeeze() 303 | x_pred = self.forward(x0, t_span) 304 | loss = self.loss_fn(x_pred[-1], x1) 305 | 306 | # Calculate metrics 307 | metricD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 308 | 309 | for key, value in metricD.items(): 310 | self.log(key+"_test", value) 311 | 312 | self.log('test_loss', loss) 313 | return loss 314 | -------------------------------------------------------------------------------- /src/model/mlp_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import ot as pot 8 | import torchdyn 9 | from torchdyn.core import NeuralODE 10 | import pytorch_lightning as pl 11 | from torch import optim 12 | import torch.functional as F 13 | import wandb 14 | 15 | from utils.visualize import * 16 | from utils.metric_calc import * 17 | from utils.sde import SDE 18 | from model.components.positional_encoding import * 19 | from model.components.mlp import * # in case need wrapper 20 | from model.components.grad_util import torch_wrapper_tv 21 | from utils.loss import mse_loss, l1_loss 22 | 23 | 24 | 25 | class MLP_Cond_Memory_Module(pl.LightningModule): 26 | def __init__(self, 27 | treatment_cond, 28 | memory=3, # can increase / tune to see effect 29 | dim=2, 30 | w=64, 31 | time_varying=True, 32 | conditional=True, 33 | lr=1e-6, 34 | sigma = 0.1, 35 | loss_fn = mse_loss, 36 | metrics = ['mse_loss', 'l1_loss'], 37 | implementation = "ODE", # can be SDE 38 | sde_noise = 0.1, 39 | clip = None, 40 | naming = None, 41 | 42 | ): 43 | super().__init__() 44 | self.model = MLP_conditional_liver_pe_memory(dim=dim, 45 | w=w, 46 | time_varying=time_varying, 47 | conditional=conditional, 48 | treatment_cond=treatment_cond, 49 | memory=memory, 50 | clip = clip) 51 | self.loss_fn = loss_fn 52 | self.save_hyperparameters() 53 | self.dim = dim 54 | # self.out_dim = out_dim 55 | self.w = w 56 | self.time_varying = time_varying 57 | self.conditional = conditional 58 | self.treatment_cond = treatment_cond 59 | self.lr = lr 60 | self.sigma = sigma 61 | self.naming = "MLP_Cond_memory_Module_"+implementation if naming is None else naming 62 | self.metrics = metrics 63 | self.implementation = implementation 64 | self.memory = memory 65 | self.sde_noise = sde_noise 66 | self.clip = clip 67 | if self.memory > 1: 68 | self.naming += "_Memory_"+str(self.memory) 69 | 70 | def __convert_tensor__(self, tensor): 71 | return tensor.to(torch.float32) 72 | 73 | def __x_processing__(self, x0, x1, t0, t1): 74 | # squeeze xs (prevent mismatch) 75 | x0 = x0.squeeze(0) 76 | x1 = x1.squeeze(0) 77 | t0 = t0.squeeze() 78 | t1 = t1.squeeze() 79 | 80 | t = torch.rand(x0.shape[0],1).to(x0.device) 81 | mu_t = x0 * (1 - t) + x1 * t 82 | data_t_diff = (t1 - t0).unsqueeze(1) 83 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device) 84 | 85 | ut = (x1 - x0) / (data_t_diff + 1e-4) 86 | t_model = t * data_t_diff + t0.unsqueeze(1) 87 | futuretime = t1 - t_model 88 | 89 | return x, ut, t_model, futuretime, t 90 | 91 | def training_step(self, batch, batch_idx): 92 | """_summary_ 93 | 94 | Args: 95 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1 96 | batch_idx (_type_): _description_ 97 | 98 | Returns: 99 | _type_: _description_ 100 | """ 101 | x0, x0_class, x1, x0_time, x1_time = batch 102 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time) 103 | 104 | 105 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time) 106 | 107 | 108 | if len(x0_class.shape) == 3: 109 | x0_class = x0_class.squeeze() 110 | 111 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1) 112 | xt = self.model.forward_train(in_tensor) 113 | 114 | if self.implementation == "SDE": 115 | variance = t*(1-t)*self.sde_noise 116 | noise = torch.randn_like(xt[:,:self.dim]) * torch.sqrt(variance) 117 | loss = self.loss_fn(xt[:,:self.dim] + noise, x1) + self.loss_fn(xt[:,-1], futuretime) 118 | else: 119 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime) 120 | self.log('train_loss', loss) 121 | return loss 122 | 123 | def config_optimizer(self): 124 | return torch.optim.Adam(self.parameters(), lr=self.lr) 125 | 126 | def validation_step(self, batch, batch_idx): 127 | """validation_step 128 | 129 | Args: 130 | batch (_type_): batch size of 1 (since uneven) 131 | batch_idx (_type_): _description_ 132 | 133 | Returns: 134 | _type_: _description_ 135 | """ 136 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val') 137 | self.log('val_loss', loss) 138 | for key, value in metricD.items(): 139 | self.log(key+"_val", value) 140 | return {'val_loss':loss, 'traj_pairs':pairs} 141 | 142 | def test_step(self, batch, batch_idx): 143 | loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test') 144 | self.log('test_loss', loss) 145 | for key, value in metricD.items(): 146 | self.log(key+"_test", value) 147 | return {'test_loss':loss, 'traj_pairs':pairs} 148 | 149 | def test_func_step(self, batch, batch_idx, mode='none'): 150 | """assuming each is one patient/batch""" 151 | total_loss = [] 152 | traj_pairs = [] 153 | 154 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch 155 | times_x0 = times_x0.squeeze() 156 | times_x1 = times_x1.squeeze() 157 | 158 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 159 | x1_values[0,:,:self.dim]], 160 | dim=0) 161 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0) 162 | ind_loss, pred_traj = self.test_trajectory(batch) 163 | total_loss.append(ind_loss) 164 | traj_pairs.append([full_traj, pred_traj]) 165 | 166 | full_traj = full_traj.detach().cpu().numpy() 167 | pred_traj = pred_traj.detach().cpu().numpy() 168 | full_time = full_time.detach().cpu().numpy() 169 | 170 | # graph 171 | fig = plot_3d_path_ind(pred_traj, 172 | full_traj, 173 | t_span=full_time, 174 | title="{}_trajectory_patient_{}".format(mode, batch_idx)) 175 | if self.logger: 176 | # may cause problem if wandb disabled 177 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)}) 178 | 179 | plt.close(fig) 180 | 181 | # metrics 182 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics) 183 | return np.mean(total_loss), traj_pairs, metricD 184 | 185 | def test_trajectory(self,pt_tensor): 186 | if self.implementation == "ODE": 187 | return self.test_trajectory_ode(pt_tensor) 188 | elif self.implementation == "SDE": 189 | return self.test_trajectory_sde(pt_tensor) 190 | 191 | def test_trajectory_ode(self,pt_tensor): 192 | """test_trajectory 193 | 194 | Args: 195 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 196 | 197 | 198 | Returns: 199 | mse_all, total_pred_tensor: _description_ 200 | """ 201 | node = NeuralODE( 202 | torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 203 | ) 204 | total_pred = [] 205 | mse = [] 206 | t_max = 0 207 | 208 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 209 | # squeeze all 210 | x0_values = x0_values.squeeze(0) 211 | x1_values = x1_values.squeeze(0) 212 | times_x0 = times_x0.squeeze() 213 | times_x1 = times_x1.squeeze() 214 | x0_classes = x0_classes.squeeze() 215 | 216 | if len(x0_classes.shape) == 1: 217 | x0_classes = x0_classes.unsqueeze(1) 218 | 219 | 220 | total_pred.append(x0_values[0].unsqueeze(0)) 221 | len_path = x0_values.shape[0] 222 | assert len_path == x1_values.shape[0] 223 | 224 | time_history = x0_classes[0][-(self.memory*self.dim):] 225 | 226 | for i in range(len_path): 227 | # print(i) 228 | t_max = (times_x1[i]-times_x0[i]) 229 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 230 | 231 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1) 232 | with torch.no_grad(): 233 | # get last pred, if none then use startpt 234 | if i == 0: 235 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1) 236 | else: # incorporate last prediction 237 | testpt = torch.cat([pred_traj, new_x_classes], dim=1) 238 | # print(testpt.shape) 239 | traj = node.trajectory( 240 | testpt, 241 | t_span=time_span, 242 | ) 243 | pred_traj = traj[-1,:,:self.dim] 244 | total_pred.append(pred_traj) 245 | ground_truth_coords = x1_values[i] 246 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 247 | 248 | # history update 249 | flattened_coords = pred_traj.flatten() 250 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze() 251 | 252 | mse_all = np.mean(mse) 253 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 254 | return mse_all, total_pred_tensor 255 | 256 | def configure_optimizers(self): 257 | # Define the optimizer 258 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 259 | return optimizer 260 | 261 | def test_trajectory_sde(self,pt_tensor): 262 | """test_trajectory 263 | 264 | Args: 265 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 266 | 267 | 268 | Returns: 269 | mse_all, total_pred_tensor: _description_ 270 | """ 271 | sde = SDE(self.model, noise=0.1) 272 | total_pred = [] 273 | mse = [] 274 | t_max = 0 275 | 276 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 277 | # squeeze all 278 | x0_values = x0_values.squeeze(0) 279 | x1_values = x1_values.squeeze(0) 280 | times_x0 = times_x0.squeeze() 281 | times_x1 = times_x1.squeeze() 282 | x0_classes = x0_classes.squeeze() 283 | 284 | if len(x0_classes.shape) == 1: 285 | x0_classes = x0_classes.unsqueeze(1) 286 | 287 | 288 | 289 | total_pred.append(x0_values[0].unsqueeze(0)) 290 | len_path = x0_values.shape[0] 291 | assert len_path == x1_values.shape[0] 292 | 293 | time_history = x0_classes[0][-(self.memory*self.dim):] 294 | 295 | for i in range(len_path): 296 | 297 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 298 | 299 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1) 300 | with torch.no_grad(): 301 | # get last pred, if none then use startpt 302 | if i == 0: 303 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1) 304 | else: # incorporate last prediction 305 | testpt = torch.cat([pred_traj, new_x_classes], dim=1) 306 | # print(testpt.shape) 307 | traj = self._sde_solver(sde, testpt, time_span) 308 | 309 | pred_traj = traj[-1,:,:self.dim] 310 | total_pred.append(pred_traj) 311 | ground_truth_coords = x1_values[i] 312 | mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy()) 313 | 314 | # history update 315 | flattened_coords = pred_traj.flatten() 316 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze() 317 | 318 | mse_all = np.mean(mse) 319 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 320 | return mse_all, total_pred_tensor 321 | 322 | 323 | def _sde_solver(self, sde, initial_state, time_span): 324 | dt = time_span[1] - time_span[0] # Time step 325 | current_state = initial_state 326 | trajectory = [current_state] 327 | 328 | for t in time_span[1:]: 329 | drift = sde.f(t, current_state) 330 | diffusion = sde.g(t, current_state) 331 | noise = torch.randn_like(current_state) * torch.sqrt(dt) 332 | current_state = current_state + drift * dt + diffusion * noise 333 | trajectory.append(current_state) 334 | 335 | return torch.stack(trajectory) 336 | 337 | 338 | 339 | -------------------------------------------------------------------------------- /src/model/mlp_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import ot as pot 8 | import torchdyn 9 | from torchdyn.core import NeuralODE 10 | import pytorch_lightning as pl 11 | from torch import optim 12 | import torch.functional as F 13 | import wandb 14 | 15 | from utils.visualize import * 16 | from utils.metric_calc import * 17 | from utils.sde import SDE 18 | from model.components.positional_encoding import * 19 | from model.components.mlp import * 20 | from model.components.sde_func_solver import * 21 | from model.components.grad_util import * 22 | 23 | class MLP_conditional_memory(torch.nn.Module): 24 | """ Conditional with many available classes 25 | 26 | return the class as is 27 | """ 28 | def __init__(self, 29 | dim, 30 | treatment_cond, 31 | memory, # how many time steps 32 | out_dim=None, 33 | w=64, 34 | time_varying=False, 35 | conditional=False, 36 | time_dim = NUM_FREQS * 2, 37 | clip = None, 38 | ): 39 | super().__init__() 40 | self.time_varying = time_varying 41 | if out_dim is None: 42 | self.out_dim = dim 43 | self.out_dim += 1 # for the time dimension 44 | self.treatment_cond = treatment_cond 45 | self.memory = memory 46 | self.dim = dim 47 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory) 48 | self.net = torch.nn.Sequential( 49 | torch.nn.Linear(self.indim, w), 50 | torch.nn.SELU(), 51 | torch.nn.Linear(w, w), 52 | torch.nn.SELU(), 53 | torch.nn.Linear(w, w), 54 | torch.nn.SELU(), 55 | torch.nn.Linear(w,self.out_dim), 56 | ) 57 | self.default_class = 0 58 | self.clip = clip 59 | # self.encoding_function = positional_encoding_tensor() 60 | 61 | def encoding_function(self, time_tensor): 62 | return positional_encoding_tensor(time_tensor) 63 | 64 | def forward_train(self, x): 65 | """forward pass 66 | Assume first two dimensions are x, c, then t 67 | """ 68 | time_tensor = x[:,-1] 69 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 70 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1) 71 | result = self.net(new_x) 72 | return torch.cat([result[:,:-1], x[:,self.dim:-1], result[:,-1].unsqueeze(1)], dim=1) 73 | 74 | def forward(self, x): 75 | """ call forward_train for training 76 | x here is x_t 77 | xt = (t)x1 + (1-t)x0 78 | (xt - tx1)/(1-t) = x0 79 | """ 80 | x1 = self.forward_train(x) 81 | x1_coord = x1[:,:self.dim] 82 | t = x[:,-1] 83 | pred_time_till_t1 = x1[:,-1] 84 | x_coord = x[:,:self.dim] 85 | if self.clip is None: 86 | vt = (x1_coord - x_coord)/(pred_time_till_t1) 87 | else: 88 | vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip) 89 | 90 | final_vt = torch.cat([vt, torch.zeros_like(x[:,self.dim:-1])], dim=1) 91 | return final_vt 92 | 93 | 94 | class MLP_conditional_memory_sde_noise(torch.nn.Module): 95 | """ Conditional with many available classes 96 | 97 | return the class as is 98 | """ 99 | def __init__(self, 100 | dim, 101 | treatment_cond, 102 | memory, # how many time steps 103 | out_dim=None, 104 | w=64, 105 | time_varying=False, 106 | conditional=False, 107 | time_dim = NUM_FREQS * 2, 108 | clip = None, 109 | ): 110 | super().__init__() 111 | self.time_varying = time_varying 112 | if out_dim is None: 113 | self.out_dim = 1 # for noise 114 | self.treatment_cond = treatment_cond 115 | self.memory = memory 116 | self.dim = dim 117 | self.indim = dim + (time_dim if time_varying else 0) + (self.treatment_cond if conditional else 0) + (dim * memory) 118 | self.net = torch.nn.Sequential( 119 | torch.nn.Linear(self.indim, w), 120 | torch.nn.SELU(), 121 | torch.nn.Linear(w, w), 122 | torch.nn.SELU(), 123 | torch.nn.Linear(w, w), 124 | torch.nn.SELU(), 125 | torch.nn.Linear(w,self.out_dim), 126 | ) 127 | self.default_class = 0 128 | self.clip = clip 129 | # self.encoding_function = positional_encoding_tensor() 130 | 131 | def encoding_function(self, time_tensor): 132 | return positional_encoding_tensor(time_tensor) 133 | 134 | def forward_train(self, x): 135 | """forward pass 136 | Assume first two dimensions are x, c, then t 137 | """ 138 | time_tensor = x[:,-1] 139 | encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2) 140 | new_x = torch.cat([x[:,:-1], encoded_time_span], dim=1) 141 | result = self.net(new_x) 142 | return result 143 | 144 | def forward(self,x): 145 | result = self.forward_train(x) 146 | return torch.cat([result, torch.zeros_like(x[:,1:-1])], dim=1) 147 | 148 | """ Lightning module """ 149 | def mse_loss(pred, true): 150 | return torch.mean((pred - true) ** 2) 151 | 152 | def l1_loss(pred, true): 153 | return torch.mean(torch.abs(pred - true)) 154 | 155 | class Noise_MLP_Cond_Memory_Module(pl.LightningModule): 156 | def __init__(self, 157 | treatment_cond, 158 | memory=3, # can increase / tune to see effect 159 | dim=2, 160 | w=64, 161 | time_varying=True, 162 | conditional=True, 163 | lr=1e-6, 164 | sigma = 0.1, 165 | loss_fn = mse_loss, 166 | metrics = ['mse_loss', 'l1_loss'], 167 | implementation = "ODE", # can be SDE 168 | sde_noise = 0.1, 169 | clip = None, 170 | naming = None, 171 | 172 | ): 173 | super().__init__() 174 | self.flow_model = MLP_conditional_memory(dim=dim, 175 | w=w, 176 | time_varying=time_varying, 177 | conditional=conditional, 178 | treatment_cond=treatment_cond, 179 | memory=memory, 180 | clip = clip) 181 | if implementation == "SDE": 182 | self.noise_model = MLP_conditional_memory_sde_noise(dim=dim, # @TODO: give \hat{x_1}? 183 | w=w, 184 | time_varying=time_varying, 185 | conditional=conditional, 186 | treatment_cond=treatment_cond, 187 | memory=memory, 188 | clip = clip) 189 | else: 190 | self.noise_model = MLP_conditional_memory(dim=dim, # @TODO: give \hat{x_1}? 191 | w=w, 192 | time_varying=time_varying, 193 | conditional=conditional, 194 | treatment_cond=treatment_cond, 195 | memory=memory, 196 | clip = clip) 197 | self.automatic_optimization = False 198 | self.loss_fn = loss_fn 199 | self.save_hyperparameters() 200 | self.dim = dim 201 | # self.out_dim = out_dim 202 | self.w = w 203 | self.time_varying = time_varying 204 | self.conditional = conditional 205 | self.treatment_cond = treatment_cond 206 | self.lr = lr 207 | self.sigma = sigma 208 | self.naming = "Noise_MLP_Cond_memory_Module_"+implementation if naming is None else naming 209 | self.metrics = metrics 210 | self.implementation = implementation 211 | self.memory = memory 212 | self.sde_noise = sde_noise 213 | self.clip = clip 214 | if self.memory > 1: 215 | self.naming += "_Memory_"+str(self.memory) 216 | 217 | def __convert_tensor__(self, tensor): 218 | return tensor.to(torch.float32) 219 | 220 | def __x_processing__(self, x0, x1, t0, t1): 221 | # squeeze xs (prevent mismatch) 222 | x0 = x0.squeeze(0) 223 | x1 = x1.squeeze(0) 224 | t0 = t0.squeeze() 225 | t1 = t1.squeeze() 226 | 227 | t = torch.rand(x0.shape[0],1).to(x0.device) 228 | mu_t = x0 * (1 - t) + x1 * t 229 | data_t_diff = (t1 - t0).unsqueeze(1) 230 | x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device) 231 | 232 | ut = (x1 - x0) / (data_t_diff + 1e-4) 233 | t_model = t * data_t_diff + t0.unsqueeze(1) 234 | futuretime = t1 - t_model 235 | 236 | return x, ut, t_model, futuretime, t 237 | 238 | def training_step(self, batch, batch_idx): 239 | """_summary_ 240 | 241 | Args: 242 | batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1 243 | batch_idx (_type_): _description_ 244 | 245 | Returns: 246 | _type_: _description_ 247 | """ 248 | flow_opt, noise_opt = self.optimizers() 249 | 250 | x0, x0_class, x1, x0_time, x1_time = batch 251 | x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time) 252 | 253 | 254 | x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time) 255 | 256 | 257 | if len(x0_class.shape) == 3: 258 | x0_class = x0_class.squeeze() 259 | 260 | in_tensor = torch.cat([x,x0_class, t_model], dim = -1) 261 | xt = self.flow_model.forward_train(in_tensor) 262 | 263 | if self.implementation == "SDE": 264 | sde_noise = self.noise_model.forward_train(in_tensor) 265 | variance = torch.sqrt(t*(1-t))*sde_noise 266 | noise = torch.randn_like(xt[:,:self.dim]) * variance 267 | loss = self.loss_fn(xt[:,:self.dim] + noise.clone().detach(), x1) + self.loss_fn(xt[:,-1], futuretime) 268 | uncertainty =(xt[:,:self.dim].clone().detach() + noise) 269 | noise_loss = self.loss_fn(uncertainty,x1) 270 | else: 271 | loss = self.loss_fn(xt[:,:self.dim], x1) + self.loss_fn(xt[:,-1], futuretime) 272 | uncertainty = torch.abs(xt[:,:self.dim].clone().detach() - x1) 273 | # noise model incorporation (model loss) 274 | noise_loss = self.loss_fn(self.noise_model.forward_train(in_tensor)[:,:self.dim], uncertainty) 275 | 276 | flow_opt.zero_grad() 277 | self.manual_backward(loss) 278 | flow_opt.step() 279 | 280 | noise_opt.zero_grad() 281 | self.manual_backward(noise_loss) 282 | noise_opt.step() 283 | 284 | self.log('train_loss', loss) 285 | self.log('noise_loss', noise_loss) 286 | return loss + noise_loss 287 | 288 | def configure_optimizers(self): 289 | print("configuring optimizers") 290 | self.flow_optimizer = torch.optim.Adam(self.flow_model.parameters(), lr=self.lr) 291 | self.noise_optimizer = torch.optim.Adam(self.noise_model.parameters(), lr=self.lr) 292 | return [self.flow_optimizer, self.noise_optimizer] 293 | 294 | def validation_step(self, batch, batch_idx): 295 | """validation_step 296 | 297 | Args: 298 | batch (_type_): batch size of 1 (since uneven) 299 | batch_idx (_type_): _description_ 300 | 301 | Returns: 302 | _type_: _description_ 303 | """ 304 | loss, pairs, metricD, noise_loss, noise_pair = self.test_func_step(batch, batch_idx, mode='val') 305 | self.log('val_loss', loss, on_epoch=True, on_step=False, sync_dist=True) 306 | self.log('noise_val_loss', noise_loss, on_epoch=True, on_step=False, sync_dist=True) 307 | for key, value in metricD.items(): 308 | self.log(key+"_val", value, on_epoch=True, on_step=False, sync_dist=True) 309 | # return total_loss, traj_pairs 310 | return {'val_loss':loss, 'traj_pairs':pairs} 311 | 312 | def test_step(self, batch, batch_idx): 313 | loss, pairs, metricD, noise_loss, noise_pair = self.test_func_step(batch, batch_idx, mode='test') 314 | self.log('test_loss', loss, on_epoch=True, on_step=False, sync_dist=True) 315 | self.log('noise_test_loss', noise_loss, on_epoch=True, on_step=False, sync_dist=True) 316 | for key, value in metricD.items(): 317 | self.log(key+"_test", value, on_epoch=True, on_step=False, sync_dist=True) 318 | # return total_loss, traj_pairs 319 | return {'test_loss':loss, 'traj_pairs':pairs} 320 | 321 | def test_func_step(self, batch, batch_idx, mode='none'): 322 | """assuming each is one patient/batch""" 323 | total_loss = [] 324 | traj_pairs = [] 325 | 326 | total_noise_loss = [] 327 | noise_pairs = [] 328 | 329 | x0_values, x0_classes, x1_values, times_x0, times_x1 = batch 330 | times_x0 = times_x0.squeeze() 331 | times_x1 = times_x1.squeeze() 332 | 333 | # print(x0_values.shape) 334 | # print(x1_values.shape) 335 | full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 336 | x1_values[0,:,:self.dim]], 337 | dim=0) 338 | full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0) 339 | ind_loss, pred_traj, noise_mse, noise_pred = self.test_trajectory(batch) 340 | total_loss.append(ind_loss) 341 | traj_pairs.append([full_traj, pred_traj]) 342 | noise_pairs.append([full_traj, noise_pred]) 343 | total_noise_loss.append(noise_mse) 344 | 345 | full_traj = full_traj.detach().cpu().numpy() 346 | pred_traj = pred_traj.detach().cpu().numpy() 347 | full_time = full_time.detach().cpu().numpy() 348 | 349 | # graph 350 | fig = plot_3d_path_ind_noise(pred_traj, 351 | full_traj, 352 | noise_pred, 353 | t_span=full_time, 354 | title="{}_trajectory_patient_{}".format(mode, batch_idx)) 355 | if self.logger: 356 | # may cause problem if wandb disabled 357 | self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)}) 358 | 359 | plt.close(fig) 360 | 361 | # metrics 362 | metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics) 363 | return np.mean(total_loss), traj_pairs, metricD, np.mean(total_noise_loss), noise_pairs 364 | 365 | def test_trajectory(self,pt_tensor): 366 | if self.implementation == "ODE": 367 | return self.test_trajectory_ode(pt_tensor) 368 | elif self.implementation == "SDE": 369 | return self.test_trajectory_sde(pt_tensor) 370 | 371 | def test_trajectory_ode(self,pt_tensor): 372 | """test_trajectory 373 | 374 | Args: 375 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 376 | 377 | 378 | Returns: 379 | mse_all, total_pred_tensor: _description_ 380 | """ 381 | node = NeuralODE( 382 | torch_wrapper_tv(self.flow_model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 383 | ) 384 | node_noise = NeuralODE( 385 | torch_wrapper_tv(self.noise_model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4 386 | ) 387 | total_pred = [] 388 | noise_pred = [] 389 | mse = [] 390 | noise_mse = [] 391 | # t_max = 0 392 | 393 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 394 | # squeeze all 395 | x0_values = x0_values.squeeze(0) 396 | x1_values = x1_values.squeeze(0) 397 | times_x0 = times_x0.squeeze() 398 | times_x1 = times_x1.squeeze() 399 | x0_classes = x0_classes.squeeze() 400 | 401 | if len(x0_classes.shape) == 1: 402 | x0_classes = x0_classes.unsqueeze(1) 403 | 404 | total_pred.append(x0_values[0].unsqueeze(0)) 405 | len_path = x0_values.shape[0] 406 | assert len_path == x1_values.shape[0] 407 | 408 | time_history = x0_classes[0][-(self.memory*self.dim):] 409 | 410 | for i in range(len_path): 411 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 412 | 413 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1) 414 | with torch.no_grad(): 415 | # get last pred, if none then use startpt 416 | if i == 0: 417 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1) 418 | else: # incorporate last prediction 419 | testpt = torch.cat([pred_traj, new_x_classes], dim=1) 420 | # print(testpt.shape) 421 | traj = node.trajectory( 422 | testpt, 423 | t_span=time_span, 424 | ) 425 | # add noise prediction 426 | noise_traj = node_noise.trajectory( 427 | testpt, 428 | t_span=time_span, 429 | ) 430 | 431 | pred_traj = traj[-1,:,:self.dim] 432 | noise_traj = noise_traj[-1,:,:self.dim] 433 | total_pred.append(pred_traj) 434 | noise_pred.append(noise_traj) 435 | 436 | ground_truth_coords = x1_values[i] 437 | mse_traj = self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy() 438 | mse.append(mse_traj) 439 | uncertainty_traj = ground_truth_coords - pred_traj 440 | noise_mse_traj = self.loss_fn(noise_traj, uncertainty_traj).detach().cpu().numpy() 441 | noise_mse.append(noise_mse_traj) 442 | 443 | # history update 444 | flattened_coords = pred_traj.flatten() 445 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze() 446 | 447 | mse_all = np.mean(mse) 448 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 449 | noise_pred = torch.stack(noise_pred).squeeze(1) 450 | return mse_all, total_pred_tensor, noise_mse, noise_pred 451 | 452 | 453 | def test_trajectory_sde(self,pt_tensor): 454 | """test_trajectory 455 | 456 | Args: 457 | pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1), 458 | 459 | 460 | Returns: 461 | mse_all, total_pred_tensor: _description_ 462 | """ 463 | sde = SDE_func_solver(self.flow_model, noise=self.noise_model) 464 | total_pred = [] 465 | mse = [] 466 | noise_pred = [] 467 | noise_mse = [] 468 | t_max = 0 469 | 470 | x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor 471 | # squeeze all 472 | x0_values = x0_values.squeeze(0) 473 | x1_values = x1_values.squeeze(0) 474 | times_x0 = times_x0.squeeze() 475 | times_x1 = times_x1.squeeze() 476 | x0_classes = x0_classes.squeeze() 477 | 478 | if len(x0_classes.shape) == 1: 479 | x0_classes = x0_classes.unsqueeze(1) 480 | 481 | 482 | 483 | total_pred.append(x0_values[0].unsqueeze(0)) 484 | len_path = x0_values.shape[0] 485 | assert len_path == x1_values.shape[0] 486 | 487 | time_history = x0_classes[0][-(self.memory*self.dim):] 488 | 489 | for i in range(len_path): 490 | 491 | time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device) 492 | 493 | new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1) 494 | with torch.no_grad(): 495 | # get last pred, if none then use startpt 496 | if i == 0: 497 | testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1) 498 | else: # incorporate last prediction 499 | testpt = torch.cat([pred_traj, new_x_classes], dim=1) 500 | traj, noise_traj = self._sde_solver(sde, testpt, time_span) 501 | 502 | pred_traj = traj[-1,:,:self.dim] 503 | noise_traj = noise_traj[-1,:,:self.dim] 504 | 505 | total_pred.append(pred_traj) 506 | noise_pred.append(noise_traj) 507 | 508 | ground_truth_coords = x1_values[i] 509 | calculated_mse = self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy() 510 | mse.append(calculated_mse) 511 | noise_mse.append(calculated_mse) 512 | 513 | # history update 514 | flattened_coords = pred_traj.flatten() 515 | time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze() 516 | 517 | 518 | 519 | mse_all = np.mean(mse) 520 | noise_mse_all = np.mean(noise_mse) 521 | total_pred_tensor = torch.stack(total_pred).squeeze(1) 522 | noise_pred_tensor = torch.stack(noise_pred).squeeze(1) 523 | return mse_all, total_pred_tensor, noise_mse_all, noise_pred_tensor 524 | 525 | 526 | def _sde_solver(self, sde, initial_state, time_span): 527 | dt = time_span[1] - time_span[0] # Time step 528 | current_state = initial_state 529 | trajectory = [current_state] 530 | noise_trajectory = [] 531 | 532 | for t in time_span[1:]: 533 | drift = sde.f(t, current_state) 534 | diffusion = sde.g(t, current_state) 535 | noise = torch.randn_like(current_state) * torch.sqrt(dt) 536 | current_state = current_state + drift * dt + diffusion * noise # @NEED this or not? 537 | trajectory.append(current_state) 538 | pred_diff = diffusion * noise 539 | noise_trajectory.append(pred_diff) 540 | 541 | return torch.stack(trajectory), torch.stack(noise_trajectory) 542 | 543 | 544 | 545 | 546 | -------------------------------------------------------------------------------- /src/model/ode_baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Neural ODE and SDE baseline models. 4 | lr = 1e-3 5 | 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import pytorch_lightning as pl 12 | from torchdiffeq import odeint 13 | from utils.metric_calc import * 14 | class ODEFunc(nn.Module): 15 | def __init__(self, dim, w=64): 16 | super().__init__() 17 | self.net = nn.Sequential( 18 | nn.Linear(dim, w), 19 | nn.Tanh(), 20 | nn.Linear(w, w), 21 | nn.Tanh(), 22 | nn.Linear(w, dim), 23 | ) 24 | 25 | def forward(self, t, x): 26 | return self.net(x) 27 | 28 | class ODEBaseline(pl.LightningModule): 29 | def __init__(self, 30 | dim=2, 31 | w=64, 32 | lr=1e-5, 33 | loss_fn=nn.MSELoss(), 34 | metrics = ['mse_loss', 'l1_loss']): 35 | super().__init__() 36 | self.ode_func = ODEFunc(dim, w) 37 | self.lr = lr 38 | self.loss_fn = loss_fn 39 | self.naming = 'ODEBaseline' 40 | self.metrics = metrics 41 | 42 | def forward(self, x0, t_span): 43 | return odeint(self.ode_func, x0, t_span) 44 | 45 | def training_step(self, batch, batch_idx): 46 | """x0, x0_class, x1, x0_time, x1_time """ 47 | x0, x0_class, x1, x0_time, x1_time = batch 48 | t_span = x1_time.squeeze() #- x0_time 49 | print("training_step") 50 | # print(x0.shape, x1.shape, t_span.shape) # torch.Size([256, 2]) torch.Size([256, 2]) torch.Size([256]) 51 | x_pred = self.forward(x0, t_span) 52 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 53 | self.log('train_loss', loss) 54 | return loss 55 | 56 | def configure_optimizers(self): 57 | return torch.optim.Adam(self.parameters(), lr=self.lr) 58 | 59 | def validation_step(self, batch, batch_idx): 60 | x0, x0_class, x1, x0_time, x1_time = batch 61 | t_span = x1_time.squeeze() 62 | x_pred = self.forward(x0, t_span) 63 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 64 | 65 | # metrics 66 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 67 | for k, v in metricsD.items(): 68 | self.log(f'{k}_val', v) 69 | 70 | self.log('val_loss', loss) 71 | return loss 72 | 73 | def test_step(self, batch, batch_idx): 74 | x0, x0_class, x1, x0_time, x1_time = batch 75 | t_span = x1_time.squeeze() 76 | x_pred = self.forward(x0, t_span) 77 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 78 | # metrics 79 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 80 | for k, v in metricsD.items(): 81 | self.log(f'{k}_test', v) 82 | 83 | self.log('test_loss', loss) 84 | return loss 85 | 86 | # sde 87 | import torchsde 88 | 89 | class SDEFunc(nn.Module): 90 | noise_type = 'diagonal' # diagonal = noise is uncorrelated across dimensions 91 | sde_type = 'ito' 92 | def __init__(self, dim, w=64): 93 | super().__init__() 94 | self.mu = nn.Sequential( 95 | nn.Linear(dim, w), 96 | nn.Tanh(), 97 | nn.Linear(w, w), 98 | nn.Tanh(), 99 | nn.Linear(w, dim), 100 | ) 101 | self.sigma = nn.Sequential( 102 | nn.Linear(dim, w), 103 | nn.Tanh(), 104 | nn.Linear(w, w), 105 | nn.Tanh(), 106 | nn.Linear(w, dim), 107 | ) 108 | 109 | def f(self, t, x): 110 | return self.mu(x) 111 | 112 | def g(self, t, x): 113 | return self.sigma(x) 114 | 115 | class SDEBaseline(pl.LightningModule): 116 | def __init__(self, 117 | dim=2, 118 | w=64, 119 | lr=1e-5, 120 | loss_fn=nn.MSELoss(), 121 | metrics = ['mse_loss', 'l1_loss']): 122 | super().__init__() 123 | self.sde_func = SDEFunc(dim, w) 124 | self.lr = lr 125 | self.loss_fn = loss_fn 126 | self.naming = 'SDEBaseline' 127 | self.metrics = metrics 128 | 129 | def forward(self, x0, t_span): 130 | # print(x0.shape) 131 | x0 = x0.squeeze() 132 | batch_size, dim = x0.shape 133 | bm = torchsde.BrownianInterval(t0=t_span[0], 134 | t1=t_span[-1], 135 | dtype=x0.dtype, 136 | device=x0.device, 137 | size=(batch_size, dim), 138 | levy_area_approximation="space-time") 139 | return torchsde.sdeint(self.sde_func, x0, t_span, bm=bm) 140 | 141 | def training_step(self, batch, batch_idx): 142 | x0, x0_class, x1, x0_time, x1_time = batch 143 | t_span = x1_time.squeeze() #- x0_time 144 | # x0, x1, t_span = batch 145 | x_pred = self.forward(x0, t_span) 146 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 147 | self.log('train_loss', loss) 148 | return loss 149 | 150 | def configure_optimizers(self): 151 | return torch.optim.Adam(self.parameters(), lr=self.lr) 152 | 153 | def validation_step(self, batch, batch_idx): 154 | x0, x0_class, x1, x0_time, x1_time = batch 155 | t_span = x1_time.squeeze() 156 | x_pred = self.forward(x0, t_span) 157 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 158 | # metrics 159 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 160 | for k, v in metricsD.items(): 161 | self.log(f'{k}_val', v) 162 | 163 | self.log('val_loss', loss) 164 | return loss 165 | 166 | def test_step(self, batch, batch_idx): 167 | x0, x0_class, x1, x0_time, x1_time = batch 168 | t_span = x1_time.squeeze() 169 | x_pred = self.forward(x0, t_span) 170 | loss = self.loss_fn(x_pred[-1], x1.squeeze()) 171 | # metrics 172 | metricsD = metrics_calculation(x_pred[-1], x1, metrics=self.metrics) 173 | for k, v in metricsD.items(): 174 | self.log(f'{k}_test', v) 175 | 176 | self.log('test_loss', loss) 177 | return loss -------------------------------------------------------------------------------- /src/utils/latent_ode_utils.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import os 7 | import logging 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | import pandas as pd 14 | import math 15 | import glob 16 | import re 17 | from shutil import copyfile 18 | import sklearn as sk 19 | import subprocess 20 | import datetime 21 | 22 | def makedirs(dirname): 23 | if not os.path.exists(dirname): 24 | os.makedirs(dirname) 25 | 26 | def save_checkpoint(state, save, epoch): 27 | if not os.path.exists(save): 28 | os.makedirs(save) 29 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 30 | torch.save(state, filename) 31 | 32 | 33 | def get_logger(logpath, filepath, package_files=[], 34 | displaying=True, saving=True, debug=False): 35 | logger = logging.getLogger() 36 | if debug: 37 | level = logging.DEBUG 38 | else: 39 | level = logging.INFO 40 | logger.setLevel(level) 41 | if saving: 42 | info_file_handler = logging.FileHandler(logpath, mode='w') 43 | info_file_handler.setLevel(level) 44 | logger.addHandler(info_file_handler) 45 | if displaying: 46 | console_handler = logging.StreamHandler() 47 | console_handler.setLevel(level) 48 | logger.addHandler(console_handler) 49 | logger.info(filepath) 50 | 51 | for f in package_files: 52 | logger.info(f) 53 | with open(f, 'r') as package_f: 54 | logger.info(package_f.read()) 55 | 56 | return logger 57 | 58 | 59 | def inf_generator(iterable): 60 | """Allows training with DataLoaders in a single infinite loop: 61 | for i, (x, y) in enumerate(inf_generator(train_loader)): 62 | """ 63 | iterator = iterable.__iter__() 64 | while True: 65 | try: 66 | yield iterator.__next__() 67 | except StopIteration: 68 | iterator = iterable.__iter__() 69 | 70 | def dump_pickle(data, filename): 71 | with open(filename, 'wb') as pkl_file: 72 | pickle.dump(data, pkl_file) 73 | 74 | def load_pickle(filename): 75 | with open(filename, 'rb') as pkl_file: 76 | filecontent = pickle.load(pkl_file) 77 | return filecontent 78 | 79 | def make_dataset(dataset_type = "spiral",**kwargs): 80 | if dataset_type == "spiral": 81 | data_path = "data/spirals.pickle" 82 | dataset = load_pickle(data_path)["dataset"] 83 | chiralities = load_pickle(data_path)["chiralities"] 84 | elif dataset_type == "chiralspiral": 85 | data_path = "data/chiral-spirals.pickle" 86 | dataset = load_pickle(data_path)["dataset"] 87 | chiralities = load_pickle(data_path)["chiralities"] 88 | else: 89 | raise Exception("Unknown dataset type " + dataset_type) 90 | return dataset, chiralities 91 | 92 | 93 | def split_last_dim(data): 94 | last_dim = data.size()[-1] 95 | last_dim = last_dim//2 96 | 97 | if len(data.size()) == 3: 98 | res = data[:,:,:last_dim], data[:,:,last_dim:] 99 | 100 | if len(data.size()) == 2: 101 | res = data[:,:last_dim], data[:,last_dim:] 102 | return res 103 | 104 | 105 | def init_network_weights(net, std = 0.1): 106 | for m in net.modules(): 107 | if isinstance(m, nn.Linear): 108 | nn.init.normal_(m.weight, mean=0, std=std) 109 | nn.init.constant_(m.bias, val=0) 110 | 111 | 112 | def flatten(x, dim): 113 | return x.reshape(x.size()[:dim] + (-1, )) 114 | 115 | 116 | def subsample_timepoints(data, time_steps, mask, n_tp_to_sample = None): 117 | # n_tp_to_sample: number of time points to subsample. If not None, sample exactly n_tp_to_sample points 118 | if n_tp_to_sample is None: 119 | return data, time_steps, mask 120 | n_tp_in_batch = len(time_steps) 121 | 122 | 123 | if n_tp_to_sample > 1: 124 | # Subsample exact number of points 125 | assert(n_tp_to_sample <= n_tp_in_batch) 126 | n_tp_to_sample = int(n_tp_to_sample) 127 | 128 | for i in range(data.size(0)): 129 | missing_idx = sorted(np.random.choice(np.arange(n_tp_in_batch), n_tp_in_batch - n_tp_to_sample, replace = False)) 130 | 131 | data[i, missing_idx] = 0. 132 | if mask is not None: 133 | mask[i, missing_idx] = 0. 134 | 135 | elif (n_tp_to_sample <= 1) and (n_tp_to_sample > 0): 136 | # Subsample percentage of points from each time series 137 | percentage_tp_to_sample = n_tp_to_sample 138 | for i in range(data.size(0)): 139 | # take mask for current training sample and sum over all features -- figure out which time points don't have any measurements at all in this batch 140 | current_mask = mask[i].sum(-1).cpu() 141 | non_missing_tp = np.where(current_mask > 0)[0] 142 | n_tp_current = len(non_missing_tp) 143 | n_to_sample = int(n_tp_current * percentage_tp_to_sample) 144 | subsampled_idx = sorted(np.random.choice(non_missing_tp, n_to_sample, replace = False)) 145 | tp_to_set_to_zero = np.setdiff1d(non_missing_tp, subsampled_idx) 146 | 147 | data[i, tp_to_set_to_zero] = 0. 148 | if mask is not None: 149 | mask[i, tp_to_set_to_zero] = 0. 150 | 151 | return data, time_steps, mask 152 | 153 | 154 | 155 | def cut_out_timepoints(data, time_steps, mask, n_points_to_cut = None): 156 | # n_points_to_cut: number of consecutive time points to cut out 157 | if n_points_to_cut is None: 158 | return data, time_steps, mask 159 | n_tp_in_batch = len(time_steps) 160 | 161 | if n_points_to_cut < 1: 162 | raise Exception("Number of time points to cut out must be > 1") 163 | 164 | assert(n_points_to_cut <= n_tp_in_batch) 165 | n_points_to_cut = int(n_points_to_cut) 166 | 167 | for i in range(data.size(0)): 168 | start = np.random.choice(np.arange(5, n_tp_in_batch - n_points_to_cut-5), replace = False) 169 | 170 | data[i, start : (start + n_points_to_cut)] = 0. 171 | if mask is not None: 172 | mask[i, start : (start + n_points_to_cut)] = 0. 173 | 174 | return data, time_steps, mask 175 | 176 | 177 | 178 | 179 | 180 | def get_device(tensor): 181 | device = torch.device("cpu") 182 | if tensor.is_cuda: 183 | device = tensor.get_device() 184 | return device 185 | 186 | def sample_standard_gaussian(mu, sigma): 187 | device = get_device(mu) 188 | 189 | d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device)) 190 | r = d.sample(mu.size()).squeeze(-1) 191 | return r * sigma.float() + mu.float() 192 | 193 | 194 | def split_train_test(data, train_fraq = 0.8): 195 | n_samples = data.size(0) 196 | data_train = data[:int(n_samples * train_fraq)] 197 | data_test = data[int(n_samples * train_fraq):] 198 | return data_train, data_test 199 | 200 | def split_train_test_data_and_time(data, time_steps, train_fraq = 0.8): 201 | n_samples = data.size(0) 202 | data_train = data[:int(n_samples * train_fraq)] 203 | data_test = data[int(n_samples * train_fraq):] 204 | 205 | assert(len(time_steps.size()) == 2) 206 | train_time_steps = time_steps[:, :int(n_samples * train_fraq)] 207 | test_time_steps = time_steps[:, int(n_samples * train_fraq):] 208 | 209 | return data_train, data_test, train_time_steps, test_time_steps 210 | 211 | 212 | 213 | def get_next_batch(dataloader): 214 | # Make the union of all time points and perform normalization across the whole dataset 215 | data_dict = dataloader.__next__() 216 | 217 | batch_dict = get_dict_template() 218 | 219 | # remove the time points where there are no observations in this batch 220 | non_missing_tp = torch.sum(data_dict["observed_data"],(0,2)) != 0. 221 | batch_dict["observed_data"] = data_dict["observed_data"][:, non_missing_tp] 222 | batch_dict["observed_tp"] = data_dict["observed_tp"][non_missing_tp] 223 | 224 | # print("observed data") 225 | # print(batch_dict["observed_data"].size()) 226 | 227 | if ("observed_mask" in data_dict) and (data_dict["observed_mask"] is not None): 228 | batch_dict["observed_mask"] = data_dict["observed_mask"][:, non_missing_tp] 229 | 230 | batch_dict[ "data_to_predict"] = data_dict["data_to_predict"] 231 | batch_dict["tp_to_predict"] = data_dict["tp_to_predict"] 232 | 233 | non_missing_tp = torch.sum(data_dict["data_to_predict"],(0,2)) != 0. 234 | batch_dict["data_to_predict"] = data_dict["data_to_predict"][:, non_missing_tp] 235 | batch_dict["tp_to_predict"] = data_dict["tp_to_predict"][non_missing_tp] 236 | 237 | # print("data_to_predict") 238 | # print(batch_dict["data_to_predict"].size()) 239 | 240 | if ("mask_predicted_data" in data_dict) and (data_dict["mask_predicted_data"] is not None): 241 | batch_dict["mask_predicted_data"] = data_dict["mask_predicted_data"][:, non_missing_tp] 242 | 243 | if ("labels" in data_dict) and (data_dict["labels"] is not None): 244 | batch_dict["labels"] = data_dict["labels"] 245 | 246 | batch_dict["mode"] = data_dict["mode"] 247 | return batch_dict 248 | 249 | 250 | 251 | def get_ckpt_model(ckpt_path, model, device): 252 | if not os.path.exists(ckpt_path): 253 | raise Exception("Checkpoint " + ckpt_path + " does not exist.") 254 | # Load checkpoint. 255 | checkpt = torch.load(ckpt_path) 256 | ckpt_args = checkpt['args'] 257 | state_dict = checkpt['state_dict'] 258 | model_dict = model.state_dict() 259 | 260 | # 1. filter out unnecessary keys 261 | state_dict = {k: v for k, v in state_dict.items() if k in model_dict} 262 | # 2. overwrite entries in the existing state dict 263 | model_dict.update(state_dict) 264 | # 3. load the new state dict 265 | model.load_state_dict(state_dict) 266 | model.to(device) 267 | 268 | 269 | def update_learning_rate(optimizer, decay_rate = 0.999, lowest = 1e-3): 270 | for param_group in optimizer.param_groups: 271 | lr = param_group['lr'] 272 | lr = max(lr * decay_rate, lowest) 273 | param_group['lr'] = lr 274 | 275 | 276 | def linspace_vector(start, end, n_points): 277 | # start is either one value or a vector 278 | size = np.prod(start.size()) 279 | 280 | assert(start.size() == end.size()) 281 | if size == 1: 282 | # start and end are 1d-tensors 283 | res = torch.linspace(start, end, n_points) 284 | else: 285 | # start and end are vectors 286 | res = torch.Tensor() 287 | for i in range(0, start.size(0)): 288 | res = torch.cat((res, 289 | torch.linspace(start[i], end[i], n_points)),0) 290 | res = torch.t(res.reshape(start.size(0), n_points)) 291 | return res 292 | 293 | def reverse(tensor): 294 | idx = [i for i in range(tensor.size(0)-1, -1, -1)] 295 | return tensor[idx] 296 | 297 | 298 | def create_net(n_inputs, n_outputs, n_layers = 1, 299 | n_units = 100, nonlinear = nn.Tanh): 300 | layers = [nn.Linear(n_inputs, n_units)] 301 | for i in range(n_layers): 302 | layers.append(nonlinear()) 303 | layers.append(nn.Linear(n_units, n_units)) 304 | 305 | layers.append(nonlinear()) 306 | layers.append(nn.Linear(n_units, n_outputs)) 307 | return nn.Sequential(*layers) 308 | 309 | 310 | def get_item_from_pickle(pickle_file, item_name): 311 | from_pickle = load_pickle(pickle_file) 312 | if item_name in from_pickle: 313 | return from_pickle[item_name] 314 | return None 315 | 316 | 317 | def get_dict_template(): 318 | return {"observed_data": None, 319 | "observed_tp": None, 320 | "data_to_predict": None, 321 | "tp_to_predict": None, 322 | "observed_mask": None, 323 | "mask_predicted_data": None, 324 | "labels": None 325 | } 326 | 327 | 328 | def normalize_data(data): 329 | reshaped = data.reshape(-1, data.size(-1)) 330 | 331 | att_min = torch.min(reshaped, 0)[0] 332 | att_max = torch.max(reshaped, 0)[0] 333 | 334 | # we don't want to divide by zero 335 | att_max[ att_max == 0.] = 1. 336 | 337 | if (att_max != 0.).all(): 338 | data_norm = (data - att_min) / att_max 339 | else: 340 | raise Exception("Zero!") 341 | 342 | if torch.isnan(data_norm).any(): 343 | raise Exception("nans!") 344 | 345 | return data_norm, att_min, att_max 346 | 347 | 348 | def normalize_masked_data(data, mask, att_min, att_max): 349 | # we don't want to divide by zero 350 | att_max[ att_max == 0.] = 1. 351 | 352 | if (att_max != 0.).all(): 353 | data_norm = (data - att_min) / att_max 354 | else: 355 | raise Exception("Zero!") 356 | 357 | if torch.isnan(data_norm).any(): 358 | raise Exception("nans!") 359 | 360 | # set masked out elements back to zero 361 | data_norm[mask == 0] = 0 362 | 363 | return data_norm, att_min, att_max 364 | 365 | 366 | def shift_outputs(outputs, first_datapoint = None): 367 | outputs = outputs[:,:,:-1,:] 368 | 369 | if first_datapoint is not None: 370 | n_traj, n_dims = first_datapoint.size() 371 | first_datapoint = first_datapoint.reshape(1, n_traj, 1, n_dims) 372 | outputs = torch.cat((first_datapoint, outputs), 2) 373 | return outputs 374 | 375 | 376 | 377 | 378 | def split_data_extrap(data_dict, dataset = ""): 379 | device = get_device(data_dict["data"]) 380 | 381 | n_observed_tp = data_dict["data"].size(1) // 2 382 | if dataset == "hopper": 383 | n_observed_tp = data_dict["data"].size(1) // 3 384 | 385 | split_dict = {"observed_data": data_dict["data"][:,:n_observed_tp,:].clone(), 386 | "observed_tp": data_dict["time_steps"][:n_observed_tp].clone(), 387 | "data_to_predict": data_dict["data"][:,n_observed_tp:,:].clone(), 388 | "tp_to_predict": data_dict["time_steps"][n_observed_tp:].clone()} 389 | 390 | split_dict["observed_mask"] = None 391 | split_dict["mask_predicted_data"] = None 392 | split_dict["labels"] = None 393 | 394 | if ("mask" in data_dict) and (data_dict["mask"] is not None): 395 | split_dict["observed_mask"] = data_dict["mask"][:, :n_observed_tp].clone() 396 | split_dict["mask_predicted_data"] = data_dict["mask"][:, n_observed_tp:].clone() 397 | 398 | if ("labels" in data_dict) and (data_dict["labels"] is not None): 399 | split_dict["labels"] = data_dict["labels"].clone() 400 | 401 | split_dict["mode"] = "extrap" 402 | return split_dict 403 | 404 | 405 | 406 | 407 | 408 | def split_data_interp(data_dict): 409 | device = get_device(data_dict["data"]) 410 | 411 | split_dict = {"observed_data": data_dict["data"].clone(), 412 | "observed_tp": data_dict["time_steps"].clone(), 413 | "data_to_predict": data_dict["data"].clone(), 414 | "tp_to_predict": data_dict["time_steps"].clone()} 415 | 416 | split_dict["observed_mask"] = None 417 | split_dict["mask_predicted_data"] = None 418 | split_dict["labels"] = None 419 | 420 | if "mask" in data_dict and data_dict["mask"] is not None: 421 | split_dict["observed_mask"] = data_dict["mask"].clone() 422 | split_dict["mask_predicted_data"] = data_dict["mask"].clone() 423 | 424 | if ("labels" in data_dict) and (data_dict["labels"] is not None): 425 | split_dict["labels"] = data_dict["labels"].clone() 426 | 427 | split_dict["mode"] = "interp" 428 | return split_dict 429 | 430 | 431 | 432 | def add_mask(data_dict): 433 | data = data_dict["observed_data"] 434 | mask = data_dict["observed_mask"] 435 | 436 | if mask is None: 437 | mask = torch.ones_like(data).to(get_device(data)) 438 | 439 | data_dict["observed_mask"] = mask 440 | return data_dict 441 | 442 | 443 | def subsample_observed_data(data_dict, n_tp_to_sample = None, n_points_to_cut = None): 444 | # n_tp_to_sample -- if not None, randomly subsample the time points. The resulting timeline has n_tp_to_sample points 445 | # n_points_to_cut -- if not None, cut out consecutive points on the timeline. The resulting timeline has (N - n_points_to_cut) points 446 | 447 | if n_tp_to_sample is not None: 448 | # Randomly subsample time points 449 | data, time_steps, mask = subsample_timepoints( 450 | data_dict["observed_data"].clone(), 451 | time_steps = data_dict["observed_tp"].clone(), 452 | mask = (data_dict["observed_mask"].clone() if data_dict["observed_mask"] is not None else None), 453 | n_tp_to_sample = n_tp_to_sample) 454 | 455 | if n_points_to_cut is not None: 456 | # Remove consecutive time points 457 | data, time_steps, mask = cut_out_timepoints( 458 | data_dict["observed_data"].clone(), 459 | time_steps = data_dict["observed_tp"].clone(), 460 | mask = (data_dict["observed_mask"].clone() if data_dict["observed_mask"] is not None else None), 461 | n_points_to_cut = n_points_to_cut) 462 | 463 | new_data_dict = {} 464 | for key in data_dict.keys(): 465 | new_data_dict[key] = data_dict[key] 466 | 467 | new_data_dict["observed_data"] = data.clone() 468 | new_data_dict["observed_tp"] = time_steps.clone() 469 | new_data_dict["observed_mask"] = mask.clone() 470 | 471 | if n_points_to_cut is not None: 472 | # Cut the section in the data to predict as well 473 | # Used only for the demo on the periodic function 474 | new_data_dict["data_to_predict"] = data.clone() 475 | new_data_dict["tp_to_predict"] = time_steps.clone() 476 | new_data_dict["mask_predicted_data"] = mask.clone() 477 | 478 | return new_data_dict 479 | 480 | 481 | def split_and_subsample_batch(data_dict, args, data_type = "train"): 482 | if data_type == "train": 483 | # Training set 484 | if args.extrap: 485 | processed_dict = split_data_extrap(data_dict, dataset = args.dataset) 486 | else: 487 | processed_dict = split_data_interp(data_dict) 488 | 489 | else: 490 | # Test set 491 | if args.extrap: 492 | processed_dict = split_data_extrap(data_dict, dataset = args.dataset) 493 | else: 494 | processed_dict = split_data_interp(data_dict) 495 | 496 | # add mask 497 | processed_dict = add_mask(processed_dict) 498 | 499 | # Subsample points or cut out the whole section of the timeline 500 | if (args.sample_tp is not None) or (args.cut_tp is not None): 501 | processed_dict = subsample_observed_data(processed_dict, 502 | n_tp_to_sample = args.sample_tp, 503 | n_points_to_cut = args.cut_tp) 504 | 505 | # if (args.sample_tp is not None): 506 | # processed_dict = subsample_observed_data(processed_dict, 507 | # n_tp_to_sample = args.sample_tp) 508 | return processed_dict 509 | 510 | 511 | 512 | 513 | 514 | def compute_loss_all_batches(model, 515 | test_dataloader, args, 516 | n_batches, experimentID, device, 517 | n_traj_samples = 1, kl_coef = 1., 518 | max_samples_for_eval = None): 519 | 520 | total = {} 521 | total["loss"] = 0 522 | total["likelihood"] = 0 523 | total["mse"] = 0 524 | total["kl_first_p"] = 0 525 | total["std_first_p"] = 0 526 | total["pois_likelihood"] = 0 527 | total["ce_loss"] = 0 528 | 529 | n_test_batches = 0 530 | 531 | classif_predictions = torch.Tensor([]).to(device) 532 | all_test_labels = torch.Tensor([]).to(device) 533 | 534 | for i in range(n_batches): 535 | print("Computing loss... " + str(i)) 536 | 537 | batch_dict = get_next_batch(test_dataloader) 538 | 539 | results = model.compute_all_losses(batch_dict, 540 | n_traj_samples = n_traj_samples, kl_coef = kl_coef) 541 | 542 | if args.classif: 543 | n_labels = model.n_labels #batch_dict["labels"].size(-1) 544 | n_traj_samples = results["label_predictions"].size(0) 545 | 546 | classif_predictions = torch.cat((classif_predictions, 547 | results["label_predictions"].reshape(n_traj_samples, -1, n_labels)),1) 548 | all_test_labels = torch.cat((all_test_labels, 549 | batch_dict["labels"].reshape(-1, n_labels)),0) 550 | 551 | for key in total.keys(): 552 | if key in results: 553 | var = results[key] 554 | if isinstance(var, torch.Tensor): 555 | var = var.detach() 556 | total[key] += var 557 | 558 | n_test_batches += 1 559 | 560 | # for speed 561 | if max_samples_for_eval is not None: 562 | if n_batches * batch_size >= max_samples_for_eval: 563 | break 564 | 565 | if n_test_batches > 0: 566 | for key, value in total.items(): 567 | total[key] = total[key] / n_test_batches 568 | 569 | if args.classif: 570 | if args.dataset == "physionet": 571 | #all_test_labels = all_test_labels.reshape(-1) 572 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 573 | all_test_labels = all_test_labels.repeat(n_traj_samples,1,1) 574 | 575 | 576 | idx_not_nan = ~torch.isnan(all_test_labels) 577 | classif_predictions = classif_predictions[idx_not_nan] 578 | all_test_labels = all_test_labels[idx_not_nan] 579 | 580 | dirname = "plots/" + str(experimentID) + "/" 581 | os.makedirs(dirname, exist_ok=True) 582 | 583 | total["auc"] = 0. 584 | if torch.sum(all_test_labels) != 0.: 585 | print("Number of labeled examples: {}".format(len(all_test_labels.reshape(-1)))) 586 | print("Number of examples with mortality 1: {}".format(torch.sum(all_test_labels == 1.))) 587 | 588 | # Cannot compute AUC with only 1 class 589 | total["auc"] = sk.metrics.roc_auc_score(all_test_labels.cpu().numpy().reshape(-1), 590 | classif_predictions.cpu().numpy().reshape(-1)) 591 | else: 592 | print("Warning: Couldn't compute AUC -- all examples are from the same class") 593 | 594 | if args.dataset == "activity": 595 | all_test_labels = all_test_labels.repeat(n_traj_samples,1,1) 596 | 597 | labeled_tp = torch.sum(all_test_labels, -1) > 0. 598 | 599 | all_test_labels = all_test_labels[labeled_tp] 600 | classif_predictions = classif_predictions[labeled_tp] 601 | 602 | # classif_predictions and all_test_labels are in on-hot-encoding -- convert to class ids 603 | _, pred_class_id = torch.max(classif_predictions, -1) 604 | _, class_labels = torch.max(all_test_labels, -1) 605 | 606 | pred_class_id = pred_class_id.reshape(-1) 607 | 608 | total["accuracy"] = sk.metrics.accuracy_score( 609 | class_labels.cpu().numpy(), 610 | pred_class_id.cpu().numpy()) 611 | return total 612 | 613 | def check_mask(data, mask): 614 | #check that "mask" argument indeed contains a mask for data 615 | n_zeros = torch.sum(mask == 0.).cpu().numpy() 616 | n_ones = torch.sum(mask == 1.).cpu().numpy() 617 | 618 | # mask should contain only zeros and ones 619 | assert((n_zeros + n_ones) == np.prod(list(mask.size()))) 620 | 621 | # all masked out elements should be zeros 622 | assert(torch.sum(data[mask == 0.] != 0.) == 0) 623 | 624 | 625 | import time 626 | import numpy as np 627 | 628 | import torch 629 | import torch.nn as nn 630 | 631 | from torch.distributions.multivariate_normal import MultivariateNormal 632 | 633 | # git clone https://github.com/rtqichen/torchdiffeq.git 634 | from torchdiffeq import odeint as odeint 635 | 636 | ##################################################################################################### 637 | 638 | class DiffeqSolver(nn.Module): 639 | def __init__(self, input_dim, ode_func, method, latents, 640 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")): 641 | super(DiffeqSolver, self).__init__() 642 | 643 | self.ode_method = method 644 | self.latents = latents 645 | self.device = device 646 | self.ode_func = ode_func 647 | 648 | self.odeint_rtol = odeint_rtol 649 | self.odeint_atol = odeint_atol 650 | 651 | def forward(self, first_point, time_steps_to_predict, backwards = False): 652 | """ 653 | # Decode the trajectory through ODE Solver 654 | """ 655 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 656 | n_dims = first_point.size()[-1] 657 | 658 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict, 659 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 660 | pred_y = pred_y.permute(1,2,0,3) 661 | 662 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001) 663 | assert(pred_y.size()[0] == n_traj_samples) 664 | assert(pred_y.size()[1] == n_traj) 665 | 666 | return pred_y 667 | 668 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict, 669 | n_traj_samples = 1): 670 | """ 671 | # Decode the trajectory through ODE Solver using samples from the prior 672 | 673 | time_steps_to_predict: time steps at which we want to sample the new trajectory 674 | """ 675 | func = self.ode_func.sample_next_point_from_prior 676 | 677 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict, 678 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 679 | # shape: [n_traj_samples, n_traj, n_tp, n_dim] 680 | pred_y = pred_y.permute(1,2,0,3) 681 | return pred_y 682 | 683 | 684 | """ 685 | https://github.com/YuliaRubanova/latent_ode/blob/master/lib/encoder_decoder.py 686 | """ 687 | from torch.nn.modules.rnn import LSTM, GRU 688 | 689 | 690 | # GRU description: 691 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/ 692 | class GRU_unit(nn.Module): 693 | def __init__(self, latent_dim, input_dim, 694 | n_units = 100, 695 | # device = torch.device("cpu"), 696 | ): 697 | super(GRU_unit, self).__init__() 698 | 699 | self.update_gate = nn.Sequential( 700 | nn.Linear(latent_dim * 2 + input_dim, n_units), 701 | nn.Tanh(), 702 | nn.Linear(n_units, latent_dim), 703 | nn.Sigmoid()) 704 | 705 | self.reset_gate = nn.Sequential( 706 | nn.Linear(latent_dim * 2 + input_dim, n_units), 707 | nn.Tanh(), 708 | nn.Linear(n_units, latent_dim), 709 | nn.Sigmoid()) 710 | 711 | self.new_state_net = nn.Sequential( 712 | nn.Linear(latent_dim * 2 + input_dim, n_units), 713 | nn.Tanh(), 714 | nn.Linear(n_units, latent_dim * 2)) 715 | 716 | 717 | 718 | def forward(self, y_mean, y_std, x): 719 | y_concat = torch.cat([y_mean, y_std, x], -1) 720 | 721 | update_gate = self.update_gate(y_concat) 722 | reset_gate = self.reset_gate(y_concat) 723 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) 724 | 725 | new_state, new_state_std = split_last_dim(self.new_state_net(concat)) 726 | new_state_std = new_state_std.abs() 727 | 728 | new_y = (1-update_gate) * new_state + update_gate * y_mean 729 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std 730 | 731 | assert(not torch.isnan(new_y).any()) 732 | 733 | # took out masked update (no mask provided) 734 | 735 | new_y_std = new_y_std.abs() 736 | return new_y, new_y_std 737 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse_loss(pred, true): 4 | return torch.mean((pred - true) ** 2) 5 | 6 | def l1_loss(pred, true): 7 | return torch.mean(torch.abs(pred - true)) -------------------------------------------------------------------------------- /src/utils/metric_calc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | # dist calculations 5 | import math 6 | from typing import Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from typing import Optional 12 | from .mmd import linear_mmd2, mix_rbf_mmd2, poly_mmd2 13 | # from .optimal_transport import wasserstein 14 | import ot as pot 15 | from functools import partial 16 | import scipy.stats as stats 17 | from scipy.stats import pearsonr, spearmanr 18 | import pandas as pd 19 | 20 | 21 | # auroc and auprc 22 | from sklearn.metrics import roc_auc_score, average_precision_score 23 | 24 | 25 | 26 | def wasserstein( 27 | x0: torch.Tensor, 28 | x1: torch.Tensor, 29 | method: Optional[str] = None, 30 | reg: float = 0.05, 31 | power: int = 2, 32 | **kwargs, 33 | ) -> float: 34 | assert power == 1 or power == 2 35 | # ot_fn should take (a, b, M) as arguments where a, b are marginals and 36 | # M is a cost matrix 37 | if method == "exact" or method is None: 38 | ot_fn = pot.emd2 39 | elif method == "sinkhorn": 40 | ot_fn = partial(pot.sinkhorn2, reg=reg) 41 | else: 42 | raise ValueError(f"Unknown method: {method}") 43 | 44 | a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) 45 | if x0.dim() > 2: 46 | x0 = x0.reshape(x0.shape[0], -1) 47 | if x1.dim() > 2: 48 | x1 = x1.reshape(x1.shape[0], -1) 49 | M = torch.cdist(x0, x1) 50 | if power == 2: 51 | M = M**2 52 | ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) 53 | if power == 2: 54 | ret = math.sqrt(ret) 55 | return ret 56 | 57 | def compute_distances(pred, true): 58 | """computes distances between vectors.""" 59 | mse = torch.nn.functional.mse_loss(pred, true).item() 60 | me = math.sqrt(mse) 61 | return mse, me, torch.nn.functional.l1_loss(pred, true).item() 62 | 63 | def compute_distribution_distances_new(pred: torch.Tensor, true: Union[torch.Tensor, list]): 64 | """computes distances between distributions. 65 | pred: [batch, times, dims] tensor 66 | true: [batch, times, dims] tensor or list[batch[i], dims] of length times 67 | 68 | This handles jagged times as a list of tensors. 69 | return the eval for the last time point 70 | """ 71 | NAMES = [ 72 | "1-Wasserstein", 73 | "2-Wasserstein", 74 | "Linear_MMD", 75 | "Poly_MMD", 76 | "RBF_MMD", 77 | "Mean_MSE", 78 | "Mean_L2", 79 | "Mean_L1", 80 | "Median_MSE", 81 | "Median_L2", 82 | "Median_L1", 83 | ] 84 | is_jagged = isinstance(true, list) 85 | pred_is_jagged = isinstance(pred, list) 86 | dists = [] 87 | to_return = [] 88 | names = [] 89 | filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")] 90 | ts = len(pred) if pred_is_jagged else pred.shape[1] 91 | # for t in np.arange(ts): 92 | t = max(ts - 1, 0) 93 | if pred_is_jagged: 94 | a = pred[t] 95 | else: 96 | a = torch.tensor(pred).float().clone().detach() 97 | if is_jagged: 98 | b = true[t] 99 | else: 100 | b = torch.tensor(true).float().clone().detach() 101 | w1 = wasserstein(a, b, power=1) 102 | w2 = wasserstein(a, b, power=2) 103 | 104 | if not pred_is_jagged and not is_jagged: 105 | mmd_linear = linear_mmd2(a, b).item() 106 | mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item() 107 | mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item() 108 | mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0)) 109 | median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0]) 110 | 111 | if pred_is_jagged or is_jagged: 112 | dists.append((w1, w2, *mean_dists, *median_dists)) 113 | else: 114 | dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists)) 115 | 116 | to_return.extend(np.array(dists).mean(axis=0)) 117 | names.extend(filtered_names) 118 | return names, to_return 119 | 120 | 121 | def metrics_calculation(pred, true, metrics=['mse_loss', 'l1_loss'], cutoff=-0.91, map_idx = 1): 122 | 123 | # if pred is a tensor, convert to numpy 124 | if isinstance(pred, torch.Tensor): 125 | pred = pred.detach().cpu().squeeze().numpy() 126 | true = true.detach().cpu().squeeze().numpy() 127 | 128 | loss_D = {key : None for key in metrics} 129 | for metric in metrics: 130 | if metric == 'mse_loss': 131 | loss_D['mse_loss'] = np.mean((pred - true)**2) 132 | # self.log('mse_loss', self.loss_fn(pred, true)) 133 | if metric == 'l1_loss': 134 | loss_D['l1_loss'] = np.mean(np.abs(pred - true)) 135 | # self.log('l1_loss', torch.mean(torch.abs(pred - true))) 136 | if metric == 'crit_map': 137 | auroc, auprc = critical_state_pred(pred, true, cutoff, map_idx) 138 | loss_D['crit_state_auroc'] = auroc 139 | loss_D['crit_state_auprc'] = auprc 140 | if metric == 'variance_dist': 141 | # calculate distribution difference in variance 142 | # add to loss_D, multiple items 143 | var_d = variance_dist(pred, true) 144 | for key, val in var_d.items(): 145 | loss_D[key] = val 146 | # remove empty key 'variange_dist' 147 | if 'variance_dist' in loss_D: 148 | del loss_D['variance_dist'] 149 | 150 | return loss_D 151 | 152 | 153 | def variance_dist(pred, true): 154 | """Calculates the variance (between data points) for a full trajectory 155 | 156 | Args: 157 | pred (numpy array): _description_ 158 | true (numpy array): _description_ 159 | """ 160 | pred_var = np.diff(pred, axis=0) 161 | true_var = np.diff(true, axis=0) 162 | # convert both to torch tensor and use compute_distances 163 | pred_var = torch.tensor(pred_var).float() 164 | true_var = torch.tensor(true_var).float() 165 | names, values = compute_distribution_distances_new(pred_var, true_var) 166 | var_met_dict = {name: value for name, value in zip(names, values)} 167 | return var_met_dict 168 | 169 | 170 | def critical_state_pred(pred, true, cutoff=-0.91, map_idx = 1): 171 | """calculate the percentage of critical state predicted 172 | always in format: HR, MAP 173 | MAP: 60 cutoff 174 | 175 | Args: 176 | pred (_type_): _description_ 177 | true (_type_): _description_ 178 | cutoff (float, optional): _description_. Defaults to -0.91. 179 | -0.91 for normalized (not scaled) 180 | """ 181 | # return the percentage of critical state predicted 182 | critical_cutoff = cutoff 183 | pred_critical = (pred[:, map_idx] < critical_cutoff).astype(int) 184 | true_critical = (true[:, map_idx] < critical_cutoff).astype(int) 185 | auroc = roc_auc_score(true_critical, pred_critical) 186 | auprc = average_precision_score(true_critical, pred_critical) 187 | return auroc, auprc -------------------------------------------------------------------------------- /src/utils/mmd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | import torch 5 | 6 | min_var_est = 1e-8 7 | 8 | 9 | # Consider linear time MMD with a linear kernel: 10 | # K(f(x), f(y)) = f(x)^Tf(y) 11 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 12 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 13 | # 14 | # f_of_X: batch_size * k 15 | # f_of_Y: batch_size * k 16 | def linear_mmd2(f_of_X, f_of_Y): 17 | loss = 0.0 18 | delta = f_of_X - f_of_Y 19 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 20 | return loss 21 | 22 | 23 | # Consider linear time MMD with a polynomial kernel: 24 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 25 | # f_of_X: batch_size * k 26 | # f_of_Y: batch_size * k 27 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 28 | K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c 29 | K_XX_mean = torch.mean(K_XX.pow(d)) 30 | 31 | K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c 32 | K_YY_mean = torch.mean(K_YY.pow(d)) 33 | 34 | K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c 35 | K_XY_mean = torch.mean(K_XY.pow(d)) 36 | 37 | K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c 38 | K_YX_mean = torch.mean(K_YX.pow(d)) 39 | 40 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 41 | 42 | 43 | def _mix_rbf_kernel(X, Y, sigma_list): 44 | assert X.size(0) == Y.size(0) 45 | m = X.size(0) 46 | 47 | Z = torch.cat((X, Y), 0) 48 | ZZT = torch.mm(Z, Z.t()) 49 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 50 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 51 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 52 | 53 | K = 0.0 54 | for sigma in sigma_list: 55 | gamma = 1.0 / (2 * sigma**2) 56 | K += torch.exp(-gamma * exponent) 57 | 58 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 59 | 60 | 61 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 62 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 63 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 64 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 65 | 66 | 67 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 68 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 69 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 70 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 71 | 72 | 73 | ################################################################################ 74 | # Helper functions to compute variances based on kernel matrices 75 | ################################################################################ 76 | 77 | 78 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 79 | m = K_XX.size(0) # assume X, Y are same shape 80 | 81 | # Get the various sums of kernels that we'll use 82 | # Kts drop the diagonal, but we don't need to compute them explicitly 83 | if const_diagonal is not False: 84 | diag_X = diag_Y = const_diagonal 85 | sum_diag_X = sum_diag_Y = m * const_diagonal 86 | else: 87 | diag_X = torch.diag(K_XX) # (m,) 88 | diag_Y = torch.diag(K_YY) # (m,) 89 | sum_diag_X = torch.sum(diag_X) 90 | sum_diag_Y = torch.sum(diag_Y) 91 | 92 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 93 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 94 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 95 | 96 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 97 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 98 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 99 | 100 | if biased: 101 | mmd2 = ( 102 | (Kt_XX_sum + sum_diag_X) / (m * m) 103 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 104 | - 2.0 * K_XY_sum / (m * m) 105 | ) 106 | else: 107 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) 108 | 109 | return mmd2 110 | 111 | 112 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 113 | mmd2, var_est = _mmd2_and_variance( 114 | K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased 115 | ) 116 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 117 | return loss, mmd2, var_est 118 | 119 | 120 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 121 | m = K_XX.size(0) # assume X, Y are same shape 122 | 123 | # Get the various sums of kernels that we'll use 124 | # Kts drop the diagonal, but we don't need to compute them explicitly 125 | if const_diagonal is not False: 126 | diag_X = diag_Y = const_diagonal 127 | sum_diag_X = sum_diag_Y = m * const_diagonal 128 | sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 129 | else: 130 | diag_X = torch.diag(K_XX) # (m,) 131 | diag_Y = torch.diag(K_YY) # (m,) 132 | sum_diag_X = torch.sum(diag_X) 133 | sum_diag_Y = torch.sum(diag_Y) 134 | sum_diag2_X = diag_X.dot(diag_X) 135 | sum_diag2_Y = diag_Y.dot(diag_Y) 136 | 137 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 138 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 139 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 140 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 141 | 142 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 143 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 144 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 145 | 146 | Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 147 | Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 148 | K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2 149 | 150 | if biased: 151 | mmd2 = ( 152 | (Kt_XX_sum + sum_diag_X) / (m * m) 153 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 154 | - 2.0 * K_XY_sum / (m * m) 155 | ) 156 | else: 157 | mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) 158 | 159 | var_est = ( 160 | 2.0 161 | / (m**2 * (m - 1.0) ** 2) 162 | * ( 163 | 2 * Kt_XX_sums.dot(Kt_XX_sums) 164 | - Kt_XX_2_sum 165 | + 2 * Kt_YY_sums.dot(Kt_YY_sums) 166 | - Kt_YY_2_sum 167 | ) 168 | - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2) 169 | + 4.0 170 | * (m - 2.0) 171 | / (m**3 * (m - 1.0) ** 2) 172 | * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 173 | - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum) 174 | - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 175 | + 8.0 176 | / (m**3 * (m - 1.0)) 177 | * ( 178 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 179 | - Kt_XX_sums.dot(K_XY_sums_1) 180 | - Kt_YY_sums.dot(K_XY_sums_0) 181 | ) 182 | ) 183 | return mmd2, var_est -------------------------------------------------------------------------------- /src/utils/sde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SDE(torch.nn.Module): 4 | 5 | noise_type = "diagonal" 6 | sde_type = "ito" 7 | 8 | # noise is sigma in this notebook for the equation sigma * (t * (1 - t)) 9 | def __init__(self, ode_drift, noise=1.0, reverse=False): 10 | super().__init__() 11 | self.drift = ode_drift 12 | self.reverse = reverse 13 | self.noise = noise 14 | 15 | # Drift 16 | def f(self, t, y): 17 | if self.reverse: 18 | t = 1 - t 19 | if len(t.shape) == len(y.shape): 20 | x = torch.cat([y, t], 1) 21 | else: 22 | x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1) 23 | return self.drift(x) 24 | 25 | # Diffusion 26 | def g(self, t, y): 27 | return torch.ones_like(t) * torch.ones_like(y) * self.noise -------------------------------------------------------------------------------- /src/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import Axes3D 5 | 6 | def plot_3d_path_ind(traj, groundtruth, t_span=torch.linspace(0, 4 * np.pi, 100), title=""): 7 | n = len(t_span) 8 | fig = plt.figure(figsize=(15, 10)) 9 | ax1 = fig.add_subplot(1, 1, 1, projection='3d') 10 | len_traj = traj.shape[0] 11 | ax1.scatter([0] * len_traj, traj[0, 0], traj[0, 1], alpha=0.5, c="red") # start 12 | for i in range(n - 1): 13 | ax1.plot([t_span[i], t_span[i + 1]], [traj[i, 0], traj[i + 1, 0]], [traj[i, 1], traj[i + 1, 1]], alpha=1, c="olive") # path 14 | ax1.plot([t_span[i], t_span[i + 1]], [groundtruth[i, 0], groundtruth[i + 1, 0]], [groundtruth[i, 1], groundtruth[i + 1, 1]], alpha=1, c="pink") 15 | ax1.scatter(t_span, traj[:, 0], traj[:, 1], alpha=0.5, c="blue") # end 16 | ax1.scatter(t_span, groundtruth[:, 0], groundtruth[:, 1], alpha=0.5, c="purple") # ground truth 17 | ax1.set_title(title) 18 | 19 | return fig 20 | 21 | 22 | def plot_3d_path_ind_noise(traj, groundtruth, noise, t_span=torch.linspace(0, 4 * np.pi, 100), title=""): 23 | n = len(t_span) 24 | fig = plt.figure(figsize=(15, 10)) 25 | ax1 = fig.add_subplot(1, 1, 1, projection='3d') 26 | 27 | noise = noise.cpu().numpy() 28 | 29 | len_traj = traj.shape[0] 30 | ax1.scatter([0] * len_traj, traj[0, 0], traj[0, 1], alpha=0.5, c="red") # start 31 | 32 | # Plot trajectory and ground truth 33 | ax1.plot(t_span, traj[:, 0], traj[:, 1], label='Trajectory', c='olive') 34 | ax1.plot(t_span, groundtruth[:, 0], groundtruth[:, 1], label='Ground Truth', c='pink') 35 | ax1.scatter(t_span, traj[:, 0], traj[:, 1], alpha=0.5, c="blue") # end 36 | ax1.scatter(t_span, groundtruth[:, 0], groundtruth[:, 1], alpha=0.5, c="purple") # ground truth 37 | # Plot uncertainty as scatter points around each trajectory point 38 | # Plus and minus noise values for visualization 39 | for i in range(n-1): 40 | if i == 0: 41 | continue 42 | x_noise_pos = traj[i+1, 0] + noise[i, 0] 43 | y_noise_pos = traj[i+1, 1] + noise[i, 1] 44 | x_noise_neg = traj[i+1, 0] - noise[i, 0] 45 | y_noise_neg = traj[i+1, 1] - noise[i, 1] 46 | ax1.scatter([t_span[i]]*2, [x_noise_pos, x_noise_neg], [y_noise_pos, y_noise_neg], color='gray', alpha=0.5) 47 | 48 | ax1.set_title(title) 49 | ax1.set_xlabel('Time') 50 | ax1.set_ylabel('X') 51 | ax1.set_zlabel('Y') 52 | ax1.legend() 53 | 54 | return fig 55 | 56 | 57 | 58 | def join_3d_plots(figs, rows, cols): 59 | new_fig = plt.figure(figsize=(15 * cols, 10 * rows)) 60 | for i, fig in enumerate(figs): 61 | ax = new_fig.add_subplot(rows, cols, i + 1, projection='3d') 62 | original_ax = fig.get_children()[1] 63 | for line in original_ax.get_lines(): 64 | ax.add_line(line) 65 | for patch in original_ax.get_patches(): 66 | ax.add_patch(patch) 67 | return new_fig -------------------------------------------------------------------------------- /src/utils/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class torch_wrapper_tv(torch.nn.Module): 4 | """Wraps model to torchdyn compatible format.""" 5 | 6 | def __init__(self, model): 7 | super().__init__() 8 | self.model = model 9 | 10 | def forward(self, t, x, *args, **kwargs): 11 | return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) --------------------------------------------------------------------------------