├── .gitignore ├── README.md ├── bin ├── arima.sh ├── gluformer.sh ├── latentode.sh ├── linreg.sh ├── nhits.sh ├── tft.sh ├── transformer.sh └── xgbtree.sh ├── config ├── colas.yaml ├── dubosson.yaml ├── hall.yaml ├── iglu.yaml └── weinstock.yaml ├── data_formatter ├── __init__.py ├── base.py ├── types.py └── utils.py ├── example.ipynb ├── exploratory_analysis ├── colas.ipynb ├── dubosson.ipynb ├── hall.ipynb ├── iglu.ipynb └── weinstock.ipynb ├── lib ├── __init__.py ├── arima.py ├── gluformer.py ├── gluformer │ ├── __init__.py │ ├── attention.py │ ├── decoder.py │ ├── embed.py │ ├── encoder.py │ ├── model.py │ ├── utils │ │ ├── __init__.py │ │ ├── collate.py │ │ ├── evaluation.py │ │ └── training.py │ └── variance.py ├── latent_ode │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── base_models.py │ ├── create_latent_ode_model.py │ ├── diffeq_solver.py │ ├── encoder_decoder.py │ ├── eval_glunet.py │ ├── latent_ode.py │ ├── likelihood_eval.py │ ├── ode_func.py │ ├── ode_rnn.py │ ├── plotting.py │ ├── rnn_baselines.py │ ├── run_models.py │ ├── test_glunet.ipynb │ ├── trainer_glunet.py │ └── utils.py ├── latentode.py ├── linreg.py ├── nhits.py ├── tft.py ├── transformer.py └── xgbtree.py ├── output ├── arima_colas.txt ├── arima_dubosson.txt ├── arima_hall.txt ├── arima_iglu.txt ├── arima_weinstock.txt ├── gluformer_colas.txt ├── gluformer_dubosson.txt ├── gluformer_hall.txt ├── gluformer_iglu.txt ├── gluformer_weinstock.txt ├── latentode_colas.txt ├── latentode_dubosson.txt ├── latentode_hall.txt ├── latentode_iglu.txt ├── latentode_weinstock.txt ├── linreg_colas.txt ├── linreg_covariates_colas.txt ├── linreg_covariates_dubosson.txt ├── linreg_covariates_hall.txt ├── linreg_covariates_iglu.txt ├── linreg_covariates_weinstock.txt ├── linreg_dubosson.txt ├── linreg_hall.txt ├── linreg_iglu.txt ├── linreg_weinstock.txt ├── nhits_colas.txt ├── nhits_covariates_colas.txt ├── nhits_covariates_dubosson.txt ├── nhits_covariates_hall.txt ├── nhits_covariates_iglu.txt ├── nhits_covariates_weinstock.txt ├── nhits_dubosson.txt ├── nhits_hall.txt ├── nhits_iglu.txt ├── nhits_weinstock.txt ├── tft_colas.txt ├── tft_covariates_colas.txt ├── tft_covariates_dubosson.txt ├── tft_covariates_hall.txt ├── tft_covariates_iglu.txt ├── tft_covariates_weinstock.txt ├── tft_dubosson.txt ├── tft_hall.txt ├── tft_iglu.txt ├── tft_weinstock.txt ├── transformer_colas.txt ├── transformer_covariates_colas.txt ├── transformer_covariates_dubosson.txt ├── transformer_covariates_hall.txt ├── transformer_covariates_iglu.txt ├── transformer_covariates_weinstock.txt ├── transformer_dubosson.txt ├── transformer_hall.txt ├── transformer_iglu.txt ├── transformer_weinstock.txt ├── xgboost_colas.txt ├── xgboost_covariates_colas.txt ├── xgboost_covariates_dubosson.txt ├── xgboost_covariates_hall.txt ├── xgboost_covariates_iglu.txt ├── xgboost_covariates_weinstock.txt ├── xgboost_dubosson.txt ├── xgboost_hall.txt ├── xgboost_iglu.txt └── xgboost_weinstock.txt ├── paper_results ├── covariate_importance_xgb.ipynb ├── figure2.ipynb ├── figure3.ipynb ├── figure4.ipynb ├── figure5.ipynb ├── forecast_compute.ipynb ├── forecast_plot.ipynb ├── parser.py ├── plots │ ├── figure2.pdf │ ├── figure3.pdf │ ├── figure3_annot.pptx │ ├── figure4.pdf │ ├── figure5.pdf │ ├── figure6.pdf │ ├── figure6.png │ └── nhits_single_prediction.pdf └── tables.ipynb ├── raw_data.zip ├── requirements.txt └── utils ├── __init__.py ├── darts_dataset.py ├── darts_evaluation.py ├── darts_processing.py └── darts_training.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore model and data files 2 | *.pth 3 | *.npy 4 | *.pkl 5 | cache/* 6 | *.out 7 | track_* 8 | output/tensorboard_*/* 9 | 10 | # ignore folder raw_data 11 | raw_data/ 12 | 13 | # ignore pre-complied code 14 | *.cpython-37.pyc 15 | __pycache__/ 16 | *.py[cod] 17 | *.pyc 18 | *.DS_Store 19 | .ipynb_checkpoints 20 | 21 | # ignore irrelevant files 22 | legacy 23 | papers 24 | meeting_notes.md 25 | 26 | 27 | # ignore workspace 28 | *.code-workspace 29 | 30 | # PyCharm 31 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 32 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 33 | # and can be added to the global gitignore or merged into this file. For a more nuclear 34 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 35 | .idea/ 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GlucoBench 2 | 3 | The official implementation of the paper "GlucoBench: Curated List of Continuous Glucose Monitoring Datasets with Prediction Benchmarks." 4 | If you found our work interesting and plan to re-use the code, please cite us as: 5 | ``` 6 | @article{ 7 | author = {Renat Sergazinov and Valeriya Rogovchenko and Elizabeth Chun and Nathaniel Fernandes and Irina Gaynanova}, 8 | title = {GlucoBench: Curated List of Continuous Glucose Monitoring Datasets with Prediction Benchmarks}, 9 | journal = {arXiv} 10 | year = {2023}, 11 | } 12 | ``` 13 | 14 | # Dependencies 15 | 16 | We recommend to setup clean Python environment with [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html "conda-env") by running `conda create -n glucobench python=3.10`. Then we can install all dependenices by running `pip install -r requirments.txt`. 17 | 18 | To run [Latent ODE](https://github.com/YuliaRubanova/latent_ode) model, install [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq). 19 | 20 | # Code organization 21 | 22 | The code is organized as follows: 23 | 24 | - `bin/`: training commands for all models 25 | - `config/`: configuration files for all datasets 26 | - `data_formatter/` 27 | - base.py: performs **all** pre-processing for all CGM datasets 28 | - `exploratory_analysis/`: notebooks with processing steps for pulling the data and converting to `.csv` files 29 | - `lib/` 30 | - `gluformer/`: model implementation 31 | - `latent_ode/`: model implementation 32 | - `*.py`: hyper-paraemter tuning, training, validation, and testing scripts 33 | - `output/`: hyper-parameter optimization and testing logs 34 | - `paper_results/`: code for producing tables and plots, found in the paper 35 | - `utils/`: helper functions for model training and testing 36 | - `raw_data.zip`: web-pulled CGM data (processed using `exploratory_analysis`) 37 | - `environment.yml`: conda environment file 38 | 39 | # Data 40 | 41 | The datasets are distributed according to the following licences and can be downloaded from the following links outlined in the table below. 42 | 43 | | Dataset | License | Number of patients | CGM Frequency | 44 | | ------- | ------- | ------------------ | ------------- | 45 | | [Colas](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0225817#sec018) | [Creative Commons 4.0](https://creativecommons.org/licenses/by/3.0/us/) | 208 | 5 minutes | 46 | | [Dubosson](https://doi.org/10.5281/zenodo.1421615) | [Creative Commons 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) | 9 | 5 minutes | 47 | | [Hall](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.2005143#pbio.2005143.s010) | [Creative Commons 4.0](https://creativecommons.org/licenses/by/4.0/) | 57 | 5 minutes | 48 | | [Broll](https://github.com/irinagain/iglu) | [GPL-2](https://www.r-project.org/Licenses/GPL-2) | 5 | 5 minutes | 49 | | [Weinstock](https://public.jaeb.org/dataset/537) | [Creative Commons 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) | 200 | 5 minutes | 50 | 51 | To process the data, follow the instructions in the `exploratory_analysis/` folder. Processed datasets should be saved in the `raw_data/` folder. We provide examples in the `raw_data.zip` file. 52 | 53 | # How to reproduce results? 54 | 55 | ## Setting up the enviroment 56 | 57 | We recommend setting up a clean Python environment using [Conda](https://docs.conda.io/). Follow these steps: 58 | 59 | 1. Create a new environment named `glucobench` with Python 3.10 by running: 60 | ``` 61 | conda env create -n glucobench python=3.10 62 | ``` 63 | 64 | 2. Activate the environment with: 65 | ``` 66 | conda activate glucobench 67 | ``` 68 | 69 | 3. Install all required dependencies by running: 70 | ``` 71 | pip install -r requirements.txt 72 | ``` 73 | 74 | 4. (Optional) To confirm that you're installing in the correct environment, run: 75 | ``` 76 | which pip 77 | ``` 78 | This should display the path to the `pip` executable within the `glucobench` environment." 79 | 80 | ## Changing the configs 81 | 82 | The `config/` folder stores the best hyper-parameters (selected by [Optuna](https://optuna.org)) for each dataset and model. The `config/` also stores the dataset-specific parameters for interpolation, dropping, splitting, and scaling. To train and evaluate the models with these defaults, we can simply run: 83 | 84 | ```python 85 | python ./lib/model.py --dataset dataset --use_covs False --optuna False 86 | ``` 87 | 88 | ## Changing the hyper-parameters 89 | 90 | To change the search grid for hyper-parameters, we need to modify the `./lib/model.py` file. Specifically, we look at the `objective()` function and modify the `trial.suggest_*` parameters to set the desired ranges. Once we are done, we can run the following command to re-run the hyper-parameter optimization: 91 | 92 | ```python 93 | python ./lib/model.py --dataset dataset --use_covs False --optuna True 94 | ``` 95 | 96 | # How to work with the repository? 97 | 98 | We provide a detailed example of the workflow in the `example.ipynb` notebook. For clarification, we provide some general suggestions below in order of increasing complexity. 99 | 100 | ## Just the data 101 | To start experimenting with the data, we can run the following command: 102 | 103 | ```python 104 | import yaml 105 | from data_formatter.base import DataFormatter 106 | 107 | with open(f'./config/{dataset}.yaml', 'r') as f: 108 | config = yaml.safe_load(f) 109 | formatter = DataFormatter(config) 110 | ``` 111 | 112 | The command exposes an object of class `DataFormatter` which automatically pre-processes the data upon initialization. The pre-processing steps can be controlled via the `config/` files. The `DataFormatter` object exposes the following attributes: 113 | 114 | 1. `formatter.train_data`: training data (as `pandas.DataFrame`) 115 | 2. `formatter.val_data`: validation data 116 | 3. `formatter.test_data`: testing (in-distribution and out-of-distribution) data 117 | i. `formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)]`: in-distribution testing data 118 | ii. `formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]`: out-of-distribution testing data 119 | 4. `formatter.data`: unscaled full data 120 | 121 | ## Integration with PyTorch 122 | 123 | Training models with PyTorch typically boils down to (1) defining a `Dataset` class with `__getitem__()` method, (2) wrapping it into a `DataLoader`, (3) defining a `torch.nn.Module` class with `forward()` method that implements the model, and (4) optimizing the model with `torch.optim` in a training loop. 124 | 125 | **Parts (1) and (2)** crucically depend on the definition of the `Dataset` class. Essentially, having the data in the table format (e.g. `formatter.train_data`), how do we sample input-output pairs and pass the covariate information? The various `Dataset` classes conveniently adopted from the `Darts` library (see [here](https://unit8co.github.io/darts/generated_api/darts.utils.data.training_dataset.html)) offer one way to wrap the data into a `Dataset` class. Different `Dataset` classes differ in what information is provided to the model: 126 | 127 | 1. `SamplingDatasetPast`: supports only past covariates 128 | 2. `SamplingDatasetDual`: supports only future-known covariates 129 | 3. `SamplingDatasetMixed`: supports both past and future-known covariates 130 | 131 | Below we give an example of loading the data and wrapping it into a `Dataset`: 132 | 133 | ```python 134 | from utils.darts_processing import load_data 135 | from utils.darts_dataset import SamplingDatasetDual 136 | 137 | formatter, series, scalers = load_data(seed=0, 138 | dataset=dataset, 139 | use_covs=True, 140 | cov_type='dual', 141 | use_static_covs=True) 142 | dataset_train = SamplingDatasetDual(series['train']['target'], 143 | series['train']['future'], 144 | output_chunk_length=out_len, 145 | input_chunk_length=in_len, 146 | use_static_covariates=True, 147 | max_samples_per_ts=max_samples_per_ts,) 148 | ``` 149 | 150 | **Parts (3) and (4)** are model-specific, so we omit their discussion. For inspiration, we suggest to take a look at the `lib/gluformer/model.py` and `lib/latent_ode/trainer_glunet.py` files. 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /bin/arima.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/arima.py --dataset weinstock > ./output/track_arima_weinstock.txt & 5 | nohup python ./lib/arima.py --dataset colas > ./output/track_arima_colas.txt & 6 | nohup python ./lib/arima.py --dataset dubosson > ./output/track_arima_dubosson.txt & 7 | nohup python ./lib/arima.py --dataset hall > ./output/track_arima_hall.txt & 8 | nohup python ./lib/arima.py --dataset iglu > ./output/track_arima_iglu.txt & 9 | 10 | -------------------------------------------------------------------------------- /bin/gluformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/gluformer.py --dataset weinstock --gpu_id 1 --optuna False > ./output/track_gluformer_weinstock.txt & 5 | # nohup python ./lib/gluformer.py --dataset colas --gpu_id 2 --optuna False > ./output/track_gluformer_colas.txt & 6 | # nohup python ./lib/gluformer.py --dataset dubosson --gpu_id 0 --optuna False > ./output/track_gluformer_dubosson.txt & 7 | # nohup python ./lib/gluformer.py --dataset hall --gpu_id 3 --optuna False > ./output/track_gluformer_hall.txt & 8 | # nohup python ./lib/gluformer.py --dataset iglu --gpu_id 0 --optuna False > ./output/track_gluformer_iglu.txt & 9 | 10 | -------------------------------------------------------------------------------- /bin/latentode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/latentode.py --dataset weinstock --gpu_id 3 --optuna False > ./output/track_latentode_weinstock.txt & 5 | nohup python ./lib/latentode.py --dataset colas --gpu_id 2 --optuna False > ./output/track_latentode_colas.txt & 6 | nohup python ./lib/latentode.py --dataset dubosson --gpu_id 0 --optuna False > ./output/track_latentode_dubosson.txt & 7 | nohup python ./lib/latentode.py --dataset hall --gpu_id 1 --optuna False > ./output/track_latentode_hall.txt & 8 | nohup python ./lib/latentode.py --dataset iglu --gpu_id 0 --optuna False > ./output/track_latentode_iglu.txt & 9 | 10 | -------------------------------------------------------------------------------- /bin/linreg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/linreg.py --dataset weinstock --use_covs False --optuna False > ./output/track_linreg_weinstock.txt & 5 | nohup python ./lib/linreg.py --dataset weinstock --use_covs True --optuna False > ./output/track_linreg_covariates_weinstock.txt & 6 | nohup python ./lib/linreg.py --dataset colas --use_covs False --optuna False > ./output/track_linreg_colas.txt & 7 | nohup python ./lib/linreg.py --dataset colas --use_covs True --optuna False > ./output/track_linreg_covariates_colas.txt & 8 | nohup python ./lib/linreg.py --dataset dubosson --use_covs False --optuna False > ./output/track_linreg_dubosson.txt & 9 | nohup python ./lib/linreg.py --dataset dubosson --use_covs True --optuna False > ./output/track_linreg_covariates_dubosson.txt & 10 | nohup python ./lib/linreg.py --dataset hall --use_covs False --optuna False > ./output/track_linreg_hall.txt & 11 | nohup python ./lib/linreg.py --dataset hall --use_covs True --optuna False > ./output/track_linreg_covariates_hall.txt & 12 | nohup python ./lib/linreg.py --dataset iglu --use_covs False --optuna False > ./output/track_linreg_iglu.txt & 13 | nohup python ./lib/linreg.py --dataset iglu --use_covs True --optuna False > ./output/track_linreg_covariates_iglu.txt & 14 | -------------------------------------------------------------------------------- /bin/nhits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/nhits.py --dataset weinstock --use_covs False --optuna False > ./output/track_nhits_weinstock.txt & 5 | nohup python ./lib/nhits.py --dataset weinstock --use_covs True --optuna False > ./output/track_nhits_covariates_weinstock.txt & 6 | nohup python ./lib/nhits.py --dataset colas --use_covs False --optuna False > ./output/track_nhits_colas.txt & 7 | nohup python ./lib/nhits.py --dataset colas --use_covs True --optuna False > ./output/track_nhits_covariates_colas.txt & 8 | nohup python ./lib/nhits.py --dataset dubosson --use_covs False --optuna False > ./output/track_nhits_dubosson.txt & 9 | nohup python ./lib/nhits.py --dataset dubosson --use_covs True --optuna False > ./output/track_nhits_covariates_dubosson.txt & 10 | nohup python ./lib/nhits.py --dataset hall --use_covs False --optuna False > ./output/track_nhits_hall.txt & 11 | nohup python ./lib/nhits.py --dataset hall --use_covs True --optuna False > ./output/track_nhits_covariates_hall.txt & 12 | nohup python ./lib/nhits.py --dataset iglu --use_covs False --optuna False > ./output/track_nhits_iglu.txt & 13 | nohup python ./lib/nhits.py --dataset iglu --use_covs True --optuna False > ./output/track_nhits_covariates_iglu.txt & 14 | -------------------------------------------------------------------------------- /bin/tft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/tft.py --dataset weinstock --use_covs False --optuna False > ./output/track_tft_weinstock.txt & 5 | # nohup python ./lib/tft.py --dataset weinstock --use_covs True --optuna False > ./output/track_tft_covariates_weinstock.txt & 6 | # nohup python ./lib/tft.py --dataset colas --use_covs False --optuna False > ./output/track_tft_colas.txt & 7 | # nohup python ./lib/tft.py --dataset colas --use_covs True --optuna False > ./output/track_tft_covariates_colas.txt & 8 | # nohup python ./lib/tft.py --dataset dubosson --use_covs False --optuna False > ./output/track_tft_dubosson.txt & 9 | # nohup python ./lib/tft.py --dataset dubosson --use_covs True --optuna False > ./output/track_tft_covariates_dubosson.txt & 10 | # nohup python ./lib/tft.py --dataset hall --use_covs False --optuna False > ./output/track_tft_hall.txt & 11 | # nohup python ./lib/tft.py --dataset hall --use_covs True --optuna False > ./output/track_tft_covariates_hall.txt & 12 | # nohup python ./lib/tft.py --dataset iglu --use_covs False --optuna False > ./output/track_tft_iglu.txt & 13 | # nohup python ./lib/tft.py --dataset iglu --use_covs True --optuna False > ./output/track_tft_covariates_iglu.txt & 14 | -------------------------------------------------------------------------------- /bin/transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/transformer.py --dataset weinstock --use_covs False --optuna False > ./output/track_transformer_weinstock.txt & 5 | nohup python ./lib/transformer.py --dataset weinstock --use_covs True --optuna False > ./output/track_transformer_covariates_weinstock.txt & 6 | nohup python ./lib/transformer.py --dataset colas --use_covs False --optuna False > ./output/track_transformer_colas.txt & 7 | nohup python ./lib/transformer.py --dataset colas --use_covs True --optuna False > ./output/track_transformer_covariates_colas.txt & 8 | nohup python ./lib/transformer.py --dataset dubosson --use_covs False --optuna False > ./output/track_transformer_dubosson.txt & 9 | nohup python ./lib/transformer.py --dataset dubosson --use_covs True --optuna False > ./output/track_transformer_covariates_dubosson.txt & 10 | nohup python ./lib/transformer.py --dataset hall --use_covs False --optuna False > ./output/track_transformer_hall.txt & 11 | nohup python ./lib/transformer.py --dataset hall --use_covs True --optuna False > ./output/track_transformer_covariates_hall.txt & 12 | nohup python ./lib/transformer.py --dataset iglu --use_covs False --optuna False > ./output/track_transformer_iglu.txt & 13 | nohup python ./lib/transformer.py --dataset iglu --use_covs True --optuna False > ./output/track_transformer_covariates_iglu.txt & 14 | -------------------------------------------------------------------------------- /bin/xgbtree.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # execute in parallel 4 | nohup python ./lib/xgbtree.py --dataset weinstock --use_covs False --optuna False > ./output/track_xgboost_weinstock.txt & 5 | nohup python ./lib/xgbtree.py --dataset weinstock --use_covs True --optuna False > ./output/track_xgboost_covariates_weinstock.txt & 6 | nohup python ./lib/xgbtree.py --dataset colas --use_covs False --optuna False > ./output/track_xgboost_colas.txt & 7 | nohup python ./lib/xgbtree.py --dataset colas --use_covs True --optuna False > ./output/track_xgboost_covariates_colas.txt & 8 | nohup python ./lib/xgbtree.py --dataset dubosson --use_covs False --optuna False > ./output/track_xgboost_dubosson.txt & 9 | nohup python ./lib/xgbtree.py --dataset dubosson --use_covs True --optuna False > ./output/track_xgboost_covariates_dubosson.txt & 10 | nohup python ./lib/xgbtree.py --dataset hall --use_covs False --optuna False > ./output/track_xgboost_hall.txt & 11 | nohup python ./lib/xgbtree.py --dataset hall --use_covs True --optuna False > ./output/track_xgboost_covariates_hall.txt & 12 | nohup python ./lib/xgbtree.py --dataset iglu --use_covs False --optuna False > ./output/track_xgboost_iglu.txt & 13 | nohup python ./lib/xgbtree.py --dataset iglu --use_covs True --optuna False > ./output/track_xgboost_covariates_iglu.txt & 14 | -------------------------------------------------------------------------------- /config/colas.yaml: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ds_name: colas2019 3 | data_csv_path: ./raw_data/colas.csv 4 | index_col: -1 5 | observation_interval: 5min 6 | 7 | # Columns 8 | column_definition: 9 | - name: id 10 | data_type: categorical 11 | input_type: id 12 | - name: time 13 | data_type: date 14 | input_type: time 15 | - name: gl 16 | data_type: real_valued 17 | input_type: target 18 | - name: gender 19 | data_type: categorical 20 | input_type: static_input 21 | - name: age 22 | data_type: real_valued 23 | input_type: static_input 24 | - name: BMI 25 | data_type: real_valued 26 | input_type: static_input 27 | - name: glycaemia 28 | data_type: real_valued 29 | input_type: static_input 30 | - name: HbA1c 31 | data_type: real_valued 32 | input_type: static_input 33 | - name: follow.up 34 | data_type: real_valued 35 | input_type: static_input 36 | - name: T2DM 37 | data_type: categorical 38 | input_type: static_input 39 | 40 | # Drop 41 | drop: null 42 | 43 | # NA values abbreviation 44 | nan_vals: null 45 | 46 | # Interpolation parameters 47 | interpolation_params: 48 | gap_threshold: 45 49 | min_drop_length: 192 50 | 51 | # Splitting parameters 52 | split_params: 53 | test_percent_subjects: 0.1 54 | length_segment: 72 55 | random_state: 0 56 | 57 | # Encoding parameters 58 | encoding_params: 59 | date: 60 | - year 61 | - month 62 | - day 63 | - hour 64 | - minute 65 | 66 | # Scaling parameters 67 | scaling_params: 68 | scaler: None 69 | 70 | # Model params 71 | max_length_input: 144 72 | length_pred: 12 73 | 74 | linreg: 75 | in_len: 12 76 | 77 | linreg_covariates: 78 | in_len: 12 79 | 80 | tft: 81 | in_len: 132 82 | max_samples_per_ts: 200 83 | hidden_size: 256 84 | num_attention_heads: 3 85 | dropout: 0.22683125764190215 86 | lr: 0.0005939103829095587 87 | batch_size: 32 88 | max_grad_norm: 0.9791152645996767 89 | 90 | tft_covariates: 91 | in_len: 120 92 | max_samples_per_ts: 100 93 | hidden_size: 32 94 | num_attention_heads: 3 95 | dropout: 0.10643530677029577 96 | lr: 0.004702414513886559 97 | batch_size: 32 98 | max_grad_norm: 0.8047252326588638 99 | 100 | xgboost: 101 | in_len: 120 102 | lr: 0.509 103 | subsample: 0.9 104 | min_child_weight: 5.0 105 | colsample_bytree: 0.9 106 | max_depth: 7 107 | gamma: 0.5 108 | alpha: 0.216 109 | lambda_: 0.241 110 | n_estimators: 352 111 | 112 | xgboost_covariates: 113 | in_len: 144 114 | lr: 0.883 115 | subsample: 0.9 116 | min_child_weight: 3.0 117 | colsample_bytree: 0.8 118 | max_depth: 5 119 | gamma: 0.5 120 | alpha: 0.055 121 | lambda_: 0.08700000000000001 122 | n_estimators: 416 123 | 124 | transformer: 125 | in_len: 108 126 | max_samples_per_ts: 200 127 | d_model: 64 128 | n_heads: 2 129 | num_encoder_layers: 3 130 | num_decoder_layers: 3 131 | dim_feedforward: 480 132 | dropout: 0.12434517563324206 133 | lr: 0.00048663109178350133 134 | batch_size: 32 135 | lr_epochs: 8 136 | max_grad_norm: 0.8299004621292704 137 | 138 | transformer_covariates: 139 | in_len: 120 140 | max_samples_per_ts: 200 141 | d_model: 128 142 | n_heads: 4 143 | num_encoder_layers: 4 144 | num_decoder_layers: 1 145 | dim_feedforward: 128 146 | dropout: 0.19572808311258694 147 | lr: 0.0008814762155445509 148 | batch_size: 32 149 | lr_epochs: 18 150 | max_grad_norm: 0.8168361106999547 151 | 152 | nhits: 153 | in_len: 132 154 | max_samples_per_ts: 100 155 | kernel_sizes: 3 156 | dropout: 0.18002875427414997 157 | lr: 0.0006643638126306677 158 | batch_size: 32 159 | lr_epochs: 2 160 | 161 | nhits_covariates: 162 | in_len: 96 163 | max_samples_per_ts: 50 164 | kernel_sizes: 3 165 | dropout: 0.13142967835347927 166 | lr: 0.0008921763677516184 167 | batch_size: 48 168 | lr_epochs: 16 169 | 170 | gluformer: 171 | in_len: 96 172 | max_samples_per_ts: 150 173 | d_model: 384 174 | n_heads: 12 175 | d_fcn: 512 176 | num_enc_layers: 1 177 | num_dec_layers: 1 178 | 179 | latentode: 180 | in_len: 48 181 | max_samples_per_ts: 100 182 | latents: 20 183 | rec_dims: 40 184 | rec_layers: 3 185 | gen_layers: 3 186 | units: 100 187 | gru_units: 100 188 | -------------------------------------------------------------------------------- /config/dubosson.yaml: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ds_name: dubosson2018 3 | data_csv_path: ./raw_data/dubosson.csv 4 | index_col: -1 5 | observation_interval: 5min 6 | 7 | # Columns 8 | column_definition: 9 | - name: id 10 | data_type: categorical 11 | input_type: id 12 | - name: time 13 | data_type: date 14 | input_type: time 15 | - name: gl 16 | data_type: real_valued 17 | input_type: target 18 | - name: fast_insulin 19 | data_type: real_valued 20 | input_type: observed_input 21 | - name: slow_insulin 22 | data_type: real_valued 23 | input_type: observed_input 24 | - name: calories 25 | data_type: real_valued 26 | input_type: observed_input 27 | - name: balance 28 | data_type: categorical 29 | input_type: observed_input 30 | - name: quality 31 | data_type: categorical 32 | input_type: observed_input 33 | - name: HR 34 | data_type: real_valued 35 | input_type: observed_input 36 | - name: BR 37 | data_type: real_valued 38 | input_type: observed_input 39 | - name: Posture 40 | data_type: real_valued 41 | input_type: observed_input 42 | - name: Activity 43 | data_type: real_valued 44 | input_type: observed_input 45 | - name: HRV 46 | data_type: real_valued 47 | input_type: observed_input 48 | - name: CoreTemp 49 | data_type: real_valued 50 | input_type: observed_input 51 | 52 | # Drop 53 | drop: 54 | rows: null 55 | columns: 56 | id: 57 | - 9 58 | 59 | # NA values abbreviation 60 | nan_vals: null 61 | 62 | # Interpolation parameters 63 | interpolation_params: 64 | gap_threshold: 30 # in minutes 65 | min_drop_length: 240 # number of points 66 | 67 | # Splitting parameters 68 | split_params: 69 | test_percent_subjects: 0.1 70 | length_segment: 144 71 | random_state: 0 72 | 73 | # Encoding parameters 74 | encoding_params: 75 | date: 76 | - year 77 | - month 78 | - day 79 | - hour 80 | - minute 81 | 82 | # Scaling parameters 83 | scaling_params: 84 | scaler: None 85 | 86 | # Model params 87 | max_length_input: 192 88 | length_pred: 12 89 | 90 | linreg: 91 | in_len: 12 92 | 93 | linreg_covariates: 94 | in_len: 12 95 | 96 | tft: 97 | in_len: 168 98 | max_samples_per_ts: 50 99 | hidden_size: 240 100 | num_attention_heads: 2 101 | dropout: 0.24910705171945197 102 | lr: 0.003353965994113796 103 | batch_size: 64 104 | max_grad_norm: 0.999584070166802 105 | 106 | tft_covariates: 107 | in_len: 120 108 | max_samples_per_ts: 50 109 | hidden_size: 240 110 | num_attention_heads: 1 111 | dropout: 0.2354005483884536 112 | lr: 0.0014372065280028868 113 | batch_size: 32 114 | max_grad_norm: 0.08770929102027172 115 | 116 | xgboost: 117 | in_len: 168 118 | lr: 0.6910000000000001 119 | subsample: 0.8 120 | min_child_weight: 5.0 121 | colsample_bytree: 0.8 122 | max_depth: 10 123 | gamma: 0.5 124 | alpha: 0.201 125 | lambda_: 0.279 126 | n_estimators: 416 127 | 128 | xgboost_covariates: 129 | in_len: 36 130 | lr: 0.651 131 | subsample: 0.8 132 | min_child_weight: 2.0 133 | colsample_bytree: 1.0 134 | max_depth: 6 135 | gamma: 1.5 136 | alpha: 0.148 137 | lambda_: 0.094 138 | n_estimators: 480 139 | 140 | transformer: 141 | in_len: 108 142 | max_samples_per_ts: 50 143 | d_model: 32 144 | n_heads: 2 145 | num_encoder_layers: 1 146 | num_decoder_layers: 1 147 | dim_feedforward: 384 148 | dropout: 0.038691123579122515 149 | lr: 0.0004450217945481336 150 | batch_size: 32 151 | lr_epochs: 6 152 | max_grad_norm: 0.20863935142150056 153 | 154 | transformer_covariates: 155 | in_len: 156 156 | max_samples_per_ts: 50 157 | d_model: 64 158 | n_heads: 2 159 | num_encoder_layers: 2 160 | num_decoder_layers: 1 161 | dim_feedforward: 384 162 | dropout: 0.0026811942171770446 163 | lr: 0.000998963295875978 164 | batch_size: 48 165 | lr_epochs: 20 166 | max_grad_norm: 0.1004169110387992 167 | 168 | nhits: 169 | in_len: 108 170 | max_samples_per_ts: 50 171 | kernel_sizes: 3 172 | dropout: 0.06496948174462439 173 | lr: 0.0003359362814711015 174 | batch_size: 32 175 | lr_epochs: 2 176 | 177 | nhits_covariates: 178 | in_len: 120 179 | max_samples_per_ts: 50 180 | kernel_sizes: 2 181 | dropout: 0.16272090435698405 182 | lr: 0.0004806891979994542 183 | batch_size: 48 184 | lr_epochs: 12 185 | 186 | gluformer: 187 | in_len: 108 188 | max_samples_per_ts: 100 189 | d_model: 384 190 | n_heads: 8 191 | d_fcn: 1024 192 | num_enc_layers: 1 193 | num_dec_layers: 3 194 | 195 | latentode: 196 | in_len: 48 197 | max_samples_per_ts: 100 198 | latents: 20 199 | rec_dims: 40 200 | rec_layers: 3 201 | gen_layers: 3 202 | units: 100 203 | gru_units: 100 204 | 205 | 206 | -------------------------------------------------------------------------------- /config/hall.yaml: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ds_name: hall2018 3 | data_csv_path: ./raw_data/hall.csv 4 | #./raw_data/Hall2018_processed_akhil.csv 5 | index_col: -1 6 | observation_interval: 5min 7 | 8 | # Columns 9 | column_definition: 10 | - name: id 11 | data_type: categorical 12 | input_type: id 13 | - name: time 14 | data_type: date 15 | input_type: time 16 | - name: gl 17 | data_type: real_valued 18 | input_type: target 19 | - name: Age 20 | data_type: real_valued 21 | input_type: static_input 22 | - name: BMI 23 | data_type: real_valued 24 | input_type: static_input 25 | - name: A1C 26 | data_type: real_valued 27 | input_type: static_input 28 | - name: FBG 29 | data_type: real_valued 30 | input_type: static_input 31 | - name: ogtt.2hr 32 | data_type: real_valued 33 | input_type: static_input 34 | - name: insulin 35 | data_type: real_valued 36 | input_type: static_input 37 | - name: hs.CRP 38 | data_type: real_valued 39 | input_type: static_input 40 | - name: Tchol 41 | data_type: real_valued 42 | input_type: static_input 43 | - name: Trg 44 | data_type: real_valued 45 | input_type: static_input 46 | - name: HDL 47 | data_type: real_valued 48 | input_type: static_input 49 | - name: LDL 50 | data_type: real_valued 51 | input_type: static_input 52 | - name: mean_glucose 53 | data_type: real_valued 54 | input_type: static_input 55 | - name: sd_glucose 56 | data_type: real_valued 57 | input_type: static_input 58 | - name: range_glucose 59 | data_type: real_valued 60 | input_type: static_input 61 | - name: min_glucose 62 | data_type: real_valued 63 | input_type: static_input 64 | - name: max_glucose 65 | data_type: real_valued 66 | input_type: static_input 67 | - name: quartile.25_glucose 68 | data_type: real_valued 69 | input_type: static_input 70 | - name: median_glucose 71 | data_type: real_valued 72 | input_type: static_input 73 | - name: quartile.75_glucose 74 | data_type: real_valued 75 | input_type: static_input 76 | - name: mean_slope 77 | data_type: real_valued 78 | input_type: static_input 79 | - name: max_slope 80 | data_type: real_valued 81 | input_type: static_input 82 | - name: number_Random140 83 | data_type: real_valued 84 | input_type: static_input 85 | - name: number_Random200 86 | data_type: real_valued 87 | input_type: static_input 88 | - name: percent_below.80 89 | data_type: real_valued 90 | input_type: static_input 91 | - name: se_glucose_mean 92 | data_type: real_valued 93 | input_type: static_input 94 | - name: numGE 95 | data_type: real_valued 96 | input_type: static_input 97 | - name: mage 98 | data_type: real_valued 99 | input_type: static_input 100 | - name: j_index 101 | data_type: real_valued 102 | input_type: static_input 103 | - name: IQR 104 | data_type: real_valued 105 | input_type: static_input 106 | - name: modd 107 | data_type: real_valued 108 | input_type: static_input 109 | - name: distance_traveled 110 | data_type: real_valued 111 | input_type: static_input 112 | - name: coef_variation 113 | data_type: real_valued 114 | input_type: static_input 115 | - name: number_Random140_normByDays 116 | data_type: real_valued 117 | input_type: static_input 118 | - name: number_Random200_normByDays 119 | data_type: real_valued 120 | input_type: static_input 121 | - name: numGE_normByDays 122 | data_type: real_valued 123 | input_type: static_input 124 | - name: distance_traveled_normByDays 125 | data_type: real_valued 126 | input_type: static_input 127 | - name: diagnosis 128 | data_type: categorical 129 | input_type: static_input 130 | - name: freq_low 131 | data_type: real_valued 132 | input_type: static_input 133 | - name: freq_moderate 134 | data_type: real_valued 135 | input_type: static_input 136 | - name: freq_severe 137 | data_type: real_valued 138 | input_type: static_input 139 | - name: glucotype 140 | data_type: categorical 141 | input_type: static_input 142 | - name: Height 143 | data_type: real_valued 144 | input_type: static_input 145 | - name: Weight 146 | data_type: real_valued 147 | input_type: static_input 148 | - name: Insulin_rate_dd 149 | data_type: real_valued 150 | input_type: static_input 151 | - name: perc_cgm_prediabetic_range 152 | data_type: real_valued 153 | input_type: static_input 154 | - name: perc_cgm_diabetic_range 155 | data_type: real_valued 156 | input_type: static_input 157 | - name: SSPG 158 | data_type: real_valued 159 | input_type: static_input 160 | 161 | # Drop 162 | drop: 163 | rows: 164 | - 57309 165 | columns: null 166 | 167 | # NA values abbreviation 168 | nan_vals: NA 169 | 170 | # Interpolation parameters 171 | interpolation_params: 172 | gap_threshold: 30 # in minutes 173 | min_drop_length: 192 # number of points 174 | 175 | # Splitting parameters 176 | split_params: 177 | test_percent_subjects: 0.1 178 | length_segment: 192 179 | random_state: 0 180 | 181 | # Encoding parameters 182 | encoding_params: 183 | date: 184 | - year 185 | - month 186 | - day 187 | - hour 188 | - minute 189 | 190 | # Scaling parameters 191 | scaling_params: 192 | scaler: None 193 | 194 | # Model params 195 | max_length_input: 144 196 | length_pred: 12 197 | 198 | linreg: 199 | in_len: 84 200 | 201 | linreg_covariates: 202 | in_len: 60 203 | 204 | tft: 205 | in_len: 96 206 | max_samples_per_ts: 50 207 | hidden_size: 160 208 | num_attention_heads: 2 209 | dropout: 0.12663651999137013 210 | lr: 0.0003909069464830342 211 | batch_size: 48 212 | max_grad_norm: 0.42691316697261855 213 | 214 | tft_covariates: 215 | in_len: 132 216 | max_samples_per_ts: 50 217 | hidden_size: 64 218 | num_attention_heads: 3 219 | dropout: 0.1514203549391074 220 | lr: 0.002278316839625157 221 | batch_size: 32 222 | max_grad_norm: 0.6617473571712074 223 | 224 | xgboost: 225 | in_len: 60 226 | lr: 0.515 227 | subsample: 0.9 228 | min_child_weight: 3.0 229 | colsample_bytree: 0.9 230 | max_depth: 6 231 | gamma: 2.0 232 | alpha: 0.099 233 | lambda_: 0.134 234 | n_estimators: 256 235 | 236 | xgboost_covariates: 237 | in_len: 120 238 | lr: 0.17200000000000001 239 | subsample: 0.7 240 | min_child_weight: 2.0 241 | colsample_bytree: 0.9 242 | max_depth: 6 243 | gamma: 1.0 244 | alpha: 0.167 245 | lambda_: 0.017 246 | n_estimators: 320 247 | 248 | transformer: 249 | in_len: 144 250 | max_samples_per_ts: 200 251 | d_model: 64 252 | n_heads: 4 253 | num_encoder_layers: 1 254 | num_decoder_layers: 1 255 | dim_feedforward: 96 256 | dropout: 0.014744750937083516 257 | lr: 0.00035186058101597097 258 | batch_size: 48 259 | lr_epochs: 14 260 | max_grad_norm: 0.43187285340924153 261 | 262 | transformer_covariates: 263 | in_len: 132 264 | max_samples_per_ts: 150 265 | d_model: 64 266 | n_heads: 4 267 | num_encoder_layers: 1 268 | num_decoder_layers: 3 269 | dim_feedforward: 192 270 | dropout: 0.1260638882066075 271 | lr: 0.0006944648317764303 272 | batch_size: 48 273 | lr_epochs: 4 274 | max_grad_norm: 0.22914229299130273 275 | 276 | nhits: 277 | in_len: 144 278 | max_samples_per_ts: 100 279 | kernel_sizes: 4 280 | dropout: 0.046869296882493555 281 | lr: 0.00011524084800602483 282 | batch_size: 48 283 | lr_epochs: 2 284 | 285 | nhits_covariates: 286 | in_len: 120 287 | max_samples_per_ts: 50 288 | kernel_sizes: 5 289 | dropout: 0.18679300209273494 290 | lr: 0.0004763622305085654 291 | batch_size: 48 292 | lr_epochs: 4 293 | 294 | gluformer: 295 | in_len: 96 296 | max_samples_per_ts: 200 297 | d_model: 384 298 | n_heads: 4 299 | d_fcn: 1024 300 | num_enc_layers: 1 301 | num_dec_layers: 1 302 | 303 | latentode: 304 | in_len: 48 305 | max_samples_per_ts: 100 306 | latents: 20 307 | rec_dims: 40 308 | rec_layers: 3 309 | gen_layers: 3 310 | units: 100 311 | gru_units: 100 -------------------------------------------------------------------------------- /config/iglu.yaml: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ds_name: iglu 3 | data_csv_path: ./raw_data/iglu.csv 4 | index_col: -1 5 | observation_interval: 5min 6 | 7 | # Columns 8 | column_definition: 9 | - name: id 10 | data_type: categorical 11 | input_type: id 12 | - name: time 13 | data_type: date 14 | input_type: time 15 | - name: gl 16 | data_type: real_valued 17 | input_type: target 18 | 19 | # Drop 20 | drop: null 21 | 22 | # NA values abbreviation 23 | nan_vals: null 24 | 25 | # Interpolation parameters 26 | interpolation_params: 27 | gap_threshold: 45 # in minutes 28 | min_drop_length: 240 # in number of points (20 hrs) 29 | 30 | # Splitting parameters 31 | split_params: 32 | test_percent_subjects: 0.1 33 | length_segment: 240 34 | random_state: 0 35 | 36 | # Encoding parameters 37 | encoding_params: 38 | date: 39 | - year 40 | - month 41 | - day 42 | - hour 43 | - minute 44 | - second 45 | 46 | # Scaling parameters 47 | scaling_params: 48 | scaler: None 49 | 50 | # Model params 51 | max_length_input: 192 # in number of points (16 hrs) 52 | length_pred: 12 # in number of points (predict 1 hr) 53 | 54 | linreg: 55 | in_len: 192 56 | 57 | linreg_covariates: 58 | in_len: 12 59 | 60 | tft: 61 | in_len: 168 62 | max_samples_per_ts: 50 63 | hidden_size: 80 64 | num_attention_heads: 4 65 | dropout: 0.12792080253276716 66 | lr: 0.003164601797779577 67 | batch_size: 32 68 | max_grad_norm: 0.5265925565310886 69 | 70 | tft_covariates: 71 | in_len: 96 72 | max_samples_per_ts: 50 73 | hidden_size: 80 74 | num_attention_heads: 3 75 | dropout: 0.22790916758695268 76 | lr: 0.005050238867376333 77 | batch_size: 32 78 | max_grad_norm: 0.026706367007025333 79 | 80 | xgboost: 81 | in_len: 84 82 | lr: 0.506 83 | subsample: 0.9 84 | min_child_weight: 2.0 85 | colsample_bytree: 0.8 86 | max_depth: 9 87 | gamma: 0.5 88 | alpha: 0.124 89 | lambda_: 0.089 90 | n_estimators: 416 91 | 92 | xgboost_covariates: 93 | in_len: 96 94 | lr: 0.387 95 | subsample: 0.8 96 | min_child_weight: 1.0 97 | colsample_bytree: 1.0 98 | max_depth: 8 99 | gamma: 1.0 100 | alpha: 0.199 101 | lambda_: 0.018000000000000002 102 | n_estimators: 288 103 | 104 | transformer: 105 | in_len: 96 106 | max_samples_per_ts: 50 107 | d_model: 96 108 | n_heads: 4 109 | num_encoder_layers: 4 110 | num_decoder_layers: 1 111 | dim_feedforward: 448 112 | dropout: 0.10161152207464333 113 | lr: 0.000840888489686657 114 | batch_size: 32 115 | lr_epochs: 16 116 | max_grad_norm: 0.6740479322943925 117 | 118 | transformer_covariates: 119 | in_len: 108 120 | max_samples_per_ts: 50 121 | d_model: 128 122 | n_heads: 2 123 | num_encoder_layers: 2 124 | num_decoder_layers: 2 125 | dim_feedforward: 160 126 | dropout: 0.044926981080245884 127 | lr: 0.00029632347559614453 128 | batch_size: 32 129 | lr_epochs: 20 130 | max_grad_norm: 0.8890169619043728 131 | 132 | nhits: 133 | in_len: 96 134 | max_samples_per_ts: 50 135 | kernel_sizes: 5 136 | dropout: 0.12695408586813234 137 | lr: 0.0004510532358403777 138 | batch_size: 64 139 | lr_epochs: 16 140 | 141 | nhits_covariates: 142 | in_len: 144 143 | max_samples_per_ts: 50 144 | kernel_sizes: 3 145 | dropout: 0.09469970402653531 146 | lr: 0.0009786650965760999 147 | batch_size: 32 148 | lr_epochs: 10 149 | 150 | gluformer: 151 | in_len: 96 152 | max_samples_per_ts: 100 153 | d_model: 512 154 | n_heads: 4 155 | d_fcn: 512 156 | num_enc_layers: 1 157 | num_dec_layers: 4 158 | 159 | latentode: 160 | in_len: 48 161 | max_samples_per_ts: 100 162 | latents: 20 163 | rec_dims: 40 164 | rec_layers: 3 165 | gen_layers: 3 166 | units: 100 167 | gru_units: 100 168 | 169 | -------------------------------------------------------------------------------- /config/weinstock.yaml: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ds_name: weinstock2016 3 | data_csv_path: ./raw_data/weinstock.csv 4 | index_col: -1 5 | observation_interval: 5min 6 | 7 | # Columns 8 | column_definition: 9 | - name: id 10 | data_type: categorical 11 | input_type: id 12 | - name: time 13 | data_type: date 14 | input_type: time 15 | - name: gl 16 | data_type: real_valued 17 | input_type: target 18 | - name: Height 19 | data_type: real_valued 20 | input_type: static_input 21 | - name: Weight 22 | data_type: real_valued 23 | input_type: static_input 24 | - name: Gender 25 | data_type: categorical 26 | input_type: static_input 27 | - name: Race 28 | data_type: categorical 29 | input_type: static_input 30 | - name: EduLevel 31 | data_type: categorical 32 | input_type: static_input 33 | - name: AnnualInc 34 | data_type: real_valued 35 | input_type: static_input 36 | - name: MaritalStatus 37 | data_type: categorical 38 | input_type: static_input 39 | - name: DaysWkEx 40 | data_type: real_valued 41 | input_type: static_input 42 | - name: DaysWkDrinkAlc 43 | data_type: real_valued 44 | input_type: static_input 45 | - name: DaysMonBingeAlc 46 | data_type: real_valued 47 | input_type: static_input 48 | - name: T1DDiagAge 49 | data_type: real_valued 50 | input_type: static_input 51 | - name: NumHospDKA 52 | data_type: real_valued 53 | input_type: static_input 54 | - name: NumSHSinceT1DDiag 55 | data_type: real_valued 56 | input_type: static_input 57 | - name: InsDeliveryMethod 58 | data_type: categorical 59 | input_type: static_input 60 | - name: UnitsInsTotal 61 | data_type: real_valued 62 | input_type: static_input 63 | - name: NumMeterCheckDay 64 | data_type: real_valued 65 | input_type: static_input 66 | - name: Aspirin 67 | data_type: real_valued 68 | input_type: static_input 69 | - name: Simvastatin 70 | data_type: real_valued 71 | input_type: static_input 72 | - name: Lisinopril 73 | data_type: real_valued 74 | input_type: static_input 75 | - name: "Vitamin D" 76 | data_type: real_valued 77 | input_type: static_input 78 | - name: "Multivitamin preparation" 79 | data_type: real_valued 80 | input_type: static_input 81 | - name: Omeprazole 82 | data_type: real_valued 83 | input_type: static_input 84 | - name: atorvastatin 85 | data_type: real_valued 86 | input_type: static_input 87 | - name: Synthroid 88 | data_type: real_valued 89 | input_type: static_input 90 | - name: "vitamin D3" 91 | data_type: real_valued 92 | input_type: static_input 93 | - name: Hypertension 94 | data_type: real_valued 95 | input_type: static_input 96 | - name: Hyperlipidemia 97 | data_type: real_valued 98 | input_type: static_input 99 | - name: Hypothyroidism 100 | data_type: real_valued 101 | input_type: static_input 102 | - name: Depression 103 | data_type: real_valued 104 | input_type: static_input 105 | - name: "Coronary artery disease" 106 | data_type: real_valued 107 | input_type: static_input 108 | - name: "Diabetic peripheral neuropathy" 109 | data_type: real_valued 110 | input_type: static_input 111 | - name: Dyslipidemia 112 | data_type: real_valued 113 | input_type: static_input 114 | - name: "Chronic kidney disease" 115 | data_type: real_valued 116 | input_type: static_input 117 | - name: Osteoporosis 118 | data_type: real_valued 119 | input_type: static_input 120 | - name: "Proliferative diabetic retinopathy" 121 | data_type: real_valued 122 | input_type: static_input 123 | - name: Hypercholesterolemia 124 | data_type: real_valued 125 | input_type: static_input 126 | - name: "Erectile dysfunction" 127 | data_type: real_valued 128 | input_type: static_input 129 | - name: "Type I diabetes mellitus" 130 | data_type: real_valued 131 | input_type: static_input 132 | 133 | # Drop 134 | drop: null 135 | 136 | # NA values abbreviation 137 | nan_vals: null 138 | 139 | # Interpolation parameters 140 | interpolation_params: 141 | gap_threshold: 45 142 | min_drop_length: 240 143 | 144 | # Splitting parameters 145 | split_params: 146 | test_percent_subjects: 0.1 147 | length_segment: 240 148 | random_state: 0 149 | 150 | # Encoding parameters 151 | encoding_params: 152 | date: 153 | - year 154 | - month 155 | - day 156 | - hour 157 | - minute 158 | 159 | # Scaling parameters 160 | scaling_params: 161 | scaler: None 162 | 163 | # Model params 164 | max_length_input: 192 165 | length_pred: 12 166 | 167 | linreg: 168 | in_len: 84 169 | 170 | linreg_covariates: 171 | in_len: 84 172 | 173 | tft: 174 | in_len: 132 175 | max_samples_per_ts: 200 176 | hidden_size: 96 177 | num_attention_heads: 3 178 | dropout: 0.14019930679548182 179 | lr: 0.003399303384204884 180 | batch_size: 48 181 | max_grad_norm: 0.9962755235072169 182 | 183 | tft_covariates: 184 | in_len: 108 185 | max_samples_per_ts: 50 186 | hidden_size: 112 187 | num_attention_heads: 2 188 | dropout: 0.1504541564537306 189 | lr: 0.0018430630797167395 190 | batch_size: 48 191 | max_grad_norm: 0.9530046023189843 192 | 193 | xgboost: 194 | in_len: 84 195 | lr: 0.722 196 | subsample: 0.9 197 | min_child_weight: 5.0 198 | colsample_bytree: 1.0 199 | max_depth: 10 200 | gamma: 0.5 201 | alpha: 0.271 202 | lambda_: 0.07100000000000001 203 | n_estimators: 416 204 | 205 | xgboost_covariates: 206 | in_len: 96 207 | lr: 0.48000000000000004 208 | subsample: 1.0 209 | min_child_weight: 2.0 210 | colsample_bytree: 0.9 211 | max_depth: 6 212 | gamma: 1.5 213 | alpha: 0.159 214 | lambda_: 0.025 215 | n_estimators: 320 216 | 217 | transformer: 218 | in_len: 96 219 | max_samples_per_ts: 50 220 | d_model: 128 221 | n_heads: 2 222 | num_encoder_layers: 2 223 | num_decoder_layers: 4 224 | dim_feedforward: 64 225 | dropout: 0.0017011626095738697 226 | lr: 0.0007790307889667749 227 | batch_size: 32 228 | lr_epochs: 4 229 | max_grad_norm: 0.4226615744655383 230 | 231 | transformer_covariates: 232 | in_len: 96 233 | max_samples_per_ts: 50 234 | d_model: 128 235 | n_heads: 4 236 | num_encoder_layers: 1 237 | num_decoder_layers: 4 238 | dim_feedforward: 448 239 | dropout: 0.1901296977134417 240 | lr: 0.000965351785309486 241 | batch_size: 48 242 | lr_epochs: 4 243 | max_grad_norm: 0.19219462323820113 244 | 245 | nhits: 246 | in_len: 96 247 | max_samples_per_ts: 200 248 | kernel_sizes: 4 249 | dropout: 0.12642017123585755 250 | lr: 0.00032840023694932384 251 | batch_size: 64 252 | lr_epochs: 16 253 | 254 | nhits_covariates: 255 | in_len: 96 256 | max_samples_per_ts: 50 257 | kernel_sizes: 3 258 | dropout: 0.10162895545943862 259 | lr: 0.0009200129411689094 260 | batch_size: 32 261 | lr_epochs: 2 262 | 263 | gluformer: 264 | in_len: 144 265 | max_samples_per_ts: 100 266 | d_model: 512 267 | n_heads: 8 268 | d_fcn: 1408 269 | num_enc_layers: 1 270 | num_dec_layers: 4 271 | 272 | latentode: 273 | in_len: 48 274 | max_samples_per_ts: 100 275 | latents: 20 276 | rec_dims: 40 277 | rec_layers: 3 278 | gen_layers: 3 279 | units: 100 280 | gru_units: 100 -------------------------------------------------------------------------------- /data_formatter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/data_formatter/__init__.py -------------------------------------------------------------------------------- /data_formatter/base.py: -------------------------------------------------------------------------------- 1 | '''Defines a generic data formatter for CGM data sets.''' 2 | import sys 3 | import warnings 4 | import numpy as np 5 | import pandas as pd 6 | import sklearn.preprocessing 7 | import data_formatter.types as types 8 | import data_formatter.utils as utils 9 | 10 | DataTypes = types.DataTypes 11 | InputTypes = types.InputTypes 12 | 13 | dict_data_type = {'categorical': DataTypes.CATEGORICAL, 14 | 'real_valued': DataTypes.REAL_VALUED, 15 | 'date': DataTypes.DATE} 16 | dict_input_type = {'target': InputTypes.TARGET, 17 | 'observed_input': InputTypes.OBSERVED_INPUT, 18 | 'known_input': InputTypes.KNOWN_INPUT, 19 | 'static_input': InputTypes.STATIC_INPUT, 20 | 'id': InputTypes.ID, 21 | 'time': InputTypes.TIME} 22 | 23 | 24 | class DataFormatter(): 25 | # Defines and formats data for the IGLU dataset. 26 | 27 | def __init__(self, cnf, study_file = None): 28 | """Initialises formatter.""" 29 | # load parameters from the config file 30 | self.params = cnf 31 | # write progress to file if specified 32 | self.study_file = study_file 33 | stdout = sys.stdout 34 | f = open(study_file, 'a') if study_file is not None else sys.stdout 35 | sys.stdout = f 36 | 37 | # load column definition 38 | print('-'*32) 39 | print('Loading column definition...') 40 | self.__process_column_definition() 41 | 42 | # check that column definition is valid 43 | print('Checking column definition...') 44 | self.__check_column_definition() 45 | 46 | # load data 47 | # check if data table has index col: -1 if not, index >= 0 if yes 48 | print('Loading data...') 49 | self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col'] 50 | # read data table 51 | self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col']) 52 | 53 | # drop columns / rows 54 | print('Dropping columns / rows...') 55 | self.__drop() 56 | 57 | # check NA values 58 | print('Checking for NA values...') 59 | self.__check_nan() 60 | 61 | # set data types in DataFrame to match column definition 62 | print('Setting data types...') 63 | self.__set_data_types() 64 | 65 | # drop columns / rows 66 | print('Dropping columns / rows...') 67 | self.__drop() 68 | 69 | # encode 70 | print('Encoding data...') 71 | self._encoding_params = self.params['encoding_params'] 72 | self.__encode() 73 | 74 | # interpolate 75 | print('Interpolating data...') 76 | self._interpolation_params = self.params['interpolation_params'] 77 | self._interpolation_params['interval_length'] = self.params['observation_interval'] 78 | self.__interpolate() 79 | 80 | # split data 81 | print('Splitting data...') 82 | self._split_params = self.params['split_params'] 83 | self._split_params['max_length_input'] = self.params['max_length_input'] 84 | self.__split_data() 85 | 86 | # scale 87 | print('Scaling data...') 88 | self._scaling_params = self.params['scaling_params'] 89 | self.__scale() 90 | 91 | print('Data formatting complete.') 92 | print('-'*32) 93 | if study_file is not None: 94 | f.close() 95 | sys.stdout = stdout 96 | 97 | 98 | def __process_column_definition(self): 99 | self._column_definition = [] 100 | for col in self.params['column_definition']: 101 | self._column_definition.append((col['name'], 102 | dict_data_type[col['data_type']], 103 | dict_input_type[col['input_type']])) 104 | 105 | def __check_column_definition(self): 106 | # check that there is unique ID column 107 | assert len([col for col in self._column_definition if col[2] == InputTypes.ID]) == 1, 'There must be exactly one ID column.' 108 | # check that there is unique time column 109 | assert len([col for col in self._column_definition if col[2] == InputTypes.TIME]) == 1, 'There must be exactly one time column.' 110 | # check that there is at least one target column 111 | assert len([col for col in self._column_definition if col[2] == InputTypes.TARGET]) >= 1, 'There must be at least one target column.' 112 | 113 | def __set_data_types(self): 114 | # set time column as datetime format in pandas 115 | for col in self._column_definition: 116 | if col[1] == DataTypes.DATE: 117 | self.data[col[0]] = pd.to_datetime(self.data[col[0]]) 118 | if col[1] == DataTypes.CATEGORICAL: 119 | self.data[col[0]] = self.data[col[0]].astype('category') 120 | if col[1] == DataTypes.REAL_VALUED: 121 | self.data[col[0]] = self.data[col[0]].astype(np.float32) 122 | 123 | def __check_nan(self): 124 | # delete rows where target, time, or id are na 125 | self.data = self.data.dropna(subset=[col[0] 126 | for col in self._column_definition 127 | if col[2] in [InputTypes.TARGET, InputTypes.TIME, InputTypes.ID]]) 128 | # assert that there are no na values in the data 129 | assert self.data.isna().sum().sum() == 0, 'There are NA values in the data even after dropping with missing time, glucose, or id.' 130 | 131 | def __drop(self): 132 | # drop columns that are not in the column definition 133 | self.data = self.data[[col[0] for col in self._column_definition]] 134 | # drop rows based on conditions set in the formatter 135 | if self.params['drop'] is not None: 136 | if self.params['drop']['rows'] is not None: 137 | # drop row at indices in the list self.params['drop']['rows'] 138 | self.data = self.data.drop(self.params['drop']['rows']) 139 | self.data = self.data.reset_index(drop=True) 140 | if self.params['drop']['columns'] is not None: 141 | for col in self.params['drop']['columns'].keys(): 142 | # drop rows where specified columns have values in the list self.params['drop']['columns'][col] 143 | self.data = self.data.loc[~self.data[col].isin(self.params['drop']['columns'][col])].copy() 144 | 145 | def __interpolate(self): 146 | self.data, self._column_definition = utils.interpolate(self.data, 147 | self._column_definition, 148 | **self._interpolation_params) 149 | 150 | def __split_data(self): 151 | if self.params['split_params']['test_percent_subjects'] == 0 or \ 152 | self.params['split_params']['length_segment'] == 0: 153 | print('\tNo splitting performed since test_percent_subjects or length_segment is 0.') 154 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = None, None, None, None 155 | self.train_data, self.val_data, self.test_data = self.data, None, None 156 | else: 157 | assert self.params['split_params']['length_segment'] > self.params['length_pred'], \ 158 | 'length_segment for test / val must be greater than length_pred.' 159 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data, 160 | self._column_definition, 161 | **self._split_params) 162 | self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \ 163 | self.data.iloc[self.val_idx], \ 164 | self.data.iloc[self.test_idx + self.test_idx_ood] 165 | 166 | def __encode(self): 167 | self.data, self._column_definition, self.encoders = utils.encode(self.data, 168 | self._column_definition, 169 | **self._encoding_params) 170 | 171 | def __scale(self): 172 | self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data, 173 | self.val_data, 174 | self.test_data, 175 | self._column_definition, 176 | **self.params['scaling_params']) 177 | 178 | def reshuffle(self, seed): 179 | stdout = sys.stdout 180 | f = open(self.study_file, 'a') 181 | sys.stdout = f 182 | self.params['split_params']['random_state'] = seed 183 | # split data 184 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data, 185 | self._column_definition, 186 | **self._split_params) 187 | self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \ 188 | self.data.iloc[self.val_idx], \ 189 | self.data.iloc[self.test_idx+self.test_idx_ood] 190 | # re-scale data 191 | self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data, 192 | self.val_data, 193 | self.test_data, 194 | self._column_definition, 195 | **self.params['scaling_params']) 196 | sys.stdout = stdout 197 | f.close() 198 | 199 | def get_column(self, column_name): 200 | # write cases for time, id, target, future, static, dynamic covariates 201 | if column_name == 'time': 202 | return [col[0] for col in self._column_definition if col[2] == InputTypes.TIME][0] 203 | elif column_name == 'id': 204 | return [col[0] for col in self._column_definition if col[2] == InputTypes.ID][0] 205 | elif column_name == 'sid': 206 | return [col[0] for col in self._column_definition if col[2] == InputTypes.SID][0] 207 | elif column_name == 'target': 208 | return [col[0] for col in self._column_definition if col[2] == InputTypes.TARGET] 209 | elif column_name == 'future_covs': 210 | future_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.KNOWN_INPUT] 211 | return future_covs if len(future_covs) > 0 else None 212 | elif column_name == 'static_covs': 213 | static_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.STATIC_INPUT] 214 | return static_covs if len(static_covs) > 0 else None 215 | elif column_name == 'dynamic_covs': 216 | dynamic_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.OBSERVED_INPUT] 217 | return dynamic_covs if len(dynamic_covs) > 0 else None 218 | else: 219 | raise ValueError('Column {} not found.'.format(column_name)) 220 | 221 | -------------------------------------------------------------------------------- /data_formatter/types.py: -------------------------------------------------------------------------------- 1 | '''Defines data and input types of each column in the dataset.''' 2 | 3 | import enum 4 | 5 | class DataTypes(enum.IntEnum): 6 | """Defines numerical types of each column.""" 7 | REAL_VALUED = 0 8 | CATEGORICAL = 1 9 | DATE = 2 10 | 11 | class InputTypes(enum.IntEnum): 12 | """Defines input types of each column.""" 13 | TARGET = 0 14 | OBSERVED_INPUT = 1 15 | KNOWN_INPUT = 2 16 | STATIC_INPUT = 3 17 | ID = 4 # Single column used as an entity identifier 18 | SID = 5 # Single column used as a segment identifier 19 | TIME = 6 # Single column exclusively used as a time index -------------------------------------------------------------------------------- /exploratory_analysis/colas.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "oFh18yX0FrjN" 7 | }, 8 | "source": [ 9 | "# Loading libraries" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "eawH97OBFrjS" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import sys\n", 21 | "import os\n", 22 | "import yaml\n", 23 | "import pandas as pd\n", 24 | "import numpy as np\n", 25 | "sys.path.insert(1, '..')\n", 26 | "os.chdir('..')\n", 27 | "\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('whitegrid')\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import statsmodels.api as sm\n", 32 | "import sklearn\n", 33 | "import optuna\n", 34 | "\n", 35 | "from darts import models\n", 36 | "from darts import metrics\n", 37 | "from darts import TimeSeries\n", 38 | "from darts.dataprocessing.transformers import Scaler\n", 39 | "\n", 40 | "from statsforecast.models import AutoARIMA\n", 41 | "\n", 42 | "from data_formatter.base import *\n", 43 | "from bin.utils import *" 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Processing" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "filenames = []\n", 61 | "for root, dir, files in os.walk('raw_data/Colas2019'):\n", 62 | " for file in files:\n", 63 | " if '.csv' in file:\n", 64 | " filenames.append(os.path.join(root, file))\n", 65 | " \n", 66 | "# next we loop through each file\n", 67 | "nfiles = len(files)\n", 68 | "\n", 69 | "count = 0\n", 70 | "for file in filenames:\n", 71 | " # read in data and extract id from filename\n", 72 | " curr = pd.read_csv(file)\n", 73 | " curr['id'] = int(file.split()[1].split(\".\")[0])\n", 74 | " # select desired columns, rename, and drop nas\n", 75 | " curr = curr[['id', 'hora', 'glucemia']]\n", 76 | " curr.rename(columns = {'hora': 'time', 'glucemia': 'gl'}, inplace=True)\n", 77 | " curr.dropna(inplace=True)\n", 78 | "\n", 79 | " # calculate time (only given in hms) as follows:\n", 80 | " # (1) get the time per day in seconds, (2) get the time differences, and correct for the day crossove (< 0)\n", 81 | " # (3) take the cumulative sum and add the cumulative number of seconds from start to the base date\n", 82 | " # thus the hms are real, while the year, month, day are fake\n", 83 | " time_secs = []\n", 84 | " for i in curr['time']:\n", 85 | " time_secs.append(int(i.split(\":\")[0])*60*60 + int(i.split(\":\")[1])*60 + int(i.split(\":\")[2])*1)\n", 86 | " time_diff = np.diff(np.array(time_secs)).tolist()\n", 87 | " time_diff_adj = [x if x > 0 else 24*60*60 + x for x in time_diff]\n", 88 | " time_diff_adj.insert(0, 0)\n", 89 | " cumin = np.cumsum(time_diff_adj)\n", 90 | " datetime = pd.to_datetime('2012-01-01') + pd.to_timedelta(cumin, unit='sec')\n", 91 | " curr['time'] = datetime\n", 92 | " curr['id'] = curr['id'].astype('int')\n", 93 | " curr.reset_index(drop=True, inplace=True)\n", 94 | "\n", 95 | " if count == 0:\n", 96 | " df = curr\n", 97 | " count += 1\n", 98 | " else:\n", 99 | " df = pd.concat([df, curr])" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# join with covariates\n", 109 | "covariates = pd.read_csv('raw_data/Colas2019/clinical_data.txt', sep = \" \")\n", 110 | "covariates['id'] = covariates.index\n", 111 | "\n", 112 | "combined = pd.merge(\n", 113 | " df, covariates, how = \"left\"\n", 114 | ")\n", 115 | "\n", 116 | "# define NA fill values for covariates\n", 117 | "values = {\n", 118 | " 'gender': 2, # if gender is NA, create own category\n", 119 | " 'age': combined['age'].mean(),\n", 120 | " 'BMI': combined['BMI'].mean(),\n", 121 | " 'glycaemia': combined['glycaemia'].mean(),\n", 122 | " 'HbA1c': combined['HbA1c'].mean(),\n", 123 | " 'follow.up': combined['follow.up'].mean(),\n", 124 | " 'T2DM': False\n", 125 | "}\n", 126 | "combined = combined.fillna(value = values)\n", 127 | "\n", 128 | "# write to csv\n", 129 | "combined.to_csv('raw_data/colas.csv')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "WeAHZmAmFrjV" 136 | }, 137 | "source": [ 138 | "# Check statistics of the data" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "colab": { 146 | "base_uri": "https://localhost:8080/" 147 | }, 148 | "id": "pkOzK6gcFrjW", 149 | "outputId": "769510ff-79ba-4020-8d9c-dc78a7cdb7ff" 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "import matplotlib.pyplot as plt\n", 154 | "\n", 155 | "# load yaml config file\n", 156 | "with open('./config/colas.yaml', 'r') as f:\n", 157 | " config = yaml.safe_load(f)\n", 158 | "\n", 159 | "# set interpolation params for no interpolation\n", 160 | "new_config = config.copy()\n", 161 | "new_config['interpolation_params']['gap_threshold'] = 5\n", 162 | "new_config['interpolation_params']['min_drop_length'] = 0\n", 163 | "# set split params for no splitting\n", 164 | "new_config['split_params']['test_percent_subjects'] = 0\n", 165 | "new_config['split_params']['length_segment'] = 0\n", 166 | "# set scaling params for no scaling\n", 167 | "new_config['scaling_params']['scaler'] = 'None'\n", 168 | "\n", 169 | "formatter = DataFormatter(new_config)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "colab": { 177 | "base_uri": "https://localhost:8080/" 178 | }, 179 | "id": "eCBgEjuAFrjX", 180 | "outputId": "1d40e5fa-1fd5-45ea-ae93-41d14226a0c5" 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "# print min, max, median, mean, std of segment lengths\n", 185 | "segment_lens = []\n", 186 | "for group, data in formatter.train_data.groupby('id_segment'):\n", 187 | " segment_lens.append(len(data))\n", 188 | "print('Train segment lengths:')\n", 189 | "print('\\tMin: ', min(segment_lens))\n", 190 | "print('\\tMax: ', max(segment_lens))\n", 191 | "print('\\t1st Quartile: ', np.quantile(segment_lens, 0.25))\n", 192 | "print('\\tMedian: ', np.median(segment_lens))\n", 193 | "print('\\tMean: ', np.mean(segment_lens))\n", 194 | "print('\\tStd: ', np.std(segment_lens))\n", 195 | "\n", 196 | "# plot first 9 segments\n", 197 | "num_segments = 9\n", 198 | "plot_data = formatter.train_data\n", 199 | "\n", 200 | "fig, axs = plt.subplots(1, num_segments, figsize=(30, 5))\n", 201 | "for i, (group, data) in enumerate(plot_data.groupby('id_segment')):\n", 202 | " data.plot(x='time', y='gl', ax=axs[i], title='Segment {}'.format(group))\n", 203 | " if i >= num_segments - 1:\n", 204 | " break" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "colab": { 212 | "base_uri": "https://localhost:8080/", 213 | "height": 341 214 | }, 215 | "id": "iU2AUHTfFrjZ", 216 | "outputId": "ac25dfa6-4eee-4fc8-9c0c-11efa1b4fa14" 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "# plot acf of random samples from first 9 segments segments\n", 221 | "fig, ax = plt.subplots(2, num_segments, figsize=(30, 5))\n", 222 | "lags = 300; k = 0\n", 223 | "for i, (group, data) in enumerate(plot_data.groupby('id_segment')):\n", 224 | " data = data['gl']\n", 225 | " if len(data) < lags:\n", 226 | " print('Segment {} is too short'.format(group))\n", 227 | " continue\n", 228 | " else:\n", 229 | " # select 10 random samples from index of data\n", 230 | " sample = np.random.choice(range(len(data))[:-lags], 10, replace=False)\n", 231 | " # plot acf / pacf of each sample\n", 232 | " for j in sample:\n", 233 | " acf, acf_ci = sm.tsa.stattools.acf(data[j:j+lags], nlags=lags, alpha=0.05)\n", 234 | " pacf, pacf_ci = sm.tsa.stattools.pacf(data[j:j+lags], method='ols-adjusted', alpha=0.05)\n", 235 | " ax[0, k].plot(acf)\n", 236 | " ax[1, k].plot(pacf)\n", 237 | " k += 1\n", 238 | " if k >= num_segments:\n", 239 | " break" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "colab": { 245 | "provenance": [] 246 | }, 247 | "kernelspec": { 248 | "display_name": "base", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.6.9" 263 | }, 264 | "orig_nbformat": 4, 265 | "vscode": { 266 | "interpreter": { 267 | "hash": "ad2bdc8ecc057115af97d19610ffacc2b4e99fae6737bb82f5d7fb13d2f2c186" 268 | } 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 0 273 | } 274 | -------------------------------------------------------------------------------- /exploratory_analysis/iglu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Loading libraries" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# DANGER: only run 1x otherwise will chdir too many times\n", 17 | "import sys\n", 18 | "import os\n", 19 | "import yaml\n", 20 | "\n", 21 | "sys.path.insert(1, '..')\n", 22 | "os.chdir('..')\n", 23 | "\n", 24 | "import seaborn as sns\n", 25 | "sns.set_style('whitegrid')\n", 26 | "\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import statsmodels.api as sm\n", 29 | "import sklearn\n", 30 | "import optuna\n", 31 | "\n", 32 | "from darts import models, metrics, TimeSeries\n", 33 | "from darts.dataprocessing.transformers import Scaler\n", 34 | "\n", 35 | "from data_formatter.base import * # TODO: inefficient" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "# Check statistics of the data" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# load yaml config file\n", 52 | "with open('./config/iglu.yaml', 'r') as f:\n", 53 | " config = yaml.safe_load(f)\n", 54 | "\n", 55 | "# set interpolation params for no interpolation\n", 56 | "new_config = config.copy()\n", 57 | "new_config['interpolation_params']['gap_threshold'] = 30\n", 58 | "new_config['interpolation_params']['min_drop_length'] = 0\n", 59 | "# set split params for no splitting\n", 60 | "new_config['split_params']['test_percent_subjects'] = 0\n", 61 | "new_config['split_params']['length_segment'] = 0\n", 62 | "# set scaling params for no scaling\n", 63 | "new_config['scaling_params']['scaler'] = 'None'\n", 64 | "\n", 65 | "formatter = DataFormatter(new_config)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "%%capture\n", 75 | "\n", 76 | "# Need: Tradeoff between interpolation and segment length\n", 77 | "# Problem: Manually tuning is slow and potentially imprecise\n", 78 | "# Idea: have automated function that can help determine what the gap threshold should be\n", 79 | "# Proof of concept below\n", 80 | "\n", 81 | "import numpy as np\n", 82 | "\n", 83 | "def calc_percent(a, b):\n", 84 | " return a*100/b\n", 85 | "\n", 86 | "gap_threshold = np.arange(5, 70, 1)\n", 87 | "percent_valid = []\n", 88 | "for i in gap_threshold:\n", 89 | " new_config['interpolation_params']['gap_threshold'] = i\n", 90 | " df = DataFormatter(new_config).train_data\n", 91 | " \n", 92 | " segment_lens = []\n", 93 | " for group, data in df.groupby('id_segment'):\n", 94 | " segment_lens.append(len(data))\n", 95 | " \n", 96 | " threshold = 240\n", 97 | " valid_ids = df.groupby('id_segment')['time'].count().loc[lambda x : x>threshold].reset_index()['id_segment']\n", 98 | " \n", 99 | " percent_valid.append((len(valid_ids)*100/len(segment_lens)))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# Plot results\n", 109 | "plt.plot(gap_threshold, percent_valid)\n", 110 | "plt.title(\"Gap Threshold affect on % Segments > 240 Length\")\n", 111 | "plt.ylabel(\"% Above Threshhold\")\n", 112 | "plt.xlabel(\"Gap Threshold (min)\")" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# print min, max, median, mean, std of segment lengths\n", 122 | "df = formatter.train_data\n", 123 | "segment_lens = []\n", 124 | "for group, data in df.groupby('id_segment'):\n", 125 | " segment_lens.append(len(data))\n", 126 | "\n", 127 | "print('Train segment lengths:')\n", 128 | "print('\\tMin: ', min(segment_lens))\n", 129 | "print('\\tMax: ', max(segment_lens))\n", 130 | "print('\\tMedian: ', np.median(segment_lens))\n", 131 | "print('\\tMean: ', np.mean(segment_lens))\n", 132 | "print('\\tStd: ', np.std(segment_lens))\n", 133 | "\n", 134 | "# Visualize segment lengths to see approx # of valid ones (>240)\n", 135 | "plt.title(\"Segment Lengths (Line at 240)\")\n", 136 | "plt.hist(segment_lens)\n", 137 | "plt.axvline(240, color='r', linestyle='dashed', linewidth=1)\n", 138 | "\n", 139 | "# filter to get valid indices\n", 140 | "threshold = 240\n", 141 | "valid_ids = df.groupby('id_segment')['time'].count().loc[lambda x : x>threshold].reset_index()['id_segment']\n", 142 | "df_filtered = df.loc[df['id_segment'].isin(valid_ids)]\n", 143 | "\n", 144 | "# plot each segment\n", 145 | "num_segments = df_filtered['id_segment'].nunique()\n", 146 | "\n", 147 | "fig, axs = plt.subplots(1, num_segments, figsize=(30, 5))\n", 148 | "for i, (group, data) in enumerate(df_filtered.groupby('id_segment')):\n", 149 | " data.plot(x='time', y='gl', ax=axs[i], title='Segment {}'.format(group))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "df.head(10)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# plot acf of random samples from segments\n", 168 | "fig, ax = plt.subplots(2, 5, figsize=(30, 5))\n", 169 | "lags = 240\n", 170 | "for i, (group, data) in enumerate(df_filtered.groupby('id_segment')):\n", 171 | " # only view top 5\n", 172 | " if i < 5:\n", 173 | " data = data['gl']\n", 174 | " if len(data) < lags: # TODO: Could probably do filtering in pandas which would be faster\n", 175 | " print('Segment {} is too short'.format(group))\n", 176 | " continue\n", 177 | " # select 10 random samples from index of data\n", 178 | " sample = np.random.choice(range(len(data))[:-lags], 10, replace=False)\n", 179 | " # plot acf / pacf of each sample\n", 180 | " for j in sample:\n", 181 | " acf, acf_ci = sm.tsa.stattools.acf(data[j:j+lags], nlags=lags, alpha=0.05)\n", 182 | " pacf, pacf_ci = sm.tsa.stattools.pacf(data[j:j+lags], method='ols-adjusted', alpha=0.05)\n", 183 | " ax[0, i].plot(acf)\n", 184 | " ax[1, i].plot(pacf)\n" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "Python 3 (ipykernel)", 191 | "language": "python", 192 | "name": "python3" 193 | }, 194 | "language_info": { 195 | "codemirror_mode": { 196 | "name": "ipython", 197 | "version": 3 198 | }, 199 | "file_extension": ".py", 200 | "mimetype": "text/x-python", 201 | "name": "python", 202 | "nbconvert_exporter": "python", 203 | "pygments_lexer": "ipython3", 204 | "version": "3.10.1" 205 | }, 206 | "vscode": { 207 | "interpreter": { 208 | "hash": "95662931fb0811c75e2373330a012ba90aa4548ba779055436524c46bd94b0ee" 209 | } 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 2 214 | } 215 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/__init__.py -------------------------------------------------------------------------------- /lib/arima.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Dict 2 | import sys 3 | import os 4 | import yaml 5 | import datetime 6 | import argparse 7 | 8 | from statsforecast.models import AutoARIMA 9 | 10 | # import data formatter 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | from data_formatter.base import * 13 | 14 | def test_model(test_data, scaler, in_len, out_len, stride, target_col, group_col): 15 | errors = [] 16 | for group, data in test_data.groupby(group_col): 17 | train_set = data[target_col].iloc[:in_len].values.flatten() 18 | # fit model 19 | model = AutoARIMA(start_p = 0, 20 | max_p = 10, 21 | start_q = 0, 22 | max_q = 10, 23 | start_P = 0, 24 | max_P = 10, 25 | start_Q=0, 26 | max_Q=10, 27 | allowdrift=True, 28 | allowmean=True, 29 | parallel=False) 30 | model.fit(train_set) 31 | # get valid sampling locations for future prediction 32 | start_idx = np.arange(start=stride, stop=len(data) - in_len - out_len + 1, step=stride) 33 | end_idx = start_idx + in_len 34 | # iterate and collect predictions 35 | for i in range(len(start_idx)): 36 | input = data[target_col].iloc[start_idx[i]:end_idx[i]].values.flatten() 37 | true = data[target_col].iloc[end_idx[i]:(end_idx[i]+out_len)].values.flatten() 38 | prediction = model.forward(input, h=out_len)['mean'] 39 | # unscale true and prediction 40 | true = scaler.inverse_transform(true.reshape(-1, 1)).flatten() 41 | prediction = scaler.inverse_transform(prediction.reshape(-1, 1)).flatten() 42 | # collect errors 43 | errors.append(np.array([np.mean((true - prediction)**2), np.mean(np.abs(true - prediction))])) 44 | errors = np.vstack(errors) 45 | return errors 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--dataset', type=str, default='weinstock') 49 | parser.add_argument('--reduction1', type=str, default='mean') 50 | parser.add_argument('--reduction2', type=str, default='median') 51 | parser.add_argument('--reduction3', type=str, default=None) 52 | args = parser.parse_args() 53 | reductions = [args.reduction1, args.reduction2, args.reduction3] 54 | if __name__ == '__main__': 55 | # study file 56 | study_file = f'./output/arima_{args.dataset}.txt' 57 | # check that file exists otherwise create it 58 | if not os.path.exists(study_file): 59 | with open(study_file, "w") as f: 60 | # write current date and time 61 | f.write(f"Optimization started at {datetime.datetime.now()}\n") 62 | # load data 63 | with open(f'./config/{args.dataset}.yaml', 'r') as f: 64 | config = yaml.safe_load(f) 65 | config['scaling_params']['scaler'] = 'MinMaxScaler' 66 | formatter = DataFormatter(config, study_file = study_file) 67 | 68 | # set params 69 | in_len = formatter.params['max_length_input'] 70 | out_len = formatter.params['length_pred'] 71 | stride = formatter.params['length_pred'] // 2 72 | target_col = formatter.get_column('target') 73 | group_col = formatter.get_column('sid') 74 | 75 | seeds = list(range(10, 20)) 76 | id_errors_cv = {key: [] for key in reductions if key is not None} 77 | ood_errors_cv = {key: [] for key in reductions if key is not None} 78 | for seed in seeds: 79 | formatter.reshuffle(seed) 80 | test_data = formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)] 81 | test_data_ood = formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)] 82 | 83 | # backtest on the ID test set 84 | id_errors_sample = test_model(test_data, 85 | formatter.scalers[target_col[0]], 86 | in_len, 87 | out_len, 88 | stride, 89 | target_col, 90 | group_col) 91 | # backtest on the ood test set 92 | ood_errors_sample = test_model(test_data_ood, 93 | formatter.scalers[target_col[0]], 94 | in_len, 95 | out_len, 96 | stride, 97 | target_col, 98 | group_col) 99 | # compute, save, and print results 100 | with open(study_file, "a") as f: 101 | for reduction in reductions: 102 | if reduction is not None: 103 | # compute 104 | reduction_f = getattr(np, reduction) 105 | id_errors_sample_red = reduction_f(id_errors_sample, axis=0) 106 | ood_errors_sample_red = reduction_f(ood_errors_sample, axis=0) 107 | # save 108 | id_errors_cv[reduction].append(id_errors_sample_red) 109 | ood_errors_cv[reduction].append(ood_errors_sample_red) 110 | # print 111 | f.write(f"\tSeed: {seed} ID {reduction} of (MSE, MAE): {id_errors_sample_red.tolist()}\n") 112 | f.write(f"\tSeed: {seed} OOD {reduction} of (MSE, MAE) stats: {ood_errors_sample_red.tolist()}\n") 113 | # compute, save, and print results 114 | with open(study_file, "a") as f: 115 | for reduction in reductions: 116 | if reduction is not None: 117 | # compute 118 | id_errors_cv[reduction] = np.vstack(id_errors_cv[reduction]) 119 | ood_errors_cv[reduction] = np.vstack(ood_errors_cv[reduction]) 120 | id_errors_cv[reduction] = np.mean(id_errors_cv[reduction], axis=0) 121 | ood_errors_cv[reduction] = np.mean(ood_errors_cv[reduction], axis=0) 122 | # print 123 | f.write(f"ID {reduction} of (MSE, MAE): {id_errors_cv[reduction].tolist()}\n") 124 | f.write(f"OOD {reduction} of (MSE, MAE): {ood_errors_cv[reduction].tolist()}\n") -------------------------------------------------------------------------------- /lib/gluformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/gluformer/__init__.py -------------------------------------------------------------------------------- /lib/gluformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from math import sqrt 6 | 7 | class CausalConv1d(torch.nn.Conv1d): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | dilation=1, 14 | groups=1, 15 | bias=True): 16 | self.__padding = (kernel_size - 1) * dilation 17 | 18 | super(CausalConv1d, self).__init__( 19 | in_channels, 20 | out_channels, 21 | kernel_size=kernel_size, 22 | stride=stride, 23 | padding=self.__padding, 24 | dilation=dilation, 25 | groups=groups, 26 | bias=bias) 27 | 28 | def forward(self, input): 29 | result = super(CausalConv1d, self).forward(input) 30 | if self.__padding != 0: 31 | return result[:, :, :-self.__padding] 32 | return result 33 | 34 | class TriangularCausalMask(): 35 | def __init__(self, b, n, device="cpu"): 36 | mask_shape = [b, 1, n, n] 37 | with torch.no_grad(): 38 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 39 | 40 | @property 41 | def mask(self): 42 | return self._mask 43 | 44 | class MultiheadAttention(nn.Module): 45 | def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1): 46 | super(MultiheadAttention, self).__init__() 47 | self.h, self.d, self.mask_flag= n_heads, d_keys, mask_flag 48 | self.proj_q = nn.Linear(d_model, self.h * self.d) 49 | self.proj_k = nn.Linear(d_model, self.h * self.d) 50 | self.proj_v = nn.Linear(d_model, self.h * self.d) 51 | self.proj_out = nn.Linear(self.h * self.d, d_model) 52 | self.dropout = nn.Dropout(r_att_drop) 53 | 54 | def forward(self, q, k, v): 55 | b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d 56 | 57 | q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v) # b, n_*, h*d 58 | q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v]) # b, n_*, h, d 59 | scores = torch.einsum('bnhd,bmhd->bhnm', (q,k)) # b, h, n_q, n_k 60 | 61 | if self.mask_flag: 62 | att_mask = TriangularCausalMask(b, n_q, device=q.device) 63 | scores.masked_fill_(att_mask.mask, -np.inf) 64 | 65 | att = F.softmax(scores / (self.d ** .5), dim=-1) # b, h, n_q, n_k 66 | att = self.dropout(att) 67 | att_out = torch.einsum('bhnm,bmhd->bnhd', (att,v)) # b, n_q, h, d 68 | att_out = att_out.reshape(b, -1, h*d) # b, n_q, h*d 69 | out = self.proj_out(att_out) # b, n_q, d_model 70 | return out 71 | -------------------------------------------------------------------------------- /lib/gluformer/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .attention import * 6 | 7 | class DecoderLayer(nn.Module): 8 | def __init__(self, self_att, cross_att, d_model, d_fcn, 9 | r_drop, activ="relu"): 10 | super(DecoderLayer, self).__init__() 11 | 12 | self.self_att = self_att 13 | self.cross_att = cross_att 14 | 15 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1) 16 | self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1) 17 | 18 | self.norm1 = nn.LayerNorm(d_model) 19 | self.norm2 = nn.LayerNorm(d_model) 20 | self.norm3 = nn.LayerNorm(d_model) 21 | 22 | self.dropout = nn.Dropout(r_drop) 23 | self.activ = F.relu if activ == "relu" else F.gelu 24 | 25 | def forward(self, x_dec, x_enc): 26 | x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec) 27 | x_dec = self.norm1(x_dec) 28 | 29 | x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc) 30 | res = x_dec = self.norm2(x_dec) 31 | 32 | res = self.dropout(self.activ(self.conv1(res.transpose(-1,1)))) 33 | res = self.dropout(self.conv2(res).transpose(-1,1)) 34 | 35 | return self.norm3(x_dec+res) 36 | 37 | class Decoder(nn.Module): 38 | def __init__(self, layers, norm_layer=None): 39 | super(Decoder, self).__init__() 40 | self.layers = nn.ModuleList(layers) 41 | self.norm = norm_layer 42 | 43 | def forward(self, x_dec, x_enc): 44 | for layer in self.layers: 45 | x_dec = layer(x_dec, x_enc) 46 | 47 | if self.norm is not None: 48 | x_dec = self.norm(x_dec) 49 | 50 | return x_dec 51 | -------------------------------------------------------------------------------- /lib/gluformer/embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, max_len=5000): 8 | super(PositionalEmbedding, self).__init__() 9 | # Compute the positional encodings once in log space. 10 | pos_emb = torch.zeros(max_len, d_model) 11 | pos_emb.require_grad = False 12 | 13 | position = torch.arange(0, max_len).unsqueeze(1) 14 | div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).exp() 15 | 16 | pos_emb[:, 0::2] = torch.sin(position * div_term) 17 | pos_emb[:, 1::2] = torch.cos(position * div_term) 18 | 19 | pos_emb = pos_emb.unsqueeze(0) 20 | self.register_buffer('pos_emb', pos_emb) 21 | 22 | def forward(self, x): 23 | return self.pos_emb[:, :x.size(1)] 24 | 25 | class TokenEmbedding(nn.Module): 26 | def __init__(self, d_model): 27 | super(TokenEmbedding, self).__init__() 28 | D_INP = 1 # one sequence 29 | self.conv = nn.Conv1d(in_channels=D_INP, out_channels=d_model, 30 | kernel_size=3, padding=1, padding_mode='circular') 31 | # nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='leaky_relu') 32 | 33 | def forward(self, x): 34 | x = self.conv(x.transpose(-1, 1)).transpose(-1, 1) 35 | return x 36 | 37 | class TemporalEmbedding(nn.Module): 38 | def __init__(self, d_model, num_features): 39 | super(TemporalEmbedding, self).__init__() 40 | self.embed = nn.Linear(num_features, d_model) 41 | 42 | def forward(self, x): 43 | return self.embed(x) 44 | 45 | class SubjectEmbedding(nn.Module): 46 | def __init__(self, d_model, num_features): 47 | super(SubjectEmbedding, self).__init__() 48 | self.id_embedding = nn.Linear(num_features, d_model) 49 | 50 | def forward(self, x): 51 | embed_x = self.id_embedding(x) 52 | 53 | return embed_x 54 | 55 | class DataEmbedding(nn.Module): 56 | def __init__(self, d_model, r_drop, num_dynamic_features, num_static_features): 57 | super(DataEmbedding, self).__init__() 58 | # note: d_model // 2 == 0 59 | self.value_embedding = TokenEmbedding(d_model) 60 | self.time_embedding = TemporalEmbedding(d_model, num_dynamic_features) # alternative: TimeFeatureEmbedding 61 | self.positional_embedding = PositionalEmbedding(d_model) 62 | self.subject_embedding = SubjectEmbedding(d_model, num_static_features) 63 | self.dropout = nn.Dropout(r_drop) 64 | 65 | def forward(self, x_id, x, x_mark): 66 | x = self.value_embedding(x) + self.positional_embedding(x) + self.time_embedding(x_mark) 67 | x_id = self.subject_embedding(x_id) 68 | x = torch.cat((x_id.unsqueeze(1), x), dim = 1) 69 | return self.dropout(x) 70 | -------------------------------------------------------------------------------- /lib/gluformer/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .attention import * 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, d_model): 9 | super(ConvLayer, self).__init__() 10 | self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model, 11 | kernel_size=3, padding=1, padding_mode='circular') 12 | self.norm = nn.BatchNorm1d(d_model) 13 | self.activ = nn.ELU() 14 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 15 | 16 | def forward(self, x): 17 | x = self.downConv(x.transpose(-1, 1)) 18 | x = self.norm(x) 19 | x = self.activ(x) 20 | x = self.maxPool(x) 21 | x = x.transpose(-1,1) 22 | return x 23 | 24 | class EncoderLayer(nn.Module): 25 | def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"): 26 | super(EncoderLayer, self).__init__() 27 | 28 | self.att = att 29 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1) 30 | self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1) 31 | self.norm1 = nn.LayerNorm(d_model) 32 | self.norm2 = nn.LayerNorm(d_model) 33 | self.dropout = nn.Dropout(r_drop) 34 | self.activ = F.relu if activ == "relu" else F.gelu 35 | 36 | def forward(self, x): 37 | new_x = self.att(x, x, x) 38 | x = x + self.dropout(new_x) 39 | 40 | res = x = self.norm1(x) 41 | res = self.dropout(self.activ(self.conv1(res.transpose(-1,1)))) 42 | res = self.dropout(self.conv2(res).transpose(-1,1)) 43 | 44 | return self.norm2(x+res) 45 | 46 | class Encoder(nn.Module): 47 | def __init__(self, enc_layers, conv_layers=None, norm_layer=None): 48 | super(Encoder, self).__init__() 49 | self.enc_layers = nn.ModuleList(enc_layers) 50 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 51 | self.norm = norm_layer 52 | 53 | def forward(self, x): 54 | # x [B, L, D] 55 | if self.conv_layers is not None: 56 | for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers): 57 | x = enc_layer(x) 58 | x = conv_layer(x) 59 | x = self.enc_layers[-1](x) 60 | else: 61 | for enc_layer in self.enc_layers: 62 | x = enc_layer(x) 63 | 64 | if self.norm is not None: 65 | x = self.norm(x) 66 | 67 | return x -------------------------------------------------------------------------------- /lib/gluformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/gluformer/utils/__init__.py -------------------------------------------------------------------------------- /lib/gluformer/utils/collate.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union, Mapping, Sequence 2 | import torch 3 | import re 4 | import collections 5 | 6 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 7 | 8 | # Replace torch._six.string_classes with this 9 | string_classes = (str, bytes) 10 | 11 | def default_convert(data: Any) -> Any: 12 | r"""Converts each NumPy array data field into a tensor""" 13 | elem_type = type(data) 14 | if isinstance(data, torch.Tensor): 15 | return data 16 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 17 | and elem_type.__name__ != 'string_': 18 | # array of string classes and object 19 | if elem_type.__name__ == 'ndarray' \ 20 | and np_str_obj_array_pattern.search(data.dtype.str) is not None: 21 | return data 22 | return torch.as_tensor(data) 23 | elif isinstance(data, collections.abc.Mapping): 24 | return {key: default_convert(data[key]) for key in data} 25 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple 26 | return elem_type(*(default_convert(d) for d in data)) 27 | elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes): 28 | return [default_convert(d) for d in data] 29 | else: 30 | return data 31 | 32 | default_collate_err_msg_format = ( 33 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 34 | "dicts or lists; found {}") 35 | 36 | def default_collate(batch: List[Any]) -> Union[torch.Tensor, List, Dict, Tuple]: 37 | r"""Puts each data field into a tensor with outer dimension batch size""" 38 | 39 | elem = batch[0] 40 | elem_type = type(elem) 41 | if isinstance(elem, torch.Tensor): 42 | out = None 43 | if torch.utils.data.get_worker_info() is not None: 44 | # If we're in a background process, concatenate directly into a 45 | # shared memory tensor to avoid an extra copy 46 | numel = sum(x.numel() for x in batch) 47 | storage = elem.storage()._new_shared(numel) 48 | out = elem.new(storage) 49 | return torch.stack(batch, 0, out=out) 50 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 51 | and elem_type.__name__ != 'string_': 52 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 53 | # array of string classes and object 54 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 55 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 56 | 57 | return default_collate([torch.as_tensor(b) for b in batch]) 58 | elif elem.shape == (): # scalars 59 | return torch.as_tensor(batch) 60 | elif isinstance(elem, float): 61 | return torch.tensor(batch, dtype=torch.float64) 62 | elif isinstance(elem, int): 63 | return torch.tensor(batch) 64 | elif isinstance(elem, string_classes): 65 | return batch 66 | elif isinstance(elem, Mapping): 67 | return {key: default_collate([d[key] for d in batch]) for key in elem} 68 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 69 | return elem_type(*(default_collate(samples) for samples in zip(*batch))) 70 | elif isinstance(elem, Sequence): 71 | # check to make sure that the elements in batch have consistent size 72 | it = iter(batch) 73 | elem_size = len(next(it)) 74 | if not all(len(elem) == elem_size for elem in it): 75 | raise RuntimeError('each element in list of batch should be of equal size') 76 | transposed = zip(*batch) 77 | return [default_collate(samples) for samples in transposed] 78 | 79 | raise TypeError(default_collate_err_msg_format.format(elem_type)) -------------------------------------------------------------------------------- /lib/gluformer/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | import random 5 | from typing import Any, \ 6 | BinaryIO, \ 7 | Callable, \ 8 | Dict, \ 9 | List, \ 10 | Optional, \ 11 | Sequence, \ 12 | Tuple, \ 13 | Union 14 | 15 | import numpy as np 16 | import scipy as sp 17 | import pandas as pd 18 | import torch 19 | 20 | # import data formatter 21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 22 | 23 | def test(series: np.ndarray, 24 | forecasts: np.ndarray, 25 | var: np.ndarray, 26 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11), 27 | ): 28 | """ 29 | Test the (rescaled to original scale) forecasts on the series. 30 | 31 | Parameters 32 | ---------- 33 | series 34 | The target time series of shape (n, t), 35 | where t is length of prediction. 36 | forecasts 37 | The forecasted means of mixture components of shape (n, t, k), 38 | where k is the number of mixture components. 39 | var 40 | The forecasted variances of mixture components of shape (n, 1, k), 41 | where k is the number of mixture components. 42 | metric 43 | The metric or metrics to use for backtesting. 44 | cal_thresholds 45 | The thresholds to use for computing the calibration error. 46 | 47 | Returns 48 | ------- 49 | np.ndarray 50 | Error array. Array of shape (n, p) 51 | where n = series.shape[0] = forecasts.shape[0] and p = len(metric). 52 | float 53 | The estimated log-likelihood of the model on the data. 54 | np.ndarray 55 | The ECE for each time point in the forecast. 56 | """ 57 | # compute errors: 1) get samples 2) compute errors using median 58 | samples = np.random.normal(loc=forecasts[..., None], 59 | scale=np.sqrt(var)[..., None], 60 | size=(forecasts.shape[0], 61 | forecasts.shape[1], 62 | forecasts.shape[2], 63 | 30)) 64 | samples = samples.reshape(samples.shape[0], samples.shape[1], -1) 65 | mse = np.mean((series.squeeze() - forecasts.mean(axis=-1))**2, axis=-1) 66 | mae = np.mean(np.abs(series.squeeze() - forecasts.mean(axis=-1)), axis=-1) 67 | errors = np.stack([mse, mae], axis=-1) 68 | 69 | # compute likelihood 70 | log_likelihood = sp.special.logsumexp((forecasts - series)**2 / (2 * var) - 71 | 0.5 * np.log(2 * np.pi * var), axis=-1) 72 | log_likelihood = np.mean(log_likelihood) 73 | 74 | # compute calibration error: 75 | cal_error = np.zeros(forecasts.shape[1]) 76 | for p in cal_thresholds: 77 | q = np.quantile(samples, p, axis=-1) 78 | est_p = np.mean(series.squeeze() <= q, axis=0) 79 | cal_error += (est_p - p) ** 2 80 | 81 | return errors, log_likelihood, cal_error -------------------------------------------------------------------------------- /lib/gluformer/utils/training.py: -------------------------------------------------------------------------------- 1 | from re import sub 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | from .collate import * 7 | 8 | class EarlyStop: 9 | def __init__(self, patience, delta): 10 | self.patience = patience 11 | self.delta = delta 12 | self.counter = 0 13 | self.best_loss = np.Inf 14 | self.stop = False 15 | 16 | def __call__(self, loss, model, path): 17 | if loss < self.best_loss: 18 | self.best_loss = loss 19 | self.counter = 0 20 | torch.save(model.state_dict(), path) 21 | elif loss > self.best_loss + self.delta: 22 | self.counter = self.counter + 1 23 | if self.counter >= self.patience: 24 | self.stop = True 25 | 26 | class ExpLikeliLoss(nn.Module): 27 | def __init__(self, num_samples = 100): 28 | # , var = 0.3 29 | super(ExpLikeliLoss, self).__init__() 30 | self.num_samples = num_samples 31 | 32 | def forward(self, pred, true, logvar): 33 | # pred & true: [b, l, d] 34 | b, l, d = pred.size(0), pred.size(1), pred.size(2) 35 | true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1) 36 | pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1) 37 | logvar = logvar.reshape(-1, self.num_samples) 38 | 39 | loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1)) 40 | return loss 41 | 42 | def modify_collate(num_samples): 43 | ''' 44 | Repeat each sample in the dataset. 45 | ''' 46 | def wrapper(batch): 47 | batch_rep = [sample for sample in batch for i in range(num_samples)] 48 | return default_collate(batch_rep) 49 | 50 | return wrapper 51 | 52 | def adjust_learning_rate(model_optim, epoch, lr): 53 | lr = lr * (0.5 ** epoch) 54 | print("Learning rate halfing...") 55 | print(f"New lr: {lr:.7f}") 56 | for param_group in model_optim.param_groups: 57 | param_group['lr'] = lr 58 | 59 | def process_batch(subj_id, 60 | batch_x, batch_y, 61 | batch_x_mark, batch_y_mark, 62 | len_pred, len_label, 63 | model, device): 64 | # read data 65 | subj_id = subj_id.long().to(device) 66 | batch_x = batch_x.float().to(device) 67 | batch_y = batch_y.float() 68 | batch_x_mark = batch_x_mark.float().to(device) 69 | batch_y_mark = batch_y_mark.float().to(device) 70 | 71 | # extract true 72 | true = batch_y[:, -len_pred:, :].to(device) 73 | 74 | # decoder input 75 | dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]]).float() 76 | dec_inp = torch.cat([batch_y[:, :len_label, :], dec_inp], dim=1).float().to(device) 77 | 78 | # model prediction 79 | # pred = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark) 80 | pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark) 81 | 82 | # clean cache 83 | del subj_id 84 | del batch_x 85 | del batch_x_mark 86 | del dec_inp 87 | del batch_y_mark 88 | return pred, true, logvar -------------------------------------------------------------------------------- /lib/gluformer/variance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Variance(nn.Module): 6 | def __init__(self, d_model, r_drop, len_seq): 7 | super(Variance, self).__init__() 8 | 9 | self.proj1 = nn.Linear(d_model, 1) 10 | self.dropout = nn.Dropout(r_drop) 11 | self.activ1 = nn.ReLU() 12 | # + 1 (for seq) for embedded person token 13 | self.proj2 = nn.Linear(len_seq+1, 1) 14 | self.activ2 = nn.Tanh() 15 | 16 | def forward(self, x): 17 | x = self.proj1(x) 18 | x = self.activ1(x) 19 | x = self.dropout(x) 20 | x = x.transpose(-1, 1) 21 | x = self.proj2(x) 22 | # scale to [-10, 10] range 23 | x = 10 * self.activ2(x) 24 | return x -------------------------------------------------------------------------------- /lib/latent_ode/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yulia Rubanova 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/latent_ode/README.md: -------------------------------------------------------------------------------- 1 | # Latent ODEs for Irregularly-Sampled Time Series 2 | 3 | Code for the paper: 4 | > Yulia Rubanova, Ricky Chen, David Duvenaud. "Latent ODEs for Irregularly-Sampled Time Series" (2019) 5 | [[arxiv]](https://arxiv.org/abs/1907.03907) 6 | 7 |

8 | 9 |

10 | 11 | ## Prerequisites 12 | 13 | Install `torchdiffeq` from https://github.com/rtqichen/torchdiffeq. 14 | 15 | ## Experiments on different datasets 16 | 17 | By default, the dataset are downloadeded and processed when script is run for the first time. 18 | 19 | Raw datasets: 20 | [[MuJoCo]](http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt) 21 | [[Physionet]](https://physionet.org/physiobank/database/challenge/2012/) 22 | [[Human Activity]](https://archive.ics.uci.edu/ml/datasets/Localization+Data+for+Person+Activity/) 23 | 24 | To generate MuJoCo trajectories from scratch, [DeepMind Control Suite](https://github.com/deepmind/dm_control/) is required 25 | 26 | 27 | * Toy dataset of 1d periodic functions 28 | ``` 29 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01 30 | ``` 31 | 32 | * MuJoCo 33 | 34 | ``` 35 | python3 run_models.py --niters 300 -n 10000 -l 15 --dataset hopper --latent-ode --rec-dims 30 --gru-units 100 --units 300 --gen-layers 3 --rec-layers 3 36 | ``` 37 | 38 | * Physionet (discretization by 1 min) 39 | ``` 40 | python3 run_models.py --niters 100 -n 8000 -l 20 --dataset physionet --latent-ode --rec-dims 40 --rec-layers 3 --gen-layers 3 --units 50 --gru-units 50 --quantization 0.016 --classif 41 | 42 | ``` 43 | 44 | * Human Activity 45 | ``` 46 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif 47 | 48 | ``` 49 | 50 | 51 | ### Running different models 52 | 53 | * ODE-RNN 54 | ``` 55 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --ode-rnn 56 | ``` 57 | 58 | * Latent ODE with ODE-RNN encoder 59 | ``` 60 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode 61 | ``` 62 | 63 | * Latent ODE with ODE-RNN encoder and poisson likelihood 64 | ``` 65 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --poisson 66 | ``` 67 | 68 | * Latent ODE with RNN encoder (Chen et al, 2018) 69 | ``` 70 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --z0-encoder rnn 71 | ``` 72 | 73 | * RNN-VAE 74 | ``` 75 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --rnn-vae 76 | ``` 77 | 78 | * Classic RNN 79 | ``` 80 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --classic-rnn 81 | ``` 82 | 83 | * GRU-D 84 | 85 | GRU-D consists of two parts: input imputation (--input-decay) and exponential decay of the hidden state (--rnn-cell expdecay) 86 | 87 | ``` 88 | python3 run_models.py --niters 500 -n 100 -b 30 -l 10 --dataset periodic --classic-rnn --input-decay --rnn-cell expdecay 89 | ``` 90 | 91 | 92 | ### Making the visualization 93 | ``` 94 | python3 run_models.py --niters 100 -n 5000 -b 100 -l 3 --dataset periodic --latent-ode --noise-weight 0.5 --lr 0.01 --viz --rec-layers 2 --gen-layers 2 -u 100 -c 30 95 | ``` 96 | -------------------------------------------------------------------------------- /lib/latent_ode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/latent_ode/__init__.py -------------------------------------------------------------------------------- /lib/latent_ode/base_models.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | 11 | from . import utils as utils 12 | from .encoder_decoder import * 13 | from .likelihood_eval import * 14 | 15 | from torch.distributions.multivariate_normal import MultivariateNormal 16 | from torch.distributions.normal import Normal 17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 18 | 19 | from torch.distributions.normal import Normal 20 | from torch.distributions import Independent 21 | from torch.nn.parameter import Parameter 22 | 23 | 24 | def create_classifier(z0_dim, n_labels): 25 | return nn.Sequential( 26 | nn.Linear(z0_dim, 300), 27 | nn.ReLU(), 28 | nn.Linear(300, 300), 29 | nn.ReLU(), 30 | nn.Linear(300, n_labels),) 31 | 32 | 33 | class Baseline(nn.Module): 34 | def __init__(self, input_dim, latent_dim, device, 35 | obsrv_std = 0.01, use_binary_classif = False, 36 | classif_per_tp = False, 37 | use_poisson_proc = False, 38 | linear_classifier = False, 39 | n_labels = 1, 40 | train_classif_w_reconstr = False): 41 | super(Baseline, self).__init__() 42 | 43 | self.input_dim = input_dim 44 | self.latent_dim = latent_dim 45 | self.n_labels = n_labels 46 | 47 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device) 48 | self.device = device 49 | 50 | self.use_binary_classif = use_binary_classif 51 | self.classif_per_tp = classif_per_tp 52 | self.use_poisson_proc = use_poisson_proc 53 | self.linear_classifier = linear_classifier 54 | self.train_classif_w_reconstr = train_classif_w_reconstr 55 | 56 | z0_dim = latent_dim 57 | if use_poisson_proc: 58 | z0_dim += latent_dim 59 | 60 | if use_binary_classif: 61 | if linear_classifier: 62 | self.classifier = nn.Sequential( 63 | nn.Linear(z0_dim, n_labels)) 64 | else: 65 | self.classifier = create_classifier(z0_dim, n_labels) 66 | utils.init_network_weights(self.classifier) 67 | 68 | 69 | def get_gaussian_likelihood(self, truth, pred_y, mask = None): 70 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 71 | # truth shape [n_traj, n_tp, n_dim] 72 | if mask is not None: 73 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 74 | 75 | # Compute likelihood of the data under the predictions 76 | log_density_data = masked_gaussian_log_density(pred_y, truth, 77 | obsrv_std = self.obsrv_std, mask = mask) 78 | log_density_data = log_density_data.permute(1,0) 79 | 80 | # Compute the total density 81 | # Take mean over n_traj_samples 82 | log_density = torch.mean(log_density_data, 0) 83 | 84 | # shape: [n_traj] 85 | return log_density 86 | 87 | 88 | def get_mse(self, truth, pred_y, mask = None): 89 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 90 | # truth shape [n_traj, n_tp, n_dim] 91 | if mask is not None: 92 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 93 | 94 | # Compute likelihood of the data under the predictions 95 | log_density_data = compute_mse(pred_y, truth, mask = mask) 96 | # shape: [1] 97 | return torch.mean(log_density_data) 98 | 99 | 100 | def compute_all_losses(self, batch_dict, 101 | n_tp_to_sample = None, n_traj_samples = 1, kl_coef = 1.): 102 | 103 | # Condition on subsampled points 104 | # Make predictions for all the points 105 | pred_x, info = self.get_reconstruction(batch_dict["tp_to_predict"], 106 | batch_dict["observed_data"], batch_dict["observed_tp"], 107 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples, 108 | mode = batch_dict["mode"]) 109 | 110 | # Compute likelihood of all the points 111 | likelihood = self.get_gaussian_likelihood(batch_dict["data_to_predict"], pred_x, 112 | mask = batch_dict["mask_predicted_data"]) 113 | 114 | mse = self.get_mse(batch_dict["data_to_predict"], pred_x, 115 | mask = batch_dict["mask_predicted_data"]) 116 | 117 | ################################ 118 | # Compute CE loss for binary classification on Physionet 119 | # Use only last attribute -- mortatility in the hospital 120 | device = get_device(batch_dict["data_to_predict"]) 121 | ce_loss = torch.Tensor([0.]).to(device) 122 | 123 | if (batch_dict["labels"] is not None) and self.use_binary_classif: 124 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1): 125 | ce_loss = compute_binary_CE_loss( 126 | info["label_predictions"], 127 | batch_dict["labels"]) 128 | else: 129 | ce_loss = compute_multiclass_CE_loss( 130 | info["label_predictions"], 131 | batch_dict["labels"], 132 | mask = batch_dict["mask_predicted_data"]) 133 | 134 | if torch.isnan(ce_loss): 135 | print("label pred") 136 | print(info["label_predictions"]) 137 | print("labels") 138 | print( batch_dict["labels"]) 139 | raise Exception("CE loss is Nan!") 140 | 141 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"])) 142 | if self.use_poisson_proc: 143 | pois_log_likelihood = compute_poisson_proc_likelihood( 144 | batch_dict["data_to_predict"], pred_x, 145 | info, mask = batch_dict["mask_predicted_data"]) 146 | # Take mean over n_traj 147 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1) 148 | 149 | loss = - torch.mean(likelihood) 150 | 151 | if self.use_poisson_proc: 152 | loss = loss - 0.1 * pois_log_likelihood 153 | 154 | if self.use_binary_classif: 155 | if self.train_classif_w_reconstr: 156 | loss = loss + ce_loss * 100 157 | else: 158 | loss = ce_loss 159 | 160 | # Take mean over the number of samples in a batch 161 | results = {} 162 | results["loss"] = torch.mean(loss) 163 | results["likelihood"] = torch.mean(likelihood).detach() 164 | results["mse"] = torch.mean(mse).detach() 165 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach() 166 | results["ce_loss"] = torch.mean(ce_loss).detach() 167 | results["kl"] = 0. 168 | results["kl_first_p"] = 0. 169 | results["std_first_p"] = 0. 170 | 171 | if batch_dict["labels"] is not None and self.use_binary_classif: 172 | results["label_predictions"] = info["label_predictions"].detach() 173 | return results 174 | 175 | 176 | 177 | class VAE_Baseline(nn.Module): 178 | def __init__(self, input_dim, latent_dim, 179 | z0_prior, device, 180 | obsrv_std = 0.01, 181 | use_binary_classif = False, 182 | classif_per_tp = False, 183 | use_poisson_proc = False, 184 | linear_classifier = False, 185 | n_labels = 1, 186 | train_classif_w_reconstr = False): 187 | 188 | super(VAE_Baseline, self).__init__() 189 | 190 | self.input_dim = input_dim 191 | self.latent_dim = latent_dim 192 | self.device = device 193 | self.n_labels = n_labels 194 | 195 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device) 196 | 197 | self.z0_prior = z0_prior 198 | self.use_binary_classif = use_binary_classif 199 | self.classif_per_tp = classif_per_tp 200 | self.use_poisson_proc = use_poisson_proc 201 | self.linear_classifier = linear_classifier 202 | self.train_classif_w_reconstr = train_classif_w_reconstr 203 | 204 | z0_dim = latent_dim 205 | if use_poisson_proc: 206 | z0_dim += latent_dim 207 | 208 | if use_binary_classif: 209 | if linear_classifier: 210 | self.classifier = nn.Sequential( 211 | nn.Linear(z0_dim, n_labels)) 212 | else: 213 | self.classifier = create_classifier(z0_dim, n_labels) 214 | utils.init_network_weights(self.classifier) 215 | 216 | 217 | def get_gaussian_likelihood(self, truth, pred_y, mask = None): 218 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 219 | # truth shape [n_traj, n_tp, n_dim] 220 | n_traj, n_tp, n_dim = truth.size() 221 | 222 | # Compute likelihood of the data under the predictions 223 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 224 | 225 | if mask is not None: 226 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 227 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated, 228 | obsrv_std = self.obsrv_std, mask = mask) 229 | log_density_data = log_density_data.permute(1,0) 230 | log_density = torch.mean(log_density_data, 1) 231 | 232 | # shape: [n_traj_samples] 233 | return log_density 234 | 235 | 236 | def get_mse(self, truth, pred_y, mask = None): 237 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] 238 | # truth shape [n_traj, n_tp, n_dim] 239 | n_traj, n_tp, n_dim = truth.size() 240 | 241 | # Compute likelihood of the data under the predictions 242 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 243 | 244 | if mask is not None: 245 | mask = mask.repeat(pred_y.size(0), 1, 1, 1) 246 | 247 | # Compute likelihood of the data under the predictions 248 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask) 249 | # shape: [1] 250 | return torch.mean(log_density_data) 251 | 252 | 253 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1.): 254 | # Condition on subsampled points 255 | # Make predictions for all the points 256 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"], 257 | batch_dict["observed_data"], batch_dict["observed_tp"], 258 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples, 259 | mode = batch_dict["mode"]) 260 | 261 | #print("get_reconstruction done -- computing likelihood") 262 | fp_mu, fp_std, fp_enc = info["first_point"] 263 | fp_std = fp_std.abs().clamp(min = 1e-2) 264 | fp_distr = Normal(fp_mu, fp_std) 265 | 266 | assert(torch.sum(fp_std < 0) == 0.) 267 | 268 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior) 269 | 270 | if torch.isnan(kldiv_z0).any(): 271 | print(fp_mu) 272 | print(fp_std) 273 | raise Exception("kldiv_z0 is Nan!") 274 | 275 | # Mean over number of latent dimensions 276 | # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated) 277 | # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly) 278 | # shape after: [n_traj_samples] 279 | kldiv_z0 = torch.mean(kldiv_z0,(1,2)) 280 | 281 | # Compute likelihood of all the points 282 | rec_likelihood = self.get_gaussian_likelihood( 283 | batch_dict["data_to_predict"], pred_y, 284 | mask = batch_dict["mask_predicted_data"]) 285 | 286 | mse = self.get_mse( 287 | batch_dict["data_to_predict"], pred_y, 288 | mask = batch_dict["mask_predicted_data"]) 289 | 290 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"])) 291 | if self.use_poisson_proc: 292 | pois_log_likelihood = compute_poisson_proc_likelihood( 293 | batch_dict["data_to_predict"], pred_y, 294 | info, mask = batch_dict["mask_predicted_data"]) 295 | # Take mean over n_traj 296 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1) 297 | 298 | ################################ 299 | # Compute CE loss for binary classification on Physionet 300 | device = get_device(batch_dict["data_to_predict"]) 301 | ce_loss = torch.Tensor([0.]).to(device) 302 | if (batch_dict["labels"] is not None) and self.use_binary_classif: 303 | 304 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1): 305 | ce_loss = compute_binary_CE_loss( 306 | info["label_predictions"], 307 | batch_dict["labels"]) 308 | else: 309 | ce_loss = compute_multiclass_CE_loss( 310 | info["label_predictions"], 311 | batch_dict["labels"], 312 | mask = batch_dict["mask_predicted_data"]) 313 | 314 | # IWAE loss 315 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0) 316 | if torch.isnan(loss): 317 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0) 318 | 319 | if self.use_poisson_proc: 320 | loss = loss - 0.1 * pois_log_likelihood 321 | 322 | if self.use_binary_classif: 323 | if self.train_classif_w_reconstr: 324 | loss = loss + ce_loss * 100 325 | else: 326 | loss = ce_loss 327 | 328 | results = {} 329 | results["loss"] = torch.mean(loss) 330 | results["likelihood"] = torch.mean(rec_likelihood).detach() 331 | results["mse"] = torch.mean(mse).detach() 332 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach() 333 | results["ce_loss"] = torch.mean(ce_loss).detach() 334 | results["kl_first_p"] = torch.mean(kldiv_z0).detach() 335 | results["std_first_p"] = torch.mean(fp_std).detach() 336 | 337 | if batch_dict["labels"] is not None and self.use_binary_classif: 338 | results["label_predictions"] = info["label_predictions"].detach() 339 | 340 | return results 341 | 342 | 343 | 344 | -------------------------------------------------------------------------------- /lib/latent_ode/create_latent_ode_model.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import os 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.functional import relu 12 | 13 | from . import utils as utils 14 | from .latent_ode import LatentODE 15 | from .encoder_decoder import * 16 | from .diffeq_solver import DiffeqSolver 17 | 18 | from torch.distributions.normal import Normal 19 | from .ode_func import ODEFunc, ODEFunc_w_Poisson 20 | 21 | ##################################################################################################### 22 | 23 | def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, 24 | classif_per_tp = False, n_labels = 1): 25 | 26 | dim = args.latents 27 | if args.poisson: 28 | lambda_net = utils.create_net(dim, input_dim, 29 | n_layers = 1, n_units = args.units, nonlinear = nn.Tanh) 30 | 31 | # ODE function produces the gradient for latent state and for poisson rate 32 | ode_func_net = utils.create_net(dim * 2, args.latents * 2, 33 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh) 34 | 35 | gen_ode_func = ODEFunc_w_Poisson( 36 | input_dim = input_dim, 37 | latent_dim = args.latents * 2, 38 | ode_func_net = ode_func_net, 39 | lambda_net = lambda_net, 40 | device = device).to(device) 41 | else: 42 | dim = args.latents 43 | ode_func_net = utils.create_net(dim, args.latents, 44 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh) 45 | 46 | gen_ode_func = ODEFunc( 47 | input_dim = input_dim, 48 | latent_dim = args.latents, 49 | ode_func_net = ode_func_net, 50 | device = device).to(device) 51 | 52 | z0_diffeq_solver = None 53 | n_rec_dims = args.rec_dims 54 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask 55 | gen_data_dim = input_dim 56 | 57 | z0_dim = args.latents 58 | if args.poisson: 59 | z0_dim += args.latents # predict the initial poisson rate 60 | 61 | if args.z0_encoder == "odernn": 62 | ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, 63 | n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh) 64 | 65 | rec_ode_func = ODEFunc( 66 | input_dim = enc_input_dim, 67 | latent_dim = n_rec_dims, 68 | ode_func_net = ode_func_net, 69 | device = device).to(device) 70 | 71 | z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents, 72 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device) 73 | 74 | encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver, 75 | z0_dim = z0_dim, n_gru_units = args.gru_units, device = device).to(device) 76 | 77 | elif args.z0_encoder == "rnn": 78 | encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim, 79 | lstm_output_size = n_rec_dims, device = device).to(device) 80 | else: 81 | raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder) 82 | 83 | decoder = Decoder(args.latents, gen_data_dim).to(device) 84 | 85 | diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 'dopri5', args.latents, 86 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device) 87 | 88 | model = LatentODE( 89 | input_dim = gen_data_dim, 90 | latent_dim = args.latents, 91 | encoder_z0 = encoder_z0, 92 | decoder = decoder, 93 | diffeq_solver = diffeq_solver, 94 | z0_prior = z0_prior, 95 | device = device, 96 | obsrv_std = obsrv_std, 97 | use_poisson_proc = args.poisson, 98 | use_binary_classif = args.classif, 99 | linear_classifier = args.linear_classif, 100 | classif_per_tp = classif_per_tp, 101 | n_labels = n_labels, 102 | train_classif_w_reconstr = (args.dataset == "physionet") 103 | ).to(device) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /lib/latent_ode/diffeq_solver.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import time 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from . import utils as utils 13 | from torch.distributions.multivariate_normal import MultivariateNormal 14 | 15 | # git clone https://github.com/rtqichen/torchdiffeq.git 16 | from torchdiffeq import odeint as odeint 17 | 18 | ##################################################################################################### 19 | 20 | class DiffeqSolver(nn.Module): 21 | def __init__(self, input_dim, ode_func, method, latents, 22 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")): 23 | super(DiffeqSolver, self).__init__() 24 | 25 | self.ode_method = method 26 | self.latents = latents 27 | self.device = device 28 | self.ode_func = ode_func 29 | 30 | self.odeint_rtol = odeint_rtol 31 | self.odeint_atol = odeint_atol 32 | 33 | def forward(self, first_point, time_steps_to_predict, backwards = False): 34 | """ 35 | # Decode the trajectory through ODE Solver 36 | """ 37 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 38 | n_dims = first_point.size()[-1] 39 | 40 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict, 41 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 42 | pred_y = pred_y.permute(1,2,0,3) 43 | 44 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001) 45 | assert(pred_y.size()[0] == n_traj_samples) 46 | assert(pred_y.size()[1] == n_traj) 47 | 48 | return pred_y 49 | 50 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict, 51 | n_traj_samples = 1): 52 | """ 53 | # Decode the trajectory through ODE Solver using samples from the prior 54 | 55 | time_steps_to_predict: time steps at which we want to sample the new trajectory 56 | """ 57 | func = self.ode_func.sample_next_point_from_prior 58 | 59 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict, 60 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 61 | # shape: [n_traj_samples, n_traj, n_tp, n_dim] 62 | pred_y = pred_y.permute(1,2,0,3) 63 | return pred_y 64 | 65 | 66 | -------------------------------------------------------------------------------- /lib/latent_ode/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | from . import utils as utils 11 | from torch.distributions import Categorical, Normal 12 | from torch.nn.modules.rnn import LSTM, GRU 13 | from .utils import get_device 14 | 15 | 16 | # GRU description: 17 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/ 18 | class GRU_unit(nn.Module): 19 | def __init__(self, latent_dim, input_dim, 20 | update_gate = None, 21 | reset_gate = None, 22 | new_state_net = None, 23 | n_units = 100, 24 | device = torch.device("cpu")): 25 | super(GRU_unit, self).__init__() 26 | 27 | if update_gate is None: 28 | self.update_gate = nn.Sequential( 29 | nn.Linear(latent_dim * 2 + input_dim, n_units), 30 | nn.Tanh(), 31 | nn.Linear(n_units, latent_dim), 32 | nn.Sigmoid()) 33 | utils.init_network_weights(self.update_gate) 34 | else: 35 | self.update_gate = update_gate 36 | 37 | if reset_gate is None: 38 | self.reset_gate = nn.Sequential( 39 | nn.Linear(latent_dim * 2 + input_dim, n_units), 40 | nn.Tanh(), 41 | nn.Linear(n_units, latent_dim), 42 | nn.Sigmoid()) 43 | utils.init_network_weights(self.reset_gate) 44 | else: 45 | self.reset_gate = reset_gate 46 | 47 | if new_state_net is None: 48 | self.new_state_net = nn.Sequential( 49 | nn.Linear(latent_dim * 2 + input_dim, n_units), 50 | nn.Tanh(), 51 | nn.Linear(n_units, latent_dim * 2)) 52 | utils.init_network_weights(self.new_state_net) 53 | else: 54 | self.new_state_net = new_state_net 55 | 56 | 57 | def forward(self, y_mean, y_std, x, masked_update = True): 58 | y_concat = torch.cat([y_mean, y_std, x], -1) 59 | 60 | update_gate = self.update_gate(y_concat) 61 | reset_gate = self.reset_gate(y_concat) 62 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) 63 | 64 | new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat)) 65 | new_state_std = new_state_std.abs() 66 | 67 | new_y = (1-update_gate) * new_state + update_gate * y_mean 68 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std 69 | 70 | assert(not torch.isnan(new_y).any()) 71 | 72 | if masked_update: 73 | # IMPORTANT: assumes that x contains both data and mask 74 | # update only the hidden states for hidden state only if at least one feature is present for the current time point 75 | n_data_dims = x.size(-1)//2 76 | mask = x[:, :, n_data_dims:] 77 | utils.check_mask(x[:, :, :n_data_dims], mask) 78 | 79 | mask = (torch.sum(mask, -1, keepdim = True) > 0).float() 80 | 81 | assert(not torch.isnan(mask).any()) 82 | 83 | new_y = mask * new_y + (1-mask) * y_mean 84 | new_y_std = mask * new_y_std + (1-mask) * y_std 85 | 86 | if torch.isnan(new_y).any(): 87 | print("new_y is nan!") 88 | print(mask) 89 | print(y_mean) 90 | print(prev_new_y) 91 | exit() 92 | 93 | new_y_std = new_y_std.abs() 94 | return new_y, new_y_std 95 | 96 | 97 | 98 | class Encoder_z0_RNN(nn.Module): 99 | def __init__(self, latent_dim, input_dim, lstm_output_size = 20, 100 | use_delta_t = True, device = torch.device("cpu")): 101 | 102 | super(Encoder_z0_RNN, self).__init__() 103 | 104 | self.gru_rnn_output_size = lstm_output_size 105 | self.latent_dim = latent_dim 106 | self.input_dim = input_dim 107 | self.device = device 108 | self.use_delta_t = use_delta_t 109 | 110 | self.hiddens_to_z0 = nn.Sequential( 111 | nn.Linear(self.gru_rnn_output_size, 50), 112 | nn.Tanh(), 113 | nn.Linear(50, latent_dim * 2),) 114 | 115 | utils.init_network_weights(self.hiddens_to_z0) 116 | 117 | input_dim = self.input_dim 118 | 119 | if use_delta_t: 120 | self.input_dim += 1 121 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device) 122 | 123 | def forward(self, data, time_steps, run_backwards = True): 124 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 125 | 126 | # data shape: [n_traj, n_tp, n_dims] 127 | # shape required for rnn: (seq_len, batch, input_size) 128 | # t0: not used here 129 | n_traj = data.size(0) 130 | 131 | assert(not torch.isnan(data).any()) 132 | assert(not torch.isnan(time_steps).any()) 133 | 134 | data = data.permute(1,0,2) 135 | 136 | if run_backwards: 137 | # Look at data in the reverse order: from later points to the first 138 | data = utils.reverse(data) 139 | 140 | if self.use_delta_t: 141 | delta_t = time_steps[1:] - time_steps[:-1] 142 | if run_backwards: 143 | # we are going backwards in time with 144 | delta_t = utils.reverse(delta_t) 145 | # append zero delta t in the end 146 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) 147 | delta_t = delta_t.unsqueeze(1).repeat((1,n_traj)).unsqueeze(-1) 148 | data = torch.cat((delta_t, data),-1) 149 | 150 | outputs, _ = self.gru_rnn(data) 151 | 152 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size) 153 | last_output = outputs[-1] 154 | 155 | self.extra_info ={"rnn_outputs": outputs, "time_points": time_steps} 156 | 157 | mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) 158 | std = std.abs() 159 | 160 | assert(not torch.isnan(mean).any()) 161 | assert(not torch.isnan(std).any()) 162 | 163 | return mean.unsqueeze(0), std.unsqueeze(0) 164 | 165 | 166 | 167 | 168 | 169 | class Encoder_z0_ODE_RNN(nn.Module): 170 | # Derive z0 by running ode backwards. 171 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i 172 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1 173 | # Continue until we get to z0 174 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None, 175 | z0_dim = None, GRU_update = None, 176 | n_gru_units = 100, 177 | device = torch.device("cpu")): 178 | 179 | super(Encoder_z0_ODE_RNN, self).__init__() 180 | 181 | if z0_dim is None: 182 | self.z0_dim = latent_dim 183 | else: 184 | self.z0_dim = z0_dim 185 | 186 | if GRU_update is None: 187 | self.GRU_update = GRU_unit(latent_dim, input_dim, 188 | n_units = n_gru_units, 189 | device=device).to(device) 190 | else: 191 | self.GRU_update = GRU_update 192 | 193 | self.z0_diffeq_solver = z0_diffeq_solver 194 | self.latent_dim = latent_dim 195 | self.input_dim = input_dim 196 | self.device = device 197 | self.extra_info = None 198 | 199 | self.transform_z0 = nn.Sequential( 200 | nn.Linear(latent_dim * 2, 100), 201 | nn.Tanh(), 202 | nn.Linear(100, self.z0_dim * 2),) 203 | utils.init_network_weights(self.transform_z0) 204 | 205 | 206 | def forward(self, data, time_steps, run_backwards = True, save_info = False): 207 | # data, time_steps -- observations and their time stamps 208 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 209 | assert(not torch.isnan(data).any()) 210 | assert(not torch.isnan(time_steps).any()) 211 | 212 | n_traj, n_tp, n_dims = data.size() 213 | if len(time_steps) == 1: 214 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 215 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) 216 | 217 | xi = data[:,0,:].unsqueeze(0) 218 | 219 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi) 220 | extra_info = None 221 | else: 222 | 223 | last_yi, last_yi_std, _, extra_info = self.run_odernn( 224 | data, time_steps, run_backwards = run_backwards, 225 | save_info = save_info) 226 | 227 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim) 228 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim) 229 | 230 | mean_z0, std_z0 = utils.split_last_dim( self.transform_z0( torch.cat((means_z0, std_z0), -1))) 231 | std_z0 = std_z0.abs() 232 | if save_info: 233 | self.extra_info = extra_info 234 | 235 | return mean_z0, std_z0 236 | 237 | 238 | def run_odernn(self, data, time_steps, 239 | run_backwards = True, save_info = False): 240 | # IMPORTANT: assumes that 'data' already has mask concatenated to it 241 | 242 | n_traj, n_tp, n_dims = data.size() 243 | extra_info = [] 244 | 245 | t0 = time_steps[-1] 246 | if run_backwards: 247 | t0 = time_steps[0] 248 | 249 | device = get_device(data) 250 | 251 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) 252 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) 253 | 254 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] 255 | 256 | interval_length = time_steps[-1] - time_steps[0] 257 | minimum_step = interval_length / 50 258 | 259 | #print("minimum step: {}".format(minimum_step)) 260 | 261 | assert(not torch.isnan(data).any()) 262 | assert(not torch.isnan(time_steps).any()) 263 | 264 | latent_ys = [] 265 | # Run ODE backwards and combine the y(t) estimates using gating 266 | time_points_iter = range(0, len(time_steps)) 267 | if run_backwards: 268 | time_points_iter = reversed(time_points_iter) 269 | 270 | for i in time_points_iter: 271 | if (prev_t - t_i) < minimum_step: 272 | time_points = torch.stack((prev_t, t_i)) 273 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) 274 | 275 | assert(not torch.isnan(inc).any()) 276 | 277 | ode_sol = prev_y + inc 278 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) 279 | 280 | assert(not torch.isnan(ode_sol).any()) 281 | else: 282 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) 283 | 284 | time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) 285 | ode_sol = self.z0_diffeq_solver(prev_y, time_points) 286 | 287 | assert(not torch.isnan(ode_sol).any()) 288 | 289 | if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: 290 | print("Error: first point of the ODE is not equal to initial value") 291 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) 292 | exit() 293 | #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) 294 | 295 | yi_ode = ode_sol[:, :, -1, :] 296 | xi = data[:,i,:].unsqueeze(0) 297 | 298 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) 299 | 300 | prev_y, prev_std = yi, yi_std 301 | prev_t, t_i = time_steps[i], time_steps[i-1] 302 | 303 | latent_ys.append(yi) 304 | 305 | if save_info: 306 | d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data, 307 | "yi": yi.detach(), "yi_std": yi_std.detach(), 308 | "time_points": time_points.detach(), "ode_sol": ode_sol.detach()} 309 | extra_info.append(d) 310 | 311 | latent_ys = torch.stack(latent_ys, 1) 312 | 313 | assert(not torch.isnan(yi).any()) 314 | assert(not torch.isnan(yi_std).any()) 315 | 316 | return yi, yi_std, latent_ys, extra_info 317 | 318 | 319 | 320 | class Decoder(nn.Module): 321 | def __init__(self, latent_dim, input_dim): 322 | super(Decoder, self).__init__() 323 | # decode data from latent space where we are solving an ODE back to the data space 324 | 325 | decoder = nn.Sequential( 326 | nn.Linear(latent_dim, input_dim),) 327 | 328 | utils.init_network_weights(decoder) 329 | self.decoder = decoder 330 | 331 | def forward(self, data): 332 | return self.decoder(data) 333 | 334 | 335 | -------------------------------------------------------------------------------- /lib/latent_ode/eval_glunet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | import random 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import scipy as sp 9 | import pandas as pd 10 | import torch 11 | 12 | # import likelihood evaluation 13 | from .likelihood_eval import masked_gaussian_log_density 14 | 15 | def test(series: np.ndarray, 16 | forecasts: np.ndarray, 17 | obsrv_std: np.ndarray, 18 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11), 19 | ): 20 | """ 21 | Test the (rescaled to original scale) forecasts on the series. 22 | 23 | Parameters 24 | ---------- 25 | series 26 | The target time series of shape (n_traj, n_tp, n_dim), 27 | where t is length of prediction. 28 | forecasts 29 | The forecasted means of mixture components of shape (n_traj_samples, n_traj, n_tp, n_dim) 30 | where k is the number of mixture components. 31 | obsrv_std 32 | The forecasted std of mixture components of shape (1). 33 | cal_thresholds 34 | The thresholds to use for computing the calibration error. 35 | 36 | Returns 37 | ------- 38 | np.ndarray 39 | Error array. Array of shape (n_traj, 2), where 40 | along last dimension, we have MSE and MAE. 41 | float 42 | The estimated log-likelihood of the model on the data. 43 | np.ndarray 44 | The ECE for each time point in the forecast. 45 | """ 46 | mse = np.mean((series - forecasts.mean(axis=0))**2, axis=-2) 47 | mae = np.mean(np.abs(series - forecasts.mean(axis=0)), axis=-2) 48 | errors = np.stack([mse.squeeze(), mae.squeeze()], axis=-1) 49 | 50 | # compute likelihood 51 | series, forecasts = torch.tensor(series), torch.tensor(forecasts) 52 | obsrv_std = torch.Tensor(obsrv_std) 53 | 54 | series_repeated = series.repeat(forecasts.size(0), 1, 1, 1) 55 | log_density_data = masked_gaussian_log_density(forecasts, series_repeated, 56 | obsrv_std = obsrv_std, mask = None) 57 | log_density_data = log_density_data.permute(1,0) 58 | log_density = torch.mean(log_density_data, 1) 59 | log_likelihood = torch.mean(log_density).item() 60 | 61 | # compute calibration error 62 | samples = torch.distributions.Normal(loc=forecasts, scale=obsrv_std).sample((100, )) 63 | samples = samples.view(samples.shape[0] * samples.shape[1], 64 | samples.shape[2], 65 | samples.shape[3]) 66 | series = series.squeeze() 67 | cal_error = torch.zeros(series.shape[1]) 68 | for p in cal_thresholds: 69 | q = torch.quantile(samples, p, dim=0) 70 | est_p = torch.mean((series <= q).float(), dim=0) 71 | cal_error += (est_p - p) ** 2 72 | cal_error = cal_error.numpy() 73 | 74 | return errors, log_likelihood, cal_error -------------------------------------------------------------------------------- /lib/latent_ode/latent_ode.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import sklearn as sk 8 | import numpy as np 9 | #import gc 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.functional import relu 13 | 14 | from . import utils as utils 15 | from .utils import get_device 16 | from .encoder_decoder import * 17 | from .likelihood_eval import * 18 | 19 | from torch.distributions.multivariate_normal import MultivariateNormal 20 | from torch.distributions.normal import Normal 21 | from torch.distributions import kl_divergence, Independent 22 | from .base_models import VAE_Baseline 23 | 24 | 25 | 26 | class LatentODE(VAE_Baseline): 27 | def __init__(self, input_dim, latent_dim, encoder_z0, decoder, diffeq_solver, 28 | z0_prior, device, obsrv_std = None, 29 | use_binary_classif = False, use_poisson_proc = False, 30 | linear_classifier = False, 31 | classif_per_tp = False, 32 | n_labels = 1, 33 | train_classif_w_reconstr = False): 34 | 35 | super(LatentODE, self).__init__( 36 | input_dim = input_dim, latent_dim = latent_dim, 37 | z0_prior = z0_prior, 38 | device = device, obsrv_std = obsrv_std, 39 | use_binary_classif = use_binary_classif, 40 | classif_per_tp = classif_per_tp, 41 | linear_classifier = linear_classifier, 42 | use_poisson_proc = use_poisson_proc, 43 | n_labels = n_labels, 44 | train_classif_w_reconstr = train_classif_w_reconstr) 45 | 46 | self.encoder_z0 = encoder_z0 47 | self.diffeq_solver = diffeq_solver 48 | self.decoder = decoder 49 | self.use_poisson_proc = use_poisson_proc 50 | 51 | def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, 52 | mask = None, n_traj_samples = 1, run_backwards = True, mode = None): 53 | 54 | if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \ 55 | isinstance(self.encoder_z0, Encoder_z0_RNN): 56 | 57 | truth_w_mask = truth 58 | if mask is not None: 59 | truth_w_mask = torch.cat((truth, mask), -1) 60 | first_point_mu, first_point_std = self.encoder_z0( 61 | truth_w_mask, truth_time_steps, run_backwards = run_backwards) 62 | 63 | means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) 64 | sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) 65 | first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0) 66 | 67 | else: 68 | raise Exception("Unknown encoder type {}".format(type(self.encoder_z0).__name__)) 69 | 70 | first_point_std = first_point_std.abs() 71 | assert(torch.sum(first_point_std < 0) == 0.) 72 | 73 | if self.use_poisson_proc: 74 | n_traj_samples, n_traj, n_dims = first_point_enc.size() 75 | # append a vector of zeros to compute the integral of lambda 76 | zeros = torch.zeros([n_traj_samples, n_traj,self.input_dim]).to(get_device(truth)) 77 | first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) 78 | means_z0_aug = torch.cat((means_z0, zeros), -1) 79 | else: 80 | first_point_enc_aug = first_point_enc 81 | means_z0_aug = means_z0 82 | 83 | assert(not torch.isnan(time_steps_to_predict).any()) 84 | assert(not torch.isnan(first_point_enc).any()) 85 | assert(not torch.isnan(first_point_enc_aug).any()) 86 | 87 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] 88 | sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) 89 | 90 | if self.use_poisson_proc: 91 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 92 | 93 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.) 94 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 95 | 96 | pred_x = self.decoder(sol_y) 97 | 98 | all_extra_info = { 99 | "first_point": (first_point_mu, first_point_std, first_point_enc), 100 | "latent_traj": sol_y.detach() 101 | } 102 | 103 | if self.use_poisson_proc: 104 | # intergral of lambda from the last step of ODE Solver 105 | all_extra_info["int_lambda"] = int_lambda[:,:,-1,:] 106 | all_extra_info["log_lambda_y"] = log_lambda_y 107 | 108 | if self.use_binary_classif: 109 | if self.classif_per_tp: 110 | all_extra_info["label_predictions"] = self.classifier(sol_y) 111 | else: 112 | all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1) 113 | 114 | return pred_x, all_extra_info 115 | 116 | 117 | def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples = 1): 118 | # input_dim = starting_point.size()[-1] 119 | # starting_point = starting_point.view(1,1,input_dim) 120 | 121 | # Sample z0 from prior 122 | starting_point_enc = self.z0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1) 123 | 124 | starting_point_enc_aug = starting_point_enc 125 | if self.use_poisson_proc: 126 | n_traj_samples, n_traj, n_dims = starting_point_enc.size() 127 | # append a vector of zeros to compute the integral of lambda 128 | zeros = torch.zeros(n_traj_samples, n_traj,self.input_dim).to(self.device) 129 | starting_point_enc_aug = torch.cat((starting_point_enc, zeros), -1) 130 | 131 | sol_y = self.diffeq_solver.sample_traj_from_prior(starting_point_enc_aug, time_steps_to_predict, 132 | n_traj_samples = 3) 133 | 134 | if self.use_poisson_proc: 135 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y) 136 | 137 | return self.decoder(sol_y) 138 | 139 | 140 | -------------------------------------------------------------------------------- /lib/latent_ode/likelihood_eval.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import gc 7 | import numpy as np 8 | import sklearn as sk 9 | import numpy as np 10 | #import gc 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn.functional import relu 14 | 15 | from . import utils as utils 16 | from .utils import get_device 17 | from .encoder_decoder import * 18 | from .likelihood_eval import * 19 | 20 | from torch.distributions.multivariate_normal import MultivariateNormal 21 | from torch.distributions.normal import Normal 22 | from torch.distributions import kl_divergence, Independent 23 | 24 | 25 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None): 26 | n_data_points = mu_2d.size()[-1] 27 | 28 | if n_data_points > 0: 29 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1) 30 | log_prob = gaussian.log_prob(data_2d) 31 | log_prob = log_prob / n_data_points 32 | else: 33 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() 34 | return log_prob 35 | 36 | 37 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas): 38 | # masked_log_lambdas and masked_data 39 | n_data_points = masked_data.size()[-1] 40 | 41 | if n_data_points > 0: 42 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices] 43 | #log_prob = log_prob / n_data_points 44 | else: 45 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze() 46 | return log_prob 47 | 48 | 49 | 50 | def compute_binary_CE_loss(label_predictions, mortality_label): 51 | #print("Computing binary classification loss: compute_CE_loss") 52 | 53 | mortality_label = mortality_label.reshape(-1) 54 | 55 | if len(label_predictions.size()) == 1: 56 | label_predictions = label_predictions.unsqueeze(0) 57 | 58 | n_traj_samples = label_predictions.size(0) 59 | label_predictions = label_predictions.reshape(n_traj_samples, -1) 60 | 61 | idx_not_nan = ~torch.isnan(mortality_label) 62 | if len(idx_not_nan) == 0.: 63 | print("All are labels are NaNs!") 64 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label)) 65 | 66 | label_predictions = label_predictions[:,idx_not_nan] 67 | mortality_label = mortality_label[idx_not_nan] 68 | 69 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0: 70 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.") 71 | 72 | assert(not torch.isnan(label_predictions).any()) 73 | assert(not torch.isnan(mortality_label).any()) 74 | 75 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 76 | mortality_label = mortality_label.repeat(n_traj_samples, 1) 77 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) 78 | 79 | # divide by number of patients in a batch 80 | ce_loss = ce_loss / n_traj_samples 81 | return ce_loss 82 | 83 | 84 | def compute_multiclass_CE_loss(label_predictions, true_label, mask): 85 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss") 86 | 87 | if (len(label_predictions.size()) == 3): 88 | label_predictions = label_predictions.unsqueeze(0) 89 | 90 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size() 91 | 92 | # assert(not torch.isnan(label_predictions).any()) 93 | # assert(not torch.isnan(true_label).any()) 94 | 95 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them 96 | true_label = true_label.repeat(n_traj_samples, 1, 1) 97 | 98 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims) 99 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims) 100 | 101 | # choose time points with at least one measurement 102 | mask = torch.sum(mask, -1) > 0 103 | 104 | # repeat the mask for each label to mark that the label for this time point is present 105 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0) 106 | 107 | label_mask = mask 108 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1) 109 | label_mask = label_mask.repeat(n_traj_samples,1,1,1) 110 | 111 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims) 112 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1) 113 | 114 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): 115 | assert(label_predictions.size(-1) == true_label.size(-1)) 116 | # targets are in one-hot encoding -- convert to indices 117 | _, true_label = true_label.max(-1) 118 | 119 | res = [] 120 | for i in range(true_label.size(0)): 121 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool()) 122 | labels = torch.masked_select(true_label[i], label_mask[i].bool()) 123 | 124 | pred_masked = pred_masked.reshape(-1, n_dims) 125 | 126 | if (len(labels) == 0): 127 | continue 128 | 129 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long()) 130 | res.append(ce_loss) 131 | 132 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions)) 133 | ce_loss = torch.mean(ce_loss) 134 | # # divide by number of patients in a batch 135 | # ce_loss = ce_loss / n_traj_samples 136 | return ce_loss 137 | 138 | 139 | 140 | 141 | def compute_masked_likelihood(mu, data, mask, likelihood_func): 142 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 143 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 144 | 145 | res = [] 146 | for i in range(n_traj_samples): 147 | for k in range(n_traj): 148 | for j in range(n_dims): 149 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool()) 150 | 151 | #assert(torch.sum(data_masked == 0.) < 10) 152 | 153 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool()) 154 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j)) 155 | res.append(log_prob) 156 | # shape: [n_traj*n_traj_samples, 1] 157 | 158 | res = torch.stack(res, 0).to(get_device(data)) 159 | res = res.reshape((n_traj_samples, n_traj, n_dims)) 160 | # Take mean over the number of dimensions 161 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean 162 | res = res.transpose(0,1) 163 | return res 164 | 165 | 166 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None): 167 | # these cases are for plotting through plot_estim_density 168 | if (len(mu.size()) == 3): 169 | # add additional dimension for gp samples 170 | mu = mu.unsqueeze(0) 171 | 172 | if (len(data.size()) == 2): 173 | # add additional dimension for gp samples and time step 174 | data = data.unsqueeze(0).unsqueeze(2) 175 | elif (len(data.size()) == 3): 176 | # add additional dimension for gp samples 177 | data = data.unsqueeze(0) 178 | 179 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 180 | 181 | assert(data.size()[-1] == n_dims) 182 | 183 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 184 | if mask is None: 185 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 186 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 187 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 188 | 189 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std) 190 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1) 191 | else: 192 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 193 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices) 194 | res = compute_masked_likelihood(mu, data, mask, func) 195 | return res 196 | 197 | 198 | 199 | def mse(mu, data, indices = None): 200 | n_data_points = mu.size()[-1] 201 | 202 | if n_data_points > 0: 203 | mse = nn.MSELoss()(mu, data) 204 | else: 205 | mse = torch.zeros([1]).to(get_device(data)).squeeze() 206 | return mse 207 | 208 | 209 | def compute_mse(mu, data, mask = None): 210 | # these cases are for plotting through plot_estim_density 211 | if (len(mu.size()) == 3): 212 | # add additional dimension for gp samples 213 | mu = mu.unsqueeze(0) 214 | 215 | if (len(data.size()) == 2): 216 | # add additional dimension for gp samples and time step 217 | data = data.unsqueeze(0).unsqueeze(2) 218 | elif (len(data.size()) == 3): 219 | # add additional dimension for gp samples 220 | data = data.unsqueeze(0) 221 | 222 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size() 223 | assert(data.size()[-1] == n_dims) 224 | 225 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims] 226 | if mask is None: 227 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 228 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size() 229 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims) 230 | res = mse(mu_flat, data_flat) 231 | else: 232 | # Compute the likelihood per patient so that we don't priorize patients with more measurements 233 | res = compute_masked_likelihood(mu, data, mask, mse) 234 | return res 235 | 236 | 237 | 238 | 239 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None): 240 | # Compute Poisson likelihood 241 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process 242 | # Sum log lambdas across all time points 243 | if mask is None: 244 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"] 245 | # Sum over data dims 246 | poisson_log_l = torch.mean(poisson_log_l, -1) 247 | else: 248 | # Compute likelihood of the data under the predictions 249 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1) 250 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1) 251 | 252 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements 253 | int_lambda = info["int_lambda"] 254 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda) 255 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f) 256 | poisson_log_l = poisson_log_l.permute(1,0) 257 | # Take mean over n_traj 258 | #poisson_log_l = torch.mean(poisson_log_l, 1) 259 | 260 | # poisson_log_l shape: [n_traj_samples, n_traj] 261 | return poisson_log_l 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /lib/latent_ode/ode_func.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.utils.spectral_norm import spectral_norm 10 | 11 | from . import utils as utils 12 | 13 | ##################################################################################################### 14 | 15 | class ODEFunc(nn.Module): 16 | def __init__(self, input_dim, latent_dim, ode_func_net, device = torch.device("cpu")): 17 | """ 18 | input_dim: dimensionality of the input 19 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 20 | """ 21 | super(ODEFunc, self).__init__() 22 | 23 | self.input_dim = input_dim 24 | self.device = device 25 | 26 | utils.init_network_weights(ode_func_net) 27 | self.gradient_net = ode_func_net 28 | 29 | def forward(self, t_local, y, backwards = False): 30 | """ 31 | Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point 32 | 33 | t_local: current time point 34 | y: value at the current time point 35 | """ 36 | grad = self.get_ode_gradient_nn(t_local, y) 37 | if backwards: 38 | grad = -grad 39 | return grad 40 | 41 | def get_ode_gradient_nn(self, t_local, y): 42 | return self.gradient_net(y) 43 | 44 | def sample_next_point_from_prior(self, t_local, y): 45 | """ 46 | t_local: current time point 47 | y: value at the current time point 48 | """ 49 | return self.get_ode_gradient_nn(t_local, y) 50 | 51 | ##################################################################################################### 52 | 53 | class ODEFunc_w_Poisson(ODEFunc): 54 | 55 | def __init__(self, input_dim, latent_dim, ode_func_net, 56 | lambda_net, device = torch.device("cpu")): 57 | """ 58 | input_dim: dimensionality of the input 59 | latent_dim: dimensionality used for ODE. Analog of a continous latent state 60 | """ 61 | super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, ode_func_net, device) 62 | 63 | self.latent_ode = ODEFunc(input_dim = input_dim, 64 | latent_dim = latent_dim, 65 | ode_func_net = ode_func_net, 66 | device = device) 67 | 68 | self.latent_dim = latent_dim 69 | self.lambda_net = lambda_net 70 | # The computation of poisson likelihood can become numerically unstable. 71 | #The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T] 72 | #Exponent of lambda can also take large values 73 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant 74 | self.const_for_lambda = torch.Tensor([100.]).to(device) 75 | 76 | def extract_poisson_rate(self, augmented, final_result = True): 77 | y, log_lambdas, int_lambda = None, None, None 78 | 79 | assert(augmented.size(-1) == self.latent_dim + self.input_dim) 80 | latent_lam_dim = self.latent_dim // 2 81 | 82 | if len(augmented.size()) == 3: 83 | int_lambda = augmented[:,:,-self.input_dim:] 84 | y_latent_lam = augmented[:,:,:-self.input_dim] 85 | 86 | log_lambdas = self.lambda_net(y_latent_lam[:,:,-latent_lam_dim:]) 87 | y = y_latent_lam[:,:,:-latent_lam_dim] 88 | 89 | elif len(augmented.size()) == 4: 90 | int_lambda = augmented[:,:,:,-self.input_dim:] 91 | y_latent_lam = augmented[:,:,:,:-self.input_dim] 92 | 93 | log_lambdas = self.lambda_net(y_latent_lam[:,:,:,-latent_lam_dim:]) 94 | y = y_latent_lam[:,:,:,:-latent_lam_dim] 95 | 96 | # Multiply the intergral over lambda by a constant 97 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn) 98 | if final_result: 99 | int_lambda = int_lambda * self.const_for_lambda 100 | 101 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas) 102 | assert(y.size(-1) == latent_lam_dim) 103 | 104 | return y, log_lambdas, int_lambda, y_latent_lam 105 | 106 | 107 | def get_ode_gradient_nn(self, t_local, augmented): 108 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result = False) 109 | dydt_dldt = self.latent_ode(t_local, y_latent_lam) 110 | 111 | log_lam = log_lam - torch.log(self.const_for_lambda) 112 | return torch.cat((dydt_dldt, torch.exp(log_lam)),-1) 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /lib/latent_ode/ode_rnn.py: -------------------------------------------------------------------------------- 1 | ########################### 2 | # Latent ODEs for Irregularly-Sampled Time Series 3 | # Author: Yulia Rubanova 4 | ########################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.functional import relu 10 | 11 | from . import utils as utils 12 | from .encoder_decoder import * 13 | from .likelihood_eval import * 14 | 15 | from torch.distributions.multivariate_normal import MultivariateNormal 16 | from torch.distributions.normal import Normal 17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase 18 | 19 | from torch.distributions.normal import Normal 20 | from torch.distributions import Independent 21 | from torch.nn.parameter import Parameter 22 | from .base_models import Baseline 23 | 24 | 25 | class ODE_RNN(Baseline): 26 | def __init__(self, input_dim, latent_dim, device = torch.device("cpu"), 27 | z0_diffeq_solver = None, n_gru_units = 100, n_units = 100, 28 | concat_mask = False, obsrv_std = 0.1, use_binary_classif = False, 29 | classif_per_tp = False, n_labels = 1, train_classif_w_reconstr = False): 30 | 31 | Baseline.__init__(self, input_dim, latent_dim, device = device, 32 | obsrv_std = obsrv_std, use_binary_classif = use_binary_classif, 33 | classif_per_tp = classif_per_tp, 34 | n_labels = n_labels, 35 | train_classif_w_reconstr = train_classif_w_reconstr) 36 | 37 | ode_rnn_encoder_dim = latent_dim 38 | 39 | self.ode_gru = Encoder_z0_ODE_RNN( 40 | latent_dim = ode_rnn_encoder_dim, 41 | input_dim = (input_dim) * 2, # input and the mask 42 | z0_diffeq_solver = z0_diffeq_solver, 43 | n_gru_units = n_gru_units, 44 | device = device).to(device) 45 | 46 | self.z0_diffeq_solver = z0_diffeq_solver 47 | 48 | self.decoder = nn.Sequential( 49 | nn.Linear(latent_dim, n_units), 50 | nn.Tanh(), 51 | nn.Linear(n_units, input_dim),) 52 | 53 | utils.init_network_weights(self.decoder) 54 | 55 | 56 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 57 | mask = None, n_traj_samples = None, mode = None): 58 | 59 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0): 60 | raise Exception("Extrapolation mode not implemented for ODE-RNN") 61 | 62 | # time_steps_to_predict and truth_time_steps should be the same 63 | assert(len(truth_time_steps) == len(time_steps_to_predict)) 64 | assert(mask is not None) 65 | 66 | data_and_mask = data 67 | if mask is not None: 68 | data_and_mask = torch.cat([data, mask],-1) 69 | 70 | _, _, latent_ys, _ = self.ode_gru.run_odernn( 71 | data_and_mask, truth_time_steps, run_backwards = False) 72 | 73 | latent_ys = latent_ys.permute(0,2,1,3) 74 | last_hidden = latent_ys[:,:,-1,:] 75 | 76 | #assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.) 77 | 78 | outputs = self.decoder(latent_ys) 79 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. 80 | first_point = data[:,0,:] 81 | outputs = utils.shift_outputs(outputs, first_point) 82 | 83 | extra_info = {"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:])} 84 | 85 | if self.use_binary_classif: 86 | if self.classif_per_tp: 87 | extra_info["label_predictions"] = self.classifier(latent_ys) 88 | else: 89 | extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1) 90 | 91 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] 92 | return outputs, extra_info 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /lib/linreg.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Dict 2 | import sys 3 | import os 4 | import yaml 5 | import datetime 6 | import argparse 7 | from functools import partial 8 | import optuna 9 | 10 | from darts import models 11 | from darts import metrics 12 | from darts import TimeSeries 13 | 14 | # import data formatter 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 16 | from data_formatter.base import * 17 | from utils.darts_processing import load_data, reshuffle_data 18 | from utils.darts_evaluation import rescale_and_backtest 19 | from utils.darts_training import print_callback 20 | 21 | # lag setter for covariates 22 | def set_lags(in_len, args): 23 | lags_past_covariates = None 24 | lags_future_covariates = None 25 | if args.use_covs == 'True': 26 | if series['train']['dynamic'] is not None: 27 | lags_past_covariates = in_len 28 | if series['train']['future'] is not None: 29 | lags_future_covariates = (in_len, formatter.params['length_pred']) 30 | return lags_past_covariates, lags_future_covariates 31 | 32 | # define objective function 33 | def objective(trial): 34 | # select input and output chunk lengths 35 | out_len = formatter.params["length_pred"] 36 | in_len = trial.suggest_int("in_len", 12, formatter.params["max_length_input"], step=12) # at least 2 hours of predictions left 37 | lags_past_covariates, lags_future_covariates = set_lags(in_len, args) 38 | 39 | # build the Linear Regression model 40 | model = models.LinearRegressionModel(lags = in_len, 41 | lags_past_covariates = lags_past_covariates, 42 | lags_future_covariates = lags_future_covariates, 43 | output_chunk_length = out_len) 44 | 45 | # train the model 46 | model.fit(series['train']['target'], 47 | past_covariates=series['train']['dynamic'], 48 | future_covariates=series['train']['future']) 49 | 50 | # backtest on the validation set 51 | errors = model.backtest(series['val']['target'], 52 | past_covariates=series['val']['dynamic'], 53 | future_covariates=series['val']['future'], 54 | forecast_horizon=out_len, 55 | stride=out_len, 56 | retrain=False, 57 | verbose=False, 58 | metric=metrics.rmse, 59 | last_points_only=False, 60 | ) 61 | avg_error = np.mean(errors) 62 | 63 | return avg_error 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--dataset', type=str, default='weinstock') 67 | parser.add_argument('--use_covs', type=str, default='False') 68 | parser.add_argument('--optuna', type=str, default='True') 69 | parser.add_argument('--reduction1', type=str, default='mean') 70 | parser.add_argument('--reduction2', type=str, default='median') 71 | parser.add_argument('--reduction3', type=str, default=None) 72 | args = parser.parse_args() 73 | reductions = [args.reduction1, args.reduction2, args.reduction3] 74 | if __name__ == '__main__': 75 | # load data 76 | study_file = f'./output/linreg_{args.dataset}.txt' if args.use_covs == 'False' \ 77 | else f'./output/linreg_covariates_{args.dataset}.txt' 78 | if not os.path.exists(study_file): 79 | with open(study_file, "w") as f: 80 | f.write(f"Optimization started at {datetime.datetime.now()}\n") 81 | formatter, series, scalers = load_data(study_file=study_file, 82 | dataset=args.dataset, 83 | use_covs=True if args.use_covs == 'True' else False, 84 | cov_type='mixed', 85 | use_static_covs=True) 86 | 87 | # hyperparameter optimization 88 | best_params = None 89 | if args.optuna == 'True': 90 | study = optuna.create_study(direction="minimize") 91 | print_call = partial(print_callback, study_file=study_file) 92 | study.optimize(objective, n_trials=50, 93 | callbacks=[print_call], 94 | catch=(np.linalg.LinAlgError, KeyError)) 95 | best_params = study.best_trial.params 96 | else: 97 | key = "linreg_covariates" if args.use_covs == 'True' else "linreg" 98 | assert formatter.params[key] is not None, "No saved hyperparameters found for this model" 99 | best_params = formatter.params[key] 100 | 101 | # select best hyperparameters 102 | in_len = best_params['in_len'] 103 | out_len = formatter.params["length_pred"] 104 | stride = out_len // 2 105 | lags_past_covariates, lags_future_covariates = set_lags(in_len, args) 106 | 107 | # test on ID and OOD data 108 | seeds = list(range(10, 20)) 109 | id_errors_cv = {key: [] for key in reductions if key is not None} 110 | ood_errors_cv = {key: [] for key in reductions if key is not None} 111 | id_likelihoods_cv = []; ood_likelihoods_cv = [] 112 | id_cal_errors_cv = []; ood_cal_errors_cv = [] 113 | for seed in seeds: 114 | formatter, series, scalers = reshuffle_data(formatter, 115 | seed, 116 | use_covs=True if args.use_covs == 'True' else False, 117 | cov_type='mixed', 118 | use_static_covs=True) 119 | # build the model 120 | model = models.LinearRegressionModel(lags = in_len, 121 | lags_past_covariates = lags_past_covariates, 122 | lags_future_covariates = lags_future_covariates, 123 | output_chunk_length = formatter.params['length_pred']) 124 | # train the model 125 | model.fit(series['train']['target'], 126 | past_covariates=series['train']['dynamic'], 127 | future_covariates=series['train']['future']) 128 | 129 | # backtest on the test set 130 | forecasts = model.historical_forecasts(series['test']['target'], 131 | past_covariates = series['test']['dynamic'], 132 | future_covariates = series['test']['future'], 133 | forecast_horizon=out_len, 134 | stride=stride, 135 | retrain=False, 136 | verbose=False, 137 | last_points_only=False, 138 | start=formatter.params["max_length_input"]) 139 | id_errors_sample, \ 140 | id_likelihood_sample, \ 141 | id_cal_errors_sample = rescale_and_backtest(series['test']['target'], 142 | forecasts, 143 | [metrics.mse, metrics.mae], 144 | scalers['target'], 145 | reduction=None) 146 | # backtest on the OOD set 147 | forecasts = model.historical_forecasts(series['test_ood']['target'], 148 | past_covariates = series['test_ood']['dynamic'], 149 | future_covariates = series['test_ood']['future'], 150 | forecast_horizon=out_len, 151 | stride=stride, 152 | retrain=False, 153 | verbose=False, 154 | last_points_only=False, 155 | start=formatter.params["max_length_input"]) 156 | 157 | ood_errors_sample, \ 158 | ood_likelihood_sample, \ 159 | ood_cal_errors_sample = rescale_and_backtest(series['test_ood']['target'], 160 | forecasts, 161 | [metrics.mse, metrics.mae], 162 | scalers['target'], 163 | reduction=None) 164 | 165 | # compute, save, and print results 166 | with open(study_file, "a") as f: 167 | for reduction in reductions: 168 | if reduction is not None: 169 | # compute 170 | reduction_f = getattr(np, reduction) 171 | id_errors_sample_red = reduction_f(id_errors_sample, axis=0) 172 | ood_errors_sample_red = reduction_f(ood_errors_sample, axis=0) 173 | # save 174 | id_errors_cv[reduction].append(id_errors_sample_red) 175 | ood_errors_cv[reduction].append(ood_errors_sample_red) 176 | # print 177 | f.write(f"\tSeed: {seed} ID {reduction} of (MSE, MAE): {id_errors_sample_red.tolist()}\n") 178 | f.write(f"\tSeed: {seed} OOD {reduction} of (MSE, MAE) stats: {ood_errors_sample_red.tolist()}\n") 179 | # save 180 | id_likelihoods_cv.append(id_likelihood_sample) 181 | ood_likelihoods_cv.append(ood_likelihood_sample) 182 | id_cal_errors_cv.append(id_cal_errors_sample) 183 | ood_cal_errors_cv.append(ood_cal_errors_sample) 184 | # print 185 | f.write(f"\tSeed: {seed} ID likelihoods: {id_likelihood_sample}\n") 186 | f.write(f"\tSeed: {seed} OOD likelihoods: {ood_likelihood_sample}\n") 187 | f.write(f"\tSeed: {seed} ID calibration errors: {id_cal_errors_sample.tolist()}\n") 188 | f.write(f"\tSeed: {seed} OOD calibration errors: {ood_cal_errors_sample.tolist()}\n") 189 | 190 | # compute, save, and print results 191 | with open(study_file, "a") as f: 192 | for reduction in reductions: 193 | if reduction is not None: 194 | # compute 195 | id_errors_cv[reduction] = np.vstack(id_errors_cv[reduction]) 196 | ood_errors_cv[reduction] = np.vstack(ood_errors_cv[reduction]) 197 | id_errors_cv[reduction] = np.mean(id_errors_cv[reduction], axis=0) 198 | ood_errors_cv[reduction] = np.mean(ood_errors_cv[reduction], axis=0) 199 | # print 200 | f.write(f"ID {reduction} of (MSE, MAE): {id_errors_cv[reduction].tolist()}\n") 201 | f.write(f"OOD {reduction} of (MSE, MAE): {ood_errors_cv[reduction].tolist()}\n") 202 | # compute 203 | id_likelihoods_cv = np.mean(id_likelihoods_cv) 204 | ood_likelihoods_cv = np.mean(ood_likelihoods_cv) 205 | id_cal_errors_cv = np.vstack(id_cal_errors_cv) 206 | ood_cal_errors_cv = np.vstack(ood_cal_errors_cv) 207 | id_cal_errors_cv = np.mean(id_cal_errors_cv, axis=0) 208 | ood_cal_errors_cv = np.mean(ood_cal_errors_cv, axis=0) 209 | # print 210 | f.write(f"ID likelihoods: {id_likelihoods_cv}\n") 211 | f.write(f"OOD likelihoods: {ood_likelihoods_cv}\n") 212 | f.write(f"ID calibration errors: {id_cal_errors_cv.tolist()}\n") 213 | f.write(f"OOD calibration errors: {ood_cal_errors_cv.tolist()}\n") -------------------------------------------------------------------------------- /output/arima_colas.txt: -------------------------------------------------------------------------------- 1 | Optimization started at 2023-03-22 17:28:57.493837 2 | -------------------------------- 3 | Loading column definition... 4 | Checking column definition... 5 | Loading data... 6 | Dropping columns / rows... 7 | Checking for NA values... 8 | Setting data types... 9 | Dropping columns / rows... 10 | Encoding data... 11 | Updated column definition: 12 | id: REAL_VALUED (ID) 13 | time: DATE (TIME) 14 | gl: REAL_VALUED (TARGET) 15 | gender: REAL_VALUED (STATIC_INPUT) 16 | age: REAL_VALUED (STATIC_INPUT) 17 | BMI: REAL_VALUED (STATIC_INPUT) 18 | glycaemia: REAL_VALUED (STATIC_INPUT) 19 | HbA1c: REAL_VALUED (STATIC_INPUT) 20 | follow.up: REAL_VALUED (STATIC_INPUT) 21 | T2DM: REAL_VALUED (STATIC_INPUT) 22 | time_year: REAL_VALUED (KNOWN_INPUT) 23 | time_month: REAL_VALUED (KNOWN_INPUT) 24 | time_day: REAL_VALUED (KNOWN_INPUT) 25 | time_hour: REAL_VALUED (KNOWN_INPUT) 26 | time_minute: REAL_VALUED (KNOWN_INPUT) 27 | Interpolating data... 28 | Dropped segments: 63 29 | Extracted segments: 205 30 | Interpolated values: 241 31 | Percent of values interpolated: 0.22% 32 | Splitting data... 33 | Train: 72275 (45.89%) 34 | Val: 35713 (22.68%) 35 | Test: 38253 (24.29%) 36 | Test OOD: 11242 (7.14%) 37 | Scaling data... 38 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 39 | Data formatting complete. 40 | -------------------------------- 41 | Train: 72173 (45.75%) 42 | Val: 35885 (22.74%) 43 | Test: 38253 (24.25%) 44 | Test OOD: 11460 (7.26%) 45 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 46 | Seed: 10 ID mean of (MSE, MAE): [120.6773731841762, 6.6476186379323945] 47 | Seed: 10 OOD mean of (MSE, MAE) stats: [98.49733495015732, 5.962563865719668] 48 | Seed: 10 ID median of (MSE, MAE): [34.119099014633285, 4.8265690635532] 49 | Seed: 10 OOD median of (MSE, MAE) stats: [25.749051321984638, 4.250000388361499] 50 | Train: 71945 (45.73%) 51 | Val: 35713 (22.70%) 52 | Test: 38037 (24.18%) 53 | Test OOD: 11644 (7.40%) 54 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 55 | Seed: 11 ID mean of (MSE, MAE): [110.98849348314229, 6.512092322826006] 56 | Seed: 11 OOD mean of (MSE, MAE) stats: [143.8660811618338, 6.893749882596932] 57 | Seed: 11 ID median of (MSE, MAE): [33.800056814408784, 4.806109972796595] 58 | Seed: 11 OOD median of (MSE, MAE) stats: [31.027526450957808, 4.534814927485767] 59 | Train: 71565 (45.53%) 60 | Val: 35497 (22.58%) 61 | Test: 38037 (24.20%) 62 | Test OOD: 12096 (7.69%) 63 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 64 | Seed: 12 ID mean of (MSE, MAE): [116.43544050531845, 6.5224223472168985] 65 | Seed: 12 OOD mean of (MSE, MAE) stats: [172.39000789417685, 7.705362619675043] 66 | Seed: 12 ID median of (MSE, MAE): [33.086893503512236, 4.795547888826327] 67 | Seed: 12 OOD median of (MSE, MAE) stats: [45.878204185631716, 5.544957629357246] 68 | Train: 73201 (46.27%) 69 | Val: 36332 (22.97%) 70 | Test: 38469 (24.32%) 71 | Test OOD: 10201 (6.45%) 72 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 73 | Seed: 13 ID mean of (MSE, MAE): [120.10565412076012, 6.670143753397182] 74 | Seed: 13 OOD mean of (MSE, MAE) stats: [156.32090280031198, 7.36687788419089] 75 | Seed: 13 ID median of (MSE, MAE): [34.19618027031381, 4.847393085094296] 76 | Seed: 13 OOD median of (MSE, MAE) stats: [40.91801930218372, 5.3125520415657235] 77 | Train: 72721 (45.97%) 78 | Val: 36577 (23.12%) 79 | Test: 38240 (24.17%) 80 | Test OOD: 10665 (6.74%) 81 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 82 | Seed: 14 ID mean of (MSE, MAE): [119.96416492732897, 6.6778244052465086] 83 | Seed: 14 OOD mean of (MSE, MAE) stats: [116.36936191941496, 6.324921143713053] 84 | Seed: 14 ID median of (MSE, MAE): [34.350508699577546, 4.845801287754875] 85 | Seed: 14 OOD median of (MSE, MAE) stats: [28.11721839067325, 4.371174125759447] 86 | Train: 72280 (45.90%) 87 | Val: 35929 (22.81%) 88 | Test: 38037 (24.15%) 89 | Test OOD: 11237 (7.14%) 90 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 91 | Seed: 15 ID mean of (MSE, MAE): [115.49432265308754, 6.4833270353370915] 92 | Seed: 15 OOD mean of (MSE, MAE) stats: [141.04209618074884, 7.00637395501393] 93 | Seed: 15 ID median of (MSE, MAE): [32.41810861864848, 4.718699547205426] 94 | Seed: 15 OOD median of (MSE, MAE) stats: [32.61929274481187, 4.849921090236984] 95 | Train: 71826 (45.65%) 96 | Val: 35713 (22.70%) 97 | Test: 38037 (24.18%) 98 | Test OOD: 11763 (7.48%) 99 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 100 | Seed: 16 ID mean of (MSE, MAE): [118.60316998673208, 6.570685408763114] 101 | Seed: 16 OOD mean of (MSE, MAE) stats: [172.53892554528767, 7.4691531687181705] 102 | Seed: 16 ID median of (MSE, MAE): [32.9922624794985, 4.776514077411481] 103 | Seed: 16 OOD median of (MSE, MAE) stats: [37.722542194470094, 5.025782263597927] 104 | Train: 72187 (45.92%) 105 | Val: 35497 (22.58%) 106 | Test: 38037 (24.20%) 107 | Test OOD: 11474 (7.30%) 108 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 109 | Seed: 17 ID mean of (MSE, MAE): [118.4226276104948, 6.565824634002014] 110 | Seed: 17 OOD mean of (MSE, MAE) stats: [120.57096069178137, 6.718860547848144] 111 | Seed: 17 ID median of (MSE, MAE): [32.75335854245767, 4.776514077411481] 112 | Seed: 17 OOD median of (MSE, MAE) stats: [36.189480921506764, 4.916667081997722] 113 | Train: 71880 (45.73%) 114 | Val: 35497 (22.58%) 115 | Test: 38037 (24.20%) 116 | Test OOD: 11781 (7.49%) 117 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 118 | Seed: 18 ID mean of (MSE, MAE): [118.39438829343224, 6.60487045660232] 119 | Seed: 18 OOD mean of (MSE, MAE) stats: [146.7202339119919, 7.267020937497021] 120 | Seed: 18 ID median of (MSE, MAE): [34.3461620511888, 4.825534360244473] 121 | Seed: 18 OOD median of (MSE, MAE) stats: [34.70171129616214, 4.99999870546165] 122 | Train: 72349 (45.90%) 123 | Val: 36145 (22.93%) 124 | Test: 38037 (24.13%) 125 | Test OOD: 11096 (7.04%) 126 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 127 | Seed: 19 ID mean of (MSE, MAE): [121.40247352207783, 6.655633166667995] 128 | Seed: 19 OOD mean of (MSE, MAE) stats: [137.0253490270386, 6.895576622625327] 129 | Seed: 19 ID median of (MSE, MAE): [33.92787207032102, 4.81699414126529] 130 | Seed: 19 OOD median of (MSE, MAE) stats: [36.766653251293356, 4.943558197736419] 131 | ID mean of (MSE, MAE): [118.04881082865504, 6.591044216799152] +- [2.9257223155159435, 0.06704249579628511] 132 | OOD mean of (MSE, MAE): [140.53412540827432, 6.961046062759818] +- [22.582784874200478, 0.5041473132551306] 133 | ID median of (MSE, MAE): [33.59905020645601, 4.803567750156345] +- [0.6810816239186338, 0.036923417227094066] 134 | OOD median of (MSE, MAE): [34.968970005967535, 4.874942645156038] +- [5.654969833183597, 0.3802513881996165] 135 | -------------------------------------------------------------------------------- /output/arima_dubosson.txt: -------------------------------------------------------------------------------- 1 | Optimization started at 2023-03-22 17:28:57.405260 2 | -------------------------------- 3 | Loading column definition... 4 | Checking column definition... 5 | Loading data... 6 | Dropping columns / rows... 7 | Checking for NA values... 8 | Setting data types... 9 | Dropping columns / rows... 10 | Encoding data... 11 | Updated column definition: 12 | id: REAL_VALUED (ID) 13 | time: DATE (TIME) 14 | gl: REAL_VALUED (TARGET) 15 | fast_insulin: REAL_VALUED (OBSERVED_INPUT) 16 | slow_insulin: REAL_VALUED (OBSERVED_INPUT) 17 | calories: REAL_VALUED (OBSERVED_INPUT) 18 | balance: REAL_VALUED (OBSERVED_INPUT) 19 | quality: REAL_VALUED (OBSERVED_INPUT) 20 | HR: REAL_VALUED (OBSERVED_INPUT) 21 | BR: REAL_VALUED (OBSERVED_INPUT) 22 | Posture: REAL_VALUED (OBSERVED_INPUT) 23 | Activity: REAL_VALUED (OBSERVED_INPUT) 24 | HRV: REAL_VALUED (OBSERVED_INPUT) 25 | CoreTemp: REAL_VALUED (OBSERVED_INPUT) 26 | time_year: REAL_VALUED (KNOWN_INPUT) 27 | time_month: REAL_VALUED (KNOWN_INPUT) 28 | time_day: REAL_VALUED (KNOWN_INPUT) 29 | time_hour: REAL_VALUED (KNOWN_INPUT) 30 | time_minute: REAL_VALUED (KNOWN_INPUT) 31 | Interpolating data... 32 | Dropped segments: 1 33 | Extracted segments: 8 34 | Interpolated values: 0 35 | Percent of values interpolated: 0.00% 36 | Splitting data... 37 | Train: 4654 (47.17%) 38 | Val: 2016 (20.43%) 39 | Test: 2057 (20.85%) 40 | Test OOD: 1140 (11.55%) 41 | Scaling data... 42 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 43 | Data formatting complete. 44 | -------------------------------- 45 | Train: 4825 (48.90%) 46 | Val: 2016 (20.43%) 47 | Test: 2057 (20.85%) 48 | Test OOD: 969 (9.82%) 49 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 50 | Seed: 10 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993] 51 | Seed: 10 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358] 52 | Seed: 10 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072] 53 | Seed: 10 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295] 54 | Train: 4825 (48.90%) 55 | Val: 2016 (20.43%) 56 | Test: 2057 (20.85%) 57 | Test OOD: 969 (9.82%) 58 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 59 | Seed: 11 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993] 60 | Seed: 11 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358] 61 | Seed: 11 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072] 62 | Seed: 11 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295] 63 | Train: 4514 (45.75%) 64 | Val: 2016 (20.43%) 65 | Test: 2057 (20.85%) 66 | Test OOD: 1280 (12.97%) 67 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 68 | Seed: 12 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009] 69 | Seed: 12 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665] 70 | Seed: 12 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228] 71 | Seed: 12 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026] 72 | Train: 4738 (48.02%) 73 | Val: 2016 (20.43%) 74 | Test: 2057 (20.85%) 75 | Test OOD: 1056 (10.70%) 76 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 77 | Seed: 13 ID mean of (MSE, MAE): [1335.955292005022, 16.945752079471546] 78 | Seed: 13 OOD mean of (MSE, MAE) stats: [1267.9886795415546, 21.3231923989978] 79 | Seed: 13 ID median of (MSE, MAE): [146.7812227135625, 10.124277894741939] 80 | Seed: 13 OOD median of (MSE, MAE) stats: [307.56874380049413, 14.881323057095171] 81 | Train: 4654 (47.17%) 82 | Val: 2016 (20.43%) 83 | Test: 2057 (20.85%) 84 | Test OOD: 1140 (11.55%) 85 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 86 | Seed: 14 ID mean of (MSE, MAE): [1562.139924622624, 19.56357352526093] 87 | Seed: 14 OOD mean of (MSE, MAE) stats: [575.304073148924, 15.900844263774417] 88 | Seed: 14 ID median of (MSE, MAE): [202.41938206965023, 11.364012783638028] 89 | Seed: 14 OOD median of (MSE, MAE) stats: [243.24603852839633, 12.629752971974511] 90 | Train: 4885 (49.51%) 91 | Val: 2016 (20.43%) 92 | Test: 2057 (20.85%) 93 | Test OOD: 909 (9.21%) 94 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 95 | Seed: 15 ID mean of (MSE, MAE): [1529.236218668717, 19.214881713534723] 96 | Seed: 15 OOD mean of (MSE, MAE) stats: [306.96453496345606, 11.01309437759319] 97 | Seed: 15 ID median of (MSE, MAE): [171.44486633813915, 10.68677369459521] 98 | Seed: 15 OOD median of (MSE, MAE) stats: [92.69485559816981, 7.574119570527739] 99 | Train: 4825 (48.90%) 100 | Val: 2016 (20.43%) 101 | Test: 2057 (20.85%) 102 | Test OOD: 969 (9.82%) 103 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 104 | Seed: 16 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993] 105 | Seed: 16 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358] 106 | Seed: 16 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072] 107 | Seed: 16 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295] 108 | Train: 4514 (45.75%) 109 | Val: 2016 (20.43%) 110 | Test: 2057 (20.85%) 111 | Test OOD: 1280 (12.97%) 112 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 113 | Seed: 17 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009] 114 | Seed: 17 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665] 115 | Seed: 17 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228] 116 | Seed: 17 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026] 117 | Train: 4514 (45.75%) 118 | Val: 2016 (20.43%) 119 | Test: 2057 (20.85%) 120 | Test OOD: 1280 (12.97%) 121 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 122 | Seed: 18 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009] 123 | Seed: 18 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665] 124 | Seed: 18 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228] 125 | Seed: 18 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026] 126 | Train: 4738 (48.02%) 127 | Val: 2016 (20.43%) 128 | Test: 2057 (20.85%) 129 | Test OOD: 1056 (10.70%) 130 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute'] 131 | Seed: 19 ID mean of (MSE, MAE): [1335.955292005022, 16.945752079471546] 132 | Seed: 19 OOD mean of (MSE, MAE) stats: [1267.9886795415546, 21.3231923989978] 133 | Seed: 19 ID median of (MSE, MAE): [146.7812227135625, 10.124277894741939] 134 | Seed: 19 OOD median of (MSE, MAE) stats: [307.56874380049413, 14.881323057095171] 135 | ID mean of (MSE, MAE): [1236.123499179212, 17.7538916472052] +- [405.6502363774581, 2.1106334578285266] 136 | OOD mean of (MSE, MAE): [1152.4934878974216, 19.711575863567326] +- [827.4111971731463, 7.114984874764907] 137 | ID median of (MSE, MAE): [183.1096725417644, 11.059067918537597] +- [40.57762165059761, 0.9819635818277976] 138 | OOD median of (MSE, MAE): [351.52634452719565, 14.577068725686317] +- [227.92572769481512, 4.75154932415569] 139 | -------------------------------------------------------------------------------- /output/arima_iglu.txt: -------------------------------------------------------------------------------- 1 | Optimization started at 2023-03-22 17:28:57.399282 2 | -------------------------------- 3 | Loading column definition... 4 | Checking column definition... 5 | Loading data... 6 | Dropping columns / rows... 7 | Checking for NA values... 8 | Setting data types... 9 | Dropping columns / rows... 10 | Encoding data... 11 | Updated column definition: 12 | id: REAL_VALUED (ID) 13 | time: DATE (TIME) 14 | gl: REAL_VALUED (TARGET) 15 | time_year: REAL_VALUED (KNOWN_INPUT) 16 | time_month: REAL_VALUED (KNOWN_INPUT) 17 | time_day: REAL_VALUED (KNOWN_INPUT) 18 | time_hour: REAL_VALUED (KNOWN_INPUT) 19 | time_minute: REAL_VALUED (KNOWN_INPUT) 20 | time_second: REAL_VALUED (KNOWN_INPUT) 21 | Interpolating data... 22 | Dropped segments: 17 23 | Extracted segments: 15 24 | Interpolated values: 561 25 | Percent of values interpolated: 4.37% 26 | Splitting data... 27 | Train: 9056 (64.79%) 28 | Val: 1774 (12.69%) 29 | Test: 1848 (13.22%) 30 | Test OOD: 1300 (9.30%) 31 | Scaling data... 32 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 33 | Data formatting complete. 34 | -------------------------------- 35 | Train: 9056 (64.79%) 36 | Val: 1774 (12.69%) 37 | Test: 1848 (13.22%) 38 | Test OOD: 1300 (9.30%) 39 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 40 | Seed: 10 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166] 41 | Seed: 10 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754] 42 | Seed: 10 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182] 43 | Seed: 10 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878] 44 | Train: 9056 (64.79%) 45 | Val: 1774 (12.69%) 46 | Test: 1848 (13.22%) 47 | Test OOD: 1300 (9.30%) 48 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 49 | Seed: 11 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166] 50 | Seed: 11 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754] 51 | Seed: 11 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182] 52 | Seed: 11 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878] 53 | Train: 8110 (59.66%) 54 | Val: 1342 (9.87%) 55 | Test: 2017 (14.84%) 56 | Test OOD: 2125 (15.63%) 57 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 58 | Seed: 12 ID mean of (MSE, MAE): [427.0167933906876, 12.995149726570737] 59 | Seed: 12 OOD mean of (MSE, MAE) stats: [531.8962502619843, 14.276775793295084] 60 | Seed: 12 ID median of (MSE, MAE): [122.93745002201426, 9.253760038373024] 61 | Seed: 12 OOD median of (MSE, MAE) stats: [144.73854388702725, 10.216134214504272] 62 | Train: 7661 (55.57%) 63 | Val: 1296 (9.40%) 64 | Test: 2017 (14.63%) 65 | Test OOD: 2812 (20.40%) 66 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 67 | Seed: 13 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374] 68 | Seed: 13 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812] 69 | Seed: 13 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416] 70 | Seed: 13 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078] 71 | Train: 7661 (55.57%) 72 | Val: 1296 (9.40%) 73 | Test: 2017 (14.63%) 74 | Test OOD: 2812 (20.40%) 75 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 76 | Seed: 14 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374] 77 | Seed: 14 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812] 78 | Seed: 14 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416] 79 | Seed: 14 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078] 80 | Train: 9056 (64.79%) 81 | Val: 1774 (12.69%) 82 | Test: 1848 (13.22%) 83 | Test OOD: 1300 (9.30%) 84 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 85 | Seed: 15 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166] 86 | Seed: 15 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754] 87 | Seed: 15 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182] 88 | Seed: 15 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878] 89 | Train: 8110 (59.66%) 90 | Val: 1342 (9.87%) 91 | Test: 2017 (14.84%) 92 | Test OOD: 2125 (15.63%) 93 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 94 | Seed: 16 ID mean of (MSE, MAE): [427.0167933906876, 12.995149726570737] 95 | Seed: 16 OOD mean of (MSE, MAE) stats: [531.8962502619843, 14.276775793295084] 96 | Seed: 16 ID median of (MSE, MAE): [122.93745002201426, 9.253760038373024] 97 | Seed: 16 OOD median of (MSE, MAE) stats: [144.73854388702725, 10.216134214504272] 98 | Train: 7643 (55.44%) 99 | Val: 1342 (9.73%) 100 | Test: 1897 (13.76%) 101 | Test OOD: 2904 (21.06%) 102 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 103 | Seed: 17 ID mean of (MSE, MAE): [415.69378849524696, 12.258175847193586] 104 | Seed: 17 OOD mean of (MSE, MAE) stats: [917.813246785934, 18.642029569388345] 105 | Seed: 17 ID median of (MSE, MAE): [83.99256708897218, 7.697683488797355] 106 | Seed: 17 OOD median of (MSE, MAE) stats: [224.85990324440925, 12.498202781002497] 107 | Train: 7643 (55.44%) 108 | Val: 1342 (9.73%) 109 | Test: 1897 (13.76%) 110 | Test OOD: 2904 (21.06%) 111 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 112 | Seed: 18 ID mean of (MSE, MAE): [415.69378849524696, 12.258175847193586] 113 | Seed: 18 OOD mean of (MSE, MAE) stats: [917.813246785934, 18.642029569388345] 114 | Seed: 18 ID median of (MSE, MAE): [83.99256708897218, 7.697683488797355] 115 | Seed: 18 OOD median of (MSE, MAE) stats: [224.85990324440925, 12.498202781002497] 116 | Train: 7661 (55.57%) 117 | Val: 1296 (9.40%) 118 | Test: 2017 (14.63%) 119 | Test OOD: 2812 (20.40%) 120 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second'] 121 | Seed: 19 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374] 122 | Seed: 19 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812] 123 | Seed: 19 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416] 124 | Seed: 19 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078] 125 | ID mean of (MSE, MAE): [415.3267773276798, 12.827756706235636] +- [18.126631939072656, 0.32319195589326777] 126 | OOD mean of (MSE, MAE): [536.0014042607065, 14.177183010335256] +- [198.40461290827625, 2.337217034828997] 127 | ID median of (MSE, MAE): [110.89655963565743, 8.667667482799056] +- [21.952031091293335, 0.7358828558087035] 128 | OOD median of (MSE, MAE): [138.0838904367558, 9.709008252925353] +- [52.73049091547017, 1.933565723289778] 129 | -------------------------------------------------------------------------------- /paper_results/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | from typing import Sequence, Dict 5 | 6 | def avg_results(model_names: str, 7 | model_names_with_covs: str = None, 8 | time_steps: int = 12)->Sequence[Dict[str, Dict[str, np.array]]]: 9 | """ 10 | Function to load final model results: averaged across random seeds / folds 11 | ---------- 12 | model_names: str 13 | path to models' results file 14 | model_name_with_covs: str 15 | path to models' results file with covariates 16 | time_steps: int 17 | number of time steps that were predicted 18 | NOTE: model_names and model_names_with_covs should be in the same order 19 | 20 | Output 21 | ------ 22 | Computes the following set of dictionaries: 23 | dict1: 24 | Dictionary of MSE / MAE values for ID / OOD sets with and without covariates 25 | key1: id / od, key2: covs / no_covs 26 | dict2: 27 | Dictionary of likelihood / calibration values for ID / OOD sets with and without covariates 28 | key1: id / od, key2: covs / no_covs 29 | """ 30 | def parser(model_names): 31 | arr_id_errors = np.full((len(model_names), 2), np.nan) 32 | arr_ood_errors = arr_id_errors.copy() 33 | arr_id_likelihoods = arr_id_errors.copy() 34 | arr_ood_likelihoods = arr_id_errors.copy() 35 | arr_id_errors_std = np.full((len(model_names), 2, 2), np.nan) 36 | arr_ood_errors_std = arr_id_errors_std.copy() 37 | arr_id_likelihoods_std = arr_id_errors_std.copy() 38 | arr_ood_likelihoods_std = arr_id_errors_std.copy() 39 | for model_name in model_names: 40 | if not os.path.isfile(model_name): 41 | continue 42 | with open(model_name, 'r') as f: 43 | for line in f: 44 | if line.startswith('ID median of (MSE, MAE):'): 45 | id_mse_mae = re.findall(r'\d+\.\d+(?:e-\d+)?', line) 46 | arr_id_errors[model_names.index(model_name), 0] = float(id_mse_mae[0]) 47 | arr_id_errors[model_names.index(model_name), 1] = float(id_mse_mae[1]) 48 | if len(id_mse_mae) > 2: 49 | arr_id_errors_std[model_names.index(model_name), 0, 0] = float(id_mse_mae[2]) 50 | arr_id_errors_std[model_names.index(model_name), 0, 1] = float(id_mse_mae[3]) 51 | if len(id_mse_mae) > 4: 52 | arr_id_errors_std[model_names.index(model_name), 1, 0] = float(id_mse_mae[4]) 53 | arr_id_errors_std[model_names.index(model_name), 1, 1] = float(id_mse_mae[5]) 54 | elif line.startswith('OOD median of (MSE, MAE):'): 55 | ood_mse_mae = re.findall(r'\d+\.\d+(?:e-\d+)?', line) 56 | arr_ood_errors[model_names.index(model_name), 0] = float(ood_mse_mae[0]) 57 | arr_ood_errors[model_names.index(model_name), 1] = float(ood_mse_mae[1]) 58 | if len(ood_mse_mae) > 2: 59 | arr_ood_errors_std[model_names.index(model_name), 0, 0] = float(ood_mse_mae[2]) 60 | arr_ood_errors_std[model_names.index(model_name), 0, 1] = float(ood_mse_mae[3]) 61 | if len(ood_mse_mae) > 4: 62 | arr_ood_errors_std[model_names.index(model_name), 1, 0] = float(ood_mse_mae[4]) 63 | arr_ood_errors_std[model_names.index(model_name), 1, 1] = float(ood_mse_mae[5]) 64 | elif line.startswith('ID likelihoods:'): 65 | id_likelihoods = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line) 66 | arr_id_likelihoods[model_names.index(model_name), 0] = float(id_likelihoods[0]) 67 | if len(id_likelihoods) > 1: 68 | arr_id_likelihoods_std[model_names.index(model_name), 0, 0] = float(id_likelihoods[1]) 69 | if len(id_likelihoods) > 2: 70 | arr_id_likelihoods_std[model_names.index(model_name), 1, 0] = float(id_likelihoods[2]) 71 | elif line.startswith('OOD likelihoods:'): 72 | ood_likelihoods = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line) 73 | arr_ood_likelihoods[model_names.index(model_name), 0] = float(ood_likelihoods[0]) 74 | if len(ood_likelihoods) > 1: 75 | arr_ood_likelihoods_std[model_names.index(model_name), 0, 0] = float(ood_likelihoods[1]) 76 | if len(ood_likelihoods) > 2: 77 | arr_ood_likelihoods_std[model_names.index(model_name), 1, 0] = float(ood_likelihoods[2]) 78 | elif line.startswith('ID calibration errors:'): 79 | id_calib = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line) 80 | arr_id_likelihoods[model_names.index(model_name), 1] = np.mean([float(x) for x in id_calib[:time_steps]]) 81 | if len(id_calib) > time_steps: 82 | arr_id_likelihoods_std[model_names.index(model_name), 0, 1] = np.mean([float(x) for x in id_calib[time_steps:]]) 83 | if len(id_calib) > 2*time_steps: 84 | arr_id_likelihoods_std[model_names.index(model_name), 1, 1] = np.mean([float(x) for x in id_calib[2*time_steps:]]) 85 | elif line.startswith('OOD calibration errors:'): 86 | ood_calib = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line) 87 | arr_ood_likelihoods[model_names.index(model_name), 1] = np.mean([float(x) for x in ood_calib[:time_steps]]) 88 | if len(ood_calib) > time_steps: 89 | arr_ood_likelihoods_std[model_names.index(model_name), 0, 1] = np.mean([float(x) for x in ood_calib[time_steps:]]) 90 | if len(ood_calib) > 2*time_steps: 91 | arr_ood_likelihoods_std[model_names.index(model_name), 1, 1] = np.mean([float(x) for x in ood_calib[2*time_steps:]]) 92 | return (arr_id_errors, arr_ood_errors, arr_id_likelihoods, arr_ood_likelihoods), \ 93 | (arr_id_errors_std, arr_ood_errors_std, arr_id_likelihoods_std, arr_ood_likelihoods_std) 94 | 95 | dict_errors, dict_errors_std = {}, {} 96 | dict_likelihoods, dict_likelihoods_std = {}, {} 97 | error, error_std = parser(model_names) 98 | dict_errors['id'] = {'no_covs': error[0]} 99 | dict_errors['ood'] = {'no_covs': error[1]} 100 | dict_likelihoods['id'] = {'no_covs': error[2]} 101 | dict_likelihoods['ood'] = {'no_covs': error[3]} 102 | dict_errors_std['id'] = {'no_covs': error_std[0]} 103 | dict_errors_std['ood'] = {'no_covs': error_std[1]} 104 | dict_likelihoods_std['id'] = {'no_covs': error_std[2]} 105 | dict_likelihoods_std['ood'] = {'no_covs': error_std[3]} 106 | 107 | if model_names_with_covs is not None: 108 | error, error_std = parser(model_names_with_covs) 109 | dict_errors['id']['covs'] = error[0] 110 | dict_errors['ood']['covs'] = error[1] 111 | dict_likelihoods['id']['covs'] = error[2] 112 | dict_likelihoods['ood']['covs'] = error[3] 113 | dict_errors_std['id']['covs'] = error_std[0] 114 | dict_errors_std['ood']['covs'] = error_std[1] 115 | dict_likelihoods_std['id']['covs'] = error_std[2] 116 | dict_likelihoods_std['ood']['covs'] = error_std[3] 117 | 118 | return (dict_errors, dict_likelihoods), (dict_errors_std, dict_likelihoods_std) -------------------------------------------------------------------------------- /paper_results/plots/figure2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure2.pdf -------------------------------------------------------------------------------- /paper_results/plots/figure3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure3.pdf -------------------------------------------------------------------------------- /paper_results/plots/figure3_annot.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure3_annot.pptx -------------------------------------------------------------------------------- /paper_results/plots/figure4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure4.pdf -------------------------------------------------------------------------------- /paper_results/plots/figure5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure5.pdf -------------------------------------------------------------------------------- /paper_results/plots/figure6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure6.pdf -------------------------------------------------------------------------------- /paper_results/plots/figure6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure6.png -------------------------------------------------------------------------------- /paper_results/plots/nhits_single_prediction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/nhits_single_prediction.pdf -------------------------------------------------------------------------------- /raw_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/raw_data.zip -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | darts==0.29.0 2 | matplotlib==3.8.4 3 | numpy==1.26.4 4 | optuna==3.6.1 5 | pandas==2.2.2 6 | pillow==10.3.0 7 | pmdarima==2.0.4 8 | pytorch-lightning==2.2.4 9 | PyYAML==6.0.1 10 | requests==2.31.0 11 | scikit-learn==1.4.2 12 | scipy==1.13.0 13 | seaborn==0.13.2 14 | statsforecast==1.7.4 15 | statsmodels==0.14.2 16 | torch==2.3.0 17 | torchaudio==2.3.0 18 | torchmetrics==1.3.2 19 | torchvision==0.18.0 20 | tqdm==4.66.4 21 | xgboost==2.0.3 22 | jupyter==1.0.0 23 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/utils/__init__.py -------------------------------------------------------------------------------- /utils/darts_evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | import random 5 | from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union 6 | 7 | import numpy as np 8 | from scipy import stats 9 | import pandas as pd 10 | import darts 11 | 12 | from darts import models 13 | from darts import metrics 14 | from darts import TimeSeries 15 | 16 | # import data formatter 17 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 18 | from data_formatter.base import * 19 | from utils.darts_processing import * 20 | 21 | def _get_values( 22 | series: TimeSeries, stochastic_quantile: Optional[float] = 0.5 23 | ) -> np.ndarray: 24 | """ 25 | Returns the numpy values of a time series. 26 | For stochastic series, return either all sample values with (stochastic_quantile=None) or the quantile sample value 27 | with (stochastic_quantile {>=0,<=1}) 28 | """ 29 | if series.is_deterministic: 30 | series_values = series.univariate_values() 31 | else: # stochastic 32 | if stochastic_quantile is None: 33 | series_values = series.all_values(copy=False) 34 | else: 35 | series_values = series.quantile_timeseries( 36 | quantile=stochastic_quantile 37 | ).univariate_values() 38 | return series_values 39 | 40 | def _get_values_or_raise( 41 | series_a: TimeSeries, 42 | series_b: TimeSeries, 43 | intersect: bool, 44 | stochastic_quantile: Optional[float] = 0.5, 45 | remove_nan_union: bool = False, 46 | ) -> Tuple[np.ndarray, np.ndarray]: 47 | """Returns the processed numpy values of two time series. Processing can be customized with arguments 48 | `intersect, stochastic_quantile, remove_nan_union`. 49 | 50 | Raises a ValueError if the two time series (or their intersection) do not have the same time index. 51 | 52 | Parameters 53 | ---------- 54 | series_a 55 | A univariate deterministic ``TimeSeries`` instance (the actual series). 56 | series_b 57 | A univariate (deterministic or stochastic) ``TimeSeries`` instance (the predicted series). 58 | intersect 59 | A boolean for whether or not to only consider the time intersection between `series_a` and `series_b` 60 | stochastic_quantile 61 | Optionally, for stochastic predicted series, return either all sample values with (`stochastic_quantile=None`) 62 | or any deterministic quantile sample values by setting `stochastic_quantile=quantile` {>=0,<=1}. 63 | remove_nan_union 64 | By setting `remove_non_union` to True, remove all indices from `series_a` and `series_b` which have a NaN value 65 | in either of the two input series. 66 | """ 67 | series_a_common = series_a.slice_intersect(series_b) if intersect else series_a 68 | series_b_common = series_b.slice_intersect(series_a) if intersect else series_b 69 | 70 | series_a_det = _get_values(series_a_common, stochastic_quantile=stochastic_quantile) 71 | series_b_det = _get_values(series_b_common, stochastic_quantile=stochastic_quantile) 72 | 73 | if not remove_nan_union: 74 | return series_a_det, series_b_det 75 | 76 | b_is_deterministic = bool(len(series_b_det.shape) == 1) 77 | if b_is_deterministic: 78 | isnan_mask = np.logical_or(np.isnan(series_a_det), np.isnan(series_b_det)) 79 | else: 80 | isnan_mask = np.logical_or( 81 | np.isnan(series_a_det), np.isnan(series_b_det).any(axis=2).flatten() 82 | ) 83 | return np.delete(series_a_det, isnan_mask), np.delete( 84 | series_b_det, isnan_mask, axis=0 85 | ) 86 | 87 | def rescale_and_backtest(series: Union[TimeSeries, 88 | Sequence[TimeSeries]], 89 | forecasts: Union[TimeSeries, 90 | Sequence[TimeSeries], 91 | Sequence[Sequence[TimeSeries]]], 92 | metric: Union[ 93 | Callable[[TimeSeries, TimeSeries], float], 94 | List[Callable[[TimeSeries, TimeSeries], float]], 95 | ], 96 | scaler: Callable[[TimeSeries], TimeSeries] = None, 97 | reduction: Union[Callable[[np.ndarray], float], None] = np.mean, 98 | likelihood: str = "GaussianMean", 99 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11), 100 | ): 101 | """ 102 | Backtest the historical forecasts (as provided by Darts) on the series. 103 | 104 | Parameters 105 | ---------- 106 | series 107 | The target time series. 108 | forecasts 109 | The forecasts. 110 | scaler 111 | The scaler used to scale the series. 112 | metric 113 | The metric or metrics to use for backtesting. 114 | reduction 115 | The reduction to apply to the metric. 116 | likelihood 117 | The likelihood to use for evaluating the model. 118 | cal_thresholds 119 | The thresholds to use for computing the calibration error. 120 | 121 | Returns 122 | ------- 123 | np.ndarray 124 | Error array. If the reduction is none, array is of shape (n, p) 125 | where n is the total number of samples (forecasts) and p is the number of metrics. 126 | If the reduction is not none, array is of shape (k, p), where k is the number of series. 127 | float 128 | The estimated log-likelihood of the model on the data. 129 | np.ndarray 130 | The ECE for each time point in the forecast. 131 | """ 132 | series = [series] if isinstance(series, TimeSeries) else series 133 | forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts 134 | metric = [metric] if not isinstance(metric, list) else metric 135 | 136 | # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors 137 | backtest_list = [] 138 | for idx in range(len(series)): 139 | if scaler is not None: 140 | series[idx] = scaler.inverse_transform(series[idx]) 141 | forecasts[idx] = [scaler.inverse_transform(f) for f in forecasts[idx]] 142 | errors = [ 143 | [metric_f(series[idx], f) for metric_f in metric] 144 | if len(metric) > 1 145 | else metric[0](series[idx], f) 146 | for f in forecasts[idx] 147 | ] 148 | if reduction is None: 149 | backtest_list.append(np.array(errors)) 150 | else: 151 | backtest_list.append(reduction(np.array(errors), axis=0)) 152 | backtest_list = np.vstack(backtest_list) 153 | 154 | if likelihood == "GaussianMean": 155 | # compute likelihood 156 | est_var = [] 157 | for idx, target_ts in enumerate(series): 158 | est_var += [metrics.mse(target_ts, f) for f in forecasts[idx]] 159 | est_var = np.mean(est_var) 160 | forecast_len = forecasts[0][0].n_timesteps 161 | log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var) 162 | 163 | # compute calibration error: 1) cdf values 2) compute calibration error 164 | # compute the cdf values 165 | cdf_vals = [] 166 | for idx in range(len(series)): 167 | for forecast in forecasts[idx]: 168 | y_true, y_pred = _get_values_or_raise(series[idx], 169 | forecast, 170 | intersect=True, 171 | remove_nan_union=True) 172 | y_true, y_pred = y_true.flatten(), y_pred.flatten() 173 | cdf_vals.append(stats.norm.cdf(y_true, loc=y_pred, scale=np.sqrt(est_var))) 174 | cdf_vals = np.vstack(cdf_vals) 175 | # compute the prediction calibration 176 | cal_error = np.zeros(forecasts[0][0].n_timesteps) 177 | for p in cal_thresholds: 178 | est_p = (cdf_vals <= p).astype(float) 179 | est_p = np.mean(est_p, axis=0) 180 | cal_error += (est_p - p) ** 2 181 | 182 | return backtest_list, log_likelihood, cal_error 183 | 184 | def rescale_and_test(series: Union[TimeSeries, 185 | Sequence[TimeSeries]], 186 | forecasts: Union[TimeSeries, 187 | Sequence[TimeSeries]], 188 | metric: Union[ 189 | Callable[[TimeSeries, TimeSeries], float], 190 | List[Callable[[TimeSeries, TimeSeries], float]], 191 | ], 192 | scaler: Callable[[TimeSeries], TimeSeries] = None, 193 | likelihood: str = "GaussianMean", 194 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11), 195 | ): 196 | """ 197 | Test the forecasts on the series. 198 | 199 | Parameters 200 | ---------- 201 | series 202 | The target time series. 203 | forecasts 204 | The forecasts. 205 | scaler 206 | The scaler used to scale the series. 207 | metric 208 | The metric or metrics to use for backtesting. 209 | reduction 210 | The reduction to apply to the metric. 211 | likelihood 212 | The likelihood to use for evaluating the likelihood and calibration of model. 213 | cal_thresholds 214 | The thresholds to use for computing the calibration error. 215 | 216 | Returns 217 | ------- 218 | np.ndarray 219 | Error array. If the reduction is none, array is of shape (n, p) 220 | where n is the total number of samples (forecasts) and p is the number of metrics. 221 | If the reduction is not none, array is of shape (k, p), where k is the number of series. 222 | float 223 | The estimated log-likelihood of the model on the data. 224 | np.ndarray 225 | The ECE for each time point in the forecast. 226 | """ 227 | series = [series] if isinstance(series, TimeSeries) else series 228 | forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts 229 | metric = [metric] if not isinstance(metric, list) else metric 230 | 231 | # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors 232 | series = scaler.inverse_transform(series) 233 | forecasts = scaler.inverse_transform(forecasts) 234 | errors = [ 235 | [metric_f(t, f) for metric_f in metric] 236 | if len(metric) > 1 237 | else metric[0](t, f) 238 | for (t, f) in zip(series, forecasts) 239 | ] 240 | errors = np.array(errors) 241 | 242 | if likelihood == "GaussianMean": 243 | # compute likelihood 244 | est_var = [metrics.mse(t, f) for (t, f) in zip(series, forecasts)] 245 | est_var = np.mean(est_var) 246 | forecast_len = forecasts[0].n_timesteps 247 | log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var) 248 | 249 | # compute calibration error: 1) cdf values 2) compute calibration error 250 | # compute the cdf values 251 | cdf_vals = [] 252 | for t, f in zip(series, forecasts): 253 | t, f = _get_values_or_raise(t, f, intersect=True, remove_nan_union=True) 254 | t, f = t.flatten(), f.flatten() 255 | cdf_vals.append(stats.norm.cdf(t, loc=f, scale=np.sqrt(est_var))) 256 | cdf_vals = np.vstack(cdf_vals) 257 | # compute the prediction calibration 258 | cal_error = np.zeros(forecasts[0].n_timesteps) 259 | for p in cal_thresholds: 260 | est_p = (cdf_vals <= p).astype(float) 261 | est_p = np.mean(est_p, axis=0) 262 | cal_error += (est_p - p) ** 2 263 | 264 | if likelihood == "Quantile": 265 | # no likelihood since we don't have a parametric model 266 | log_likelihood = 0 267 | 268 | # compute calibration error: 1) get quantiles 2) compute calibration error 269 | cal_error = np.zeros(forecasts[0].n_timesteps) 270 | for p in cal_thresholds: 271 | est_p = 0 272 | for t, f in zip(series, forecasts): 273 | q = f.quantile(p) 274 | t, q = _get_values_or_raise(t, q, intersect=True, remove_nan_union=True) 275 | t, q = t.flatten(), q.flatten() 276 | est_p += (t <= q).astype(float) 277 | est_p = (est_p / len(series)).flatten() 278 | cal_error += (est_p - p) ** 2 279 | 280 | return errors, log_likelihood, cal_error -------------------------------------------------------------------------------- /utils/darts_training.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | import random 5 | from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union 6 | 7 | import numpy as np 8 | from scipy import stats 9 | import pandas as pd 10 | import darts 11 | 12 | from darts import models 13 | from darts import metrics 14 | from darts import TimeSeries 15 | from pytorch_lightning.callbacks import Callback 16 | from darts.logging import get_logger, raise_if_not 17 | 18 | # for optuna callback 19 | import warnings 20 | import optuna 21 | from optuna.storages._cached_storage import _CachedStorage 22 | from optuna.storages._rdb.storage import RDBStorage 23 | # Define key names of `Trial.system_attrs`. 24 | _PRUNED_KEY = "ddp_pl:pruned" 25 | _EPOCH_KEY = "ddp_pl:epoch" 26 | with optuna._imports.try_import() as _imports: 27 | import pytorch_lightning as pl 28 | from pytorch_lightning import LightningModule 29 | from pytorch_lightning import Trainer 30 | from pytorch_lightning.callbacks import Callback 31 | if not _imports.is_successful(): 32 | Callback = object # type: ignore # NOQA 33 | LightningModule = object # type: ignore # NOQA 34 | Trainer = object # type: ignore # NOQA 35 | 36 | def print_callback(study, trial, study_file=None): 37 | # write output to a file 38 | with open(study_file, "a") as f: 39 | f.write(f"Current value: {trial.value}, Current params: {trial.params}\n") 40 | f.write(f"Best value: {study.best_value}, Best params: {study.best_trial.params}\n") 41 | 42 | def early_stopping_check(study, 43 | trial, 44 | study_file, 45 | early_stopping_rounds=10): 46 | """ 47 | Early stopping callback for Optuna. 48 | This function checks the current trial number and the best trial number. 49 | """ 50 | current_trial_number = trial.number 51 | best_trial_number = study.best_trial.number 52 | should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds 53 | if should_stop: 54 | with open(study_file, 'a') as f: 55 | f.write('\nEarly stopping at trial {} (best trial: {})'.format(current_trial_number, best_trial_number)) 56 | study.stop() 57 | 58 | class LossLogger(Callback): 59 | def __init__(self): 60 | self.train_loss = [] 61 | self.val_loss = [] 62 | 63 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 64 | self.train_loss.append(float(trainer.callback_metrics["train_loss"])) 65 | 66 | def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 67 | self.val_loss.append(float(trainer.callback_metrics["val_loss"])) 68 | 69 | class PyTorchLightningPruningCallback(Callback): 70 | """PyTorch Lightning callback to prune unpromising trials. 71 | See `the example `__ 73 | if you want to add a pruning callback which observes accuracy. 74 | Args: 75 | trial: 76 | A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the 77 | objective function. 78 | monitor: 79 | An evaluation metric for pruning, e.g., ``val_loss`` or 80 | ``val_acc``. The metrics are obtained from the returned dictionaries from e.g. 81 | ``pytorch_lightning.LightningModule.training_step`` or 82 | ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on 83 | how this dictionary is formatted. 84 | """ 85 | 86 | def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: 87 | super().__init__() 88 | 89 | self._trial = trial 90 | self.monitor = monitor 91 | 92 | def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 93 | # When the trainer calls `on_validation_end` for sanity check, 94 | # do not call `trial.report` to avoid calling `trial.report` multiple times 95 | # at epoch 0. The related page is 96 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391. 97 | if trainer.sanity_checking: 98 | return 99 | 100 | epoch = pl_module.current_epoch 101 | 102 | current_score = trainer.callback_metrics.get(self.monitor) 103 | if current_score is None: 104 | message = ( 105 | "The metric '{}' is not in the evaluation logs for pruning. " 106 | "Please make sure you set the correct metric name.".format(self.monitor) 107 | ) 108 | warnings.warn(message) 109 | return 110 | 111 | self._trial.report(current_score, step=epoch) 112 | if self._trial.should_prune(): 113 | message = "Trial was pruned at epoch {}.".format(epoch) 114 | raise optuna.TrialPruned(message) --------------------------------------------------------------------------------