├── .gitignore
├── README.md
├── bin
├── arima.sh
├── gluformer.sh
├── latentode.sh
├── linreg.sh
├── nhits.sh
├── tft.sh
├── transformer.sh
└── xgbtree.sh
├── config
├── colas.yaml
├── dubosson.yaml
├── hall.yaml
├── iglu.yaml
└── weinstock.yaml
├── data_formatter
├── __init__.py
├── base.py
├── types.py
└── utils.py
├── example.ipynb
├── exploratory_analysis
├── colas.ipynb
├── dubosson.ipynb
├── hall.ipynb
├── iglu.ipynb
└── weinstock.ipynb
├── lib
├── __init__.py
├── arima.py
├── gluformer.py
├── gluformer
│ ├── __init__.py
│ ├── attention.py
│ ├── decoder.py
│ ├── embed.py
│ ├── encoder.py
│ ├── model.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── collate.py
│ │ ├── evaluation.py
│ │ └── training.py
│ └── variance.py
├── latent_ode
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── base_models.py
│ ├── create_latent_ode_model.py
│ ├── diffeq_solver.py
│ ├── encoder_decoder.py
│ ├── eval_glunet.py
│ ├── latent_ode.py
│ ├── likelihood_eval.py
│ ├── ode_func.py
│ ├── ode_rnn.py
│ ├── plotting.py
│ ├── rnn_baselines.py
│ ├── run_models.py
│ ├── test_glunet.ipynb
│ ├── trainer_glunet.py
│ └── utils.py
├── latentode.py
├── linreg.py
├── nhits.py
├── tft.py
├── transformer.py
└── xgbtree.py
├── output
├── arima_colas.txt
├── arima_dubosson.txt
├── arima_hall.txt
├── arima_iglu.txt
├── arima_weinstock.txt
├── gluformer_colas.txt
├── gluformer_dubosson.txt
├── gluformer_hall.txt
├── gluformer_iglu.txt
├── gluformer_weinstock.txt
├── latentode_colas.txt
├── latentode_dubosson.txt
├── latentode_hall.txt
├── latentode_iglu.txt
├── latentode_weinstock.txt
├── linreg_colas.txt
├── linreg_covariates_colas.txt
├── linreg_covariates_dubosson.txt
├── linreg_covariates_hall.txt
├── linreg_covariates_iglu.txt
├── linreg_covariates_weinstock.txt
├── linreg_dubosson.txt
├── linreg_hall.txt
├── linreg_iglu.txt
├── linreg_weinstock.txt
├── nhits_colas.txt
├── nhits_covariates_colas.txt
├── nhits_covariates_dubosson.txt
├── nhits_covariates_hall.txt
├── nhits_covariates_iglu.txt
├── nhits_covariates_weinstock.txt
├── nhits_dubosson.txt
├── nhits_hall.txt
├── nhits_iglu.txt
├── nhits_weinstock.txt
├── tft_colas.txt
├── tft_covariates_colas.txt
├── tft_covariates_dubosson.txt
├── tft_covariates_hall.txt
├── tft_covariates_iglu.txt
├── tft_covariates_weinstock.txt
├── tft_dubosson.txt
├── tft_hall.txt
├── tft_iglu.txt
├── tft_weinstock.txt
├── transformer_colas.txt
├── transformer_covariates_colas.txt
├── transformer_covariates_dubosson.txt
├── transformer_covariates_hall.txt
├── transformer_covariates_iglu.txt
├── transformer_covariates_weinstock.txt
├── transformer_dubosson.txt
├── transformer_hall.txt
├── transformer_iglu.txt
├── transformer_weinstock.txt
├── xgboost_colas.txt
├── xgboost_covariates_colas.txt
├── xgboost_covariates_dubosson.txt
├── xgboost_covariates_hall.txt
├── xgboost_covariates_iglu.txt
├── xgboost_covariates_weinstock.txt
├── xgboost_dubosson.txt
├── xgboost_hall.txt
├── xgboost_iglu.txt
└── xgboost_weinstock.txt
├── paper_results
├── covariate_importance_xgb.ipynb
├── figure2.ipynb
├── figure3.ipynb
├── figure4.ipynb
├── figure5.ipynb
├── forecast_compute.ipynb
├── forecast_plot.ipynb
├── parser.py
├── plots
│ ├── figure2.pdf
│ ├── figure3.pdf
│ ├── figure3_annot.pptx
│ ├── figure4.pdf
│ ├── figure5.pdf
│ ├── figure6.pdf
│ ├── figure6.png
│ └── nhits_single_prediction.pdf
└── tables.ipynb
├── raw_data.zip
├── requirements.txt
└── utils
├── __init__.py
├── darts_dataset.py
├── darts_evaluation.py
├── darts_processing.py
└── darts_training.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore model and data files
2 | *.pth
3 | *.npy
4 | *.pkl
5 | cache/*
6 | *.out
7 | track_*
8 | output/tensorboard_*/*
9 |
10 | # ignore folder raw_data
11 | raw_data/
12 |
13 | # ignore pre-complied code
14 | *.cpython-37.pyc
15 | __pycache__/
16 | *.py[cod]
17 | *.pyc
18 | *.DS_Store
19 | .ipynb_checkpoints
20 |
21 | # ignore irrelevant files
22 | legacy
23 | papers
24 | meeting_notes.md
25 |
26 |
27 | # ignore workspace
28 | *.code-workspace
29 |
30 | # PyCharm
31 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
32 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
33 | # and can be added to the global gitignore or merged into this file. For a more nuclear
34 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
35 | .idea/
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GlucoBench
2 |
3 | The official implementation of the paper "GlucoBench: Curated List of Continuous Glucose Monitoring Datasets with Prediction Benchmarks."
4 | If you found our work interesting and plan to re-use the code, please cite us as:
5 | ```
6 | @article{
7 | author = {Renat Sergazinov and Valeriya Rogovchenko and Elizabeth Chun and Nathaniel Fernandes and Irina Gaynanova},
8 | title = {GlucoBench: Curated List of Continuous Glucose Monitoring Datasets with Prediction Benchmarks},
9 | journal = {arXiv}
10 | year = {2023},
11 | }
12 | ```
13 |
14 | # Dependencies
15 |
16 | We recommend to setup clean Python environment with [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html "conda-env") by running `conda create -n glucobench python=3.10`. Then we can install all dependenices by running `pip install -r requirments.txt`.
17 |
18 | To run [Latent ODE](https://github.com/YuliaRubanova/latent_ode) model, install [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq).
19 |
20 | # Code organization
21 |
22 | The code is organized as follows:
23 |
24 | - `bin/`: training commands for all models
25 | - `config/`: configuration files for all datasets
26 | - `data_formatter/`
27 | - base.py: performs **all** pre-processing for all CGM datasets
28 | - `exploratory_analysis/`: notebooks with processing steps for pulling the data and converting to `.csv` files
29 | - `lib/`
30 | - `gluformer/`: model implementation
31 | - `latent_ode/`: model implementation
32 | - `*.py`: hyper-paraemter tuning, training, validation, and testing scripts
33 | - `output/`: hyper-parameter optimization and testing logs
34 | - `paper_results/`: code for producing tables and plots, found in the paper
35 | - `utils/`: helper functions for model training and testing
36 | - `raw_data.zip`: web-pulled CGM data (processed using `exploratory_analysis`)
37 | - `environment.yml`: conda environment file
38 |
39 | # Data
40 |
41 | The datasets are distributed according to the following licences and can be downloaded from the following links outlined in the table below.
42 |
43 | | Dataset | License | Number of patients | CGM Frequency |
44 | | ------- | ------- | ------------------ | ------------- |
45 | | [Colas](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0225817#sec018) | [Creative Commons 4.0](https://creativecommons.org/licenses/by/3.0/us/) | 208 | 5 minutes |
46 | | [Dubosson](https://doi.org/10.5281/zenodo.1421615) | [Creative Commons 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) | 9 | 5 minutes |
47 | | [Hall](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.2005143#pbio.2005143.s010) | [Creative Commons 4.0](https://creativecommons.org/licenses/by/4.0/) | 57 | 5 minutes |
48 | | [Broll](https://github.com/irinagain/iglu) | [GPL-2](https://www.r-project.org/Licenses/GPL-2) | 5 | 5 minutes |
49 | | [Weinstock](https://public.jaeb.org/dataset/537) | [Creative Commons 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) | 200 | 5 minutes |
50 |
51 | To process the data, follow the instructions in the `exploratory_analysis/` folder. Processed datasets should be saved in the `raw_data/` folder. We provide examples in the `raw_data.zip` file.
52 |
53 | # How to reproduce results?
54 |
55 | ## Setting up the enviroment
56 |
57 | We recommend setting up a clean Python environment using [Conda](https://docs.conda.io/). Follow these steps:
58 |
59 | 1. Create a new environment named `glucobench` with Python 3.10 by running:
60 | ```
61 | conda env create -n glucobench python=3.10
62 | ```
63 |
64 | 2. Activate the environment with:
65 | ```
66 | conda activate glucobench
67 | ```
68 |
69 | 3. Install all required dependencies by running:
70 | ```
71 | pip install -r requirements.txt
72 | ```
73 |
74 | 4. (Optional) To confirm that you're installing in the correct environment, run:
75 | ```
76 | which pip
77 | ```
78 | This should display the path to the `pip` executable within the `glucobench` environment."
79 |
80 | ## Changing the configs
81 |
82 | The `config/` folder stores the best hyper-parameters (selected by [Optuna](https://optuna.org)) for each dataset and model. The `config/` also stores the dataset-specific parameters for interpolation, dropping, splitting, and scaling. To train and evaluate the models with these defaults, we can simply run:
83 |
84 | ```python
85 | python ./lib/model.py --dataset dataset --use_covs False --optuna False
86 | ```
87 |
88 | ## Changing the hyper-parameters
89 |
90 | To change the search grid for hyper-parameters, we need to modify the `./lib/model.py` file. Specifically, we look at the `objective()` function and modify the `trial.suggest_*` parameters to set the desired ranges. Once we are done, we can run the following command to re-run the hyper-parameter optimization:
91 |
92 | ```python
93 | python ./lib/model.py --dataset dataset --use_covs False --optuna True
94 | ```
95 |
96 | # How to work with the repository?
97 |
98 | We provide a detailed example of the workflow in the `example.ipynb` notebook. For clarification, we provide some general suggestions below in order of increasing complexity.
99 |
100 | ## Just the data
101 | To start experimenting with the data, we can run the following command:
102 |
103 | ```python
104 | import yaml
105 | from data_formatter.base import DataFormatter
106 |
107 | with open(f'./config/{dataset}.yaml', 'r') as f:
108 | config = yaml.safe_load(f)
109 | formatter = DataFormatter(config)
110 | ```
111 |
112 | The command exposes an object of class `DataFormatter` which automatically pre-processes the data upon initialization. The pre-processing steps can be controlled via the `config/` files. The `DataFormatter` object exposes the following attributes:
113 |
114 | 1. `formatter.train_data`: training data (as `pandas.DataFrame`)
115 | 2. `formatter.val_data`: validation data
116 | 3. `formatter.test_data`: testing (in-distribution and out-of-distribution) data
117 | i. `formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)]`: in-distribution testing data
118 | ii. `formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]`: out-of-distribution testing data
119 | 4. `formatter.data`: unscaled full data
120 |
121 | ## Integration with PyTorch
122 |
123 | Training models with PyTorch typically boils down to (1) defining a `Dataset` class with `__getitem__()` method, (2) wrapping it into a `DataLoader`, (3) defining a `torch.nn.Module` class with `forward()` method that implements the model, and (4) optimizing the model with `torch.optim` in a training loop.
124 |
125 | **Parts (1) and (2)** crucically depend on the definition of the `Dataset` class. Essentially, having the data in the table format (e.g. `formatter.train_data`), how do we sample input-output pairs and pass the covariate information? The various `Dataset` classes conveniently adopted from the `Darts` library (see [here](https://unit8co.github.io/darts/generated_api/darts.utils.data.training_dataset.html)) offer one way to wrap the data into a `Dataset` class. Different `Dataset` classes differ in what information is provided to the model:
126 |
127 | 1. `SamplingDatasetPast`: supports only past covariates
128 | 2. `SamplingDatasetDual`: supports only future-known covariates
129 | 3. `SamplingDatasetMixed`: supports both past and future-known covariates
130 |
131 | Below we give an example of loading the data and wrapping it into a `Dataset`:
132 |
133 | ```python
134 | from utils.darts_processing import load_data
135 | from utils.darts_dataset import SamplingDatasetDual
136 |
137 | formatter, series, scalers = load_data(seed=0,
138 | dataset=dataset,
139 | use_covs=True,
140 | cov_type='dual',
141 | use_static_covs=True)
142 | dataset_train = SamplingDatasetDual(series['train']['target'],
143 | series['train']['future'],
144 | output_chunk_length=out_len,
145 | input_chunk_length=in_len,
146 | use_static_covariates=True,
147 | max_samples_per_ts=max_samples_per_ts,)
148 | ```
149 |
150 | **Parts (3) and (4)** are model-specific, so we omit their discussion. For inspiration, we suggest to take a look at the `lib/gluformer/model.py` and `lib/latent_ode/trainer_glunet.py` files.
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/bin/arima.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/arima.py --dataset weinstock > ./output/track_arima_weinstock.txt &
5 | nohup python ./lib/arima.py --dataset colas > ./output/track_arima_colas.txt &
6 | nohup python ./lib/arima.py --dataset dubosson > ./output/track_arima_dubosson.txt &
7 | nohup python ./lib/arima.py --dataset hall > ./output/track_arima_hall.txt &
8 | nohup python ./lib/arima.py --dataset iglu > ./output/track_arima_iglu.txt &
9 |
10 |
--------------------------------------------------------------------------------
/bin/gluformer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/gluformer.py --dataset weinstock --gpu_id 1 --optuna False > ./output/track_gluformer_weinstock.txt &
5 | # nohup python ./lib/gluformer.py --dataset colas --gpu_id 2 --optuna False > ./output/track_gluformer_colas.txt &
6 | # nohup python ./lib/gluformer.py --dataset dubosson --gpu_id 0 --optuna False > ./output/track_gluformer_dubosson.txt &
7 | # nohup python ./lib/gluformer.py --dataset hall --gpu_id 3 --optuna False > ./output/track_gluformer_hall.txt &
8 | # nohup python ./lib/gluformer.py --dataset iglu --gpu_id 0 --optuna False > ./output/track_gluformer_iglu.txt &
9 |
10 |
--------------------------------------------------------------------------------
/bin/latentode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/latentode.py --dataset weinstock --gpu_id 3 --optuna False > ./output/track_latentode_weinstock.txt &
5 | nohup python ./lib/latentode.py --dataset colas --gpu_id 2 --optuna False > ./output/track_latentode_colas.txt &
6 | nohup python ./lib/latentode.py --dataset dubosson --gpu_id 0 --optuna False > ./output/track_latentode_dubosson.txt &
7 | nohup python ./lib/latentode.py --dataset hall --gpu_id 1 --optuna False > ./output/track_latentode_hall.txt &
8 | nohup python ./lib/latentode.py --dataset iglu --gpu_id 0 --optuna False > ./output/track_latentode_iglu.txt &
9 |
10 |
--------------------------------------------------------------------------------
/bin/linreg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/linreg.py --dataset weinstock --use_covs False --optuna False > ./output/track_linreg_weinstock.txt &
5 | nohup python ./lib/linreg.py --dataset weinstock --use_covs True --optuna False > ./output/track_linreg_covariates_weinstock.txt &
6 | nohup python ./lib/linreg.py --dataset colas --use_covs False --optuna False > ./output/track_linreg_colas.txt &
7 | nohup python ./lib/linreg.py --dataset colas --use_covs True --optuna False > ./output/track_linreg_covariates_colas.txt &
8 | nohup python ./lib/linreg.py --dataset dubosson --use_covs False --optuna False > ./output/track_linreg_dubosson.txt &
9 | nohup python ./lib/linreg.py --dataset dubosson --use_covs True --optuna False > ./output/track_linreg_covariates_dubosson.txt &
10 | nohup python ./lib/linreg.py --dataset hall --use_covs False --optuna False > ./output/track_linreg_hall.txt &
11 | nohup python ./lib/linreg.py --dataset hall --use_covs True --optuna False > ./output/track_linreg_covariates_hall.txt &
12 | nohup python ./lib/linreg.py --dataset iglu --use_covs False --optuna False > ./output/track_linreg_iglu.txt &
13 | nohup python ./lib/linreg.py --dataset iglu --use_covs True --optuna False > ./output/track_linreg_covariates_iglu.txt &
14 |
--------------------------------------------------------------------------------
/bin/nhits.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/nhits.py --dataset weinstock --use_covs False --optuna False > ./output/track_nhits_weinstock.txt &
5 | nohup python ./lib/nhits.py --dataset weinstock --use_covs True --optuna False > ./output/track_nhits_covariates_weinstock.txt &
6 | nohup python ./lib/nhits.py --dataset colas --use_covs False --optuna False > ./output/track_nhits_colas.txt &
7 | nohup python ./lib/nhits.py --dataset colas --use_covs True --optuna False > ./output/track_nhits_covariates_colas.txt &
8 | nohup python ./lib/nhits.py --dataset dubosson --use_covs False --optuna False > ./output/track_nhits_dubosson.txt &
9 | nohup python ./lib/nhits.py --dataset dubosson --use_covs True --optuna False > ./output/track_nhits_covariates_dubosson.txt &
10 | nohup python ./lib/nhits.py --dataset hall --use_covs False --optuna False > ./output/track_nhits_hall.txt &
11 | nohup python ./lib/nhits.py --dataset hall --use_covs True --optuna False > ./output/track_nhits_covariates_hall.txt &
12 | nohup python ./lib/nhits.py --dataset iglu --use_covs False --optuna False > ./output/track_nhits_iglu.txt &
13 | nohup python ./lib/nhits.py --dataset iglu --use_covs True --optuna False > ./output/track_nhits_covariates_iglu.txt &
14 |
--------------------------------------------------------------------------------
/bin/tft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/tft.py --dataset weinstock --use_covs False --optuna False > ./output/track_tft_weinstock.txt &
5 | # nohup python ./lib/tft.py --dataset weinstock --use_covs True --optuna False > ./output/track_tft_covariates_weinstock.txt &
6 | # nohup python ./lib/tft.py --dataset colas --use_covs False --optuna False > ./output/track_tft_colas.txt &
7 | # nohup python ./lib/tft.py --dataset colas --use_covs True --optuna False > ./output/track_tft_covariates_colas.txt &
8 | # nohup python ./lib/tft.py --dataset dubosson --use_covs False --optuna False > ./output/track_tft_dubosson.txt &
9 | # nohup python ./lib/tft.py --dataset dubosson --use_covs True --optuna False > ./output/track_tft_covariates_dubosson.txt &
10 | # nohup python ./lib/tft.py --dataset hall --use_covs False --optuna False > ./output/track_tft_hall.txt &
11 | # nohup python ./lib/tft.py --dataset hall --use_covs True --optuna False > ./output/track_tft_covariates_hall.txt &
12 | # nohup python ./lib/tft.py --dataset iglu --use_covs False --optuna False > ./output/track_tft_iglu.txt &
13 | # nohup python ./lib/tft.py --dataset iglu --use_covs True --optuna False > ./output/track_tft_covariates_iglu.txt &
14 |
--------------------------------------------------------------------------------
/bin/transformer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/transformer.py --dataset weinstock --use_covs False --optuna False > ./output/track_transformer_weinstock.txt &
5 | nohup python ./lib/transformer.py --dataset weinstock --use_covs True --optuna False > ./output/track_transformer_covariates_weinstock.txt &
6 | nohup python ./lib/transformer.py --dataset colas --use_covs False --optuna False > ./output/track_transformer_colas.txt &
7 | nohup python ./lib/transformer.py --dataset colas --use_covs True --optuna False > ./output/track_transformer_covariates_colas.txt &
8 | nohup python ./lib/transformer.py --dataset dubosson --use_covs False --optuna False > ./output/track_transformer_dubosson.txt &
9 | nohup python ./lib/transformer.py --dataset dubosson --use_covs True --optuna False > ./output/track_transformer_covariates_dubosson.txt &
10 | nohup python ./lib/transformer.py --dataset hall --use_covs False --optuna False > ./output/track_transformer_hall.txt &
11 | nohup python ./lib/transformer.py --dataset hall --use_covs True --optuna False > ./output/track_transformer_covariates_hall.txt &
12 | nohup python ./lib/transformer.py --dataset iglu --use_covs False --optuna False > ./output/track_transformer_iglu.txt &
13 | nohup python ./lib/transformer.py --dataset iglu --use_covs True --optuna False > ./output/track_transformer_covariates_iglu.txt &
14 |
--------------------------------------------------------------------------------
/bin/xgbtree.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # execute in parallel
4 | nohup python ./lib/xgbtree.py --dataset weinstock --use_covs False --optuna False > ./output/track_xgboost_weinstock.txt &
5 | nohup python ./lib/xgbtree.py --dataset weinstock --use_covs True --optuna False > ./output/track_xgboost_covariates_weinstock.txt &
6 | nohup python ./lib/xgbtree.py --dataset colas --use_covs False --optuna False > ./output/track_xgboost_colas.txt &
7 | nohup python ./lib/xgbtree.py --dataset colas --use_covs True --optuna False > ./output/track_xgboost_covariates_colas.txt &
8 | nohup python ./lib/xgbtree.py --dataset dubosson --use_covs False --optuna False > ./output/track_xgboost_dubosson.txt &
9 | nohup python ./lib/xgbtree.py --dataset dubosson --use_covs True --optuna False > ./output/track_xgboost_covariates_dubosson.txt &
10 | nohup python ./lib/xgbtree.py --dataset hall --use_covs False --optuna False > ./output/track_xgboost_hall.txt &
11 | nohup python ./lib/xgbtree.py --dataset hall --use_covs True --optuna False > ./output/track_xgboost_covariates_hall.txt &
12 | nohup python ./lib/xgbtree.py --dataset iglu --use_covs False --optuna False > ./output/track_xgboost_iglu.txt &
13 | nohup python ./lib/xgbtree.py --dataset iglu --use_covs True --optuna False > ./output/track_xgboost_covariates_iglu.txt &
14 |
--------------------------------------------------------------------------------
/config/colas.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | ds_name: colas2019
3 | data_csv_path: ./raw_data/colas.csv
4 | index_col: -1
5 | observation_interval: 5min
6 |
7 | # Columns
8 | column_definition:
9 | - name: id
10 | data_type: categorical
11 | input_type: id
12 | - name: time
13 | data_type: date
14 | input_type: time
15 | - name: gl
16 | data_type: real_valued
17 | input_type: target
18 | - name: gender
19 | data_type: categorical
20 | input_type: static_input
21 | - name: age
22 | data_type: real_valued
23 | input_type: static_input
24 | - name: BMI
25 | data_type: real_valued
26 | input_type: static_input
27 | - name: glycaemia
28 | data_type: real_valued
29 | input_type: static_input
30 | - name: HbA1c
31 | data_type: real_valued
32 | input_type: static_input
33 | - name: follow.up
34 | data_type: real_valued
35 | input_type: static_input
36 | - name: T2DM
37 | data_type: categorical
38 | input_type: static_input
39 |
40 | # Drop
41 | drop: null
42 |
43 | # NA values abbreviation
44 | nan_vals: null
45 |
46 | # Interpolation parameters
47 | interpolation_params:
48 | gap_threshold: 45
49 | min_drop_length: 192
50 |
51 | # Splitting parameters
52 | split_params:
53 | test_percent_subjects: 0.1
54 | length_segment: 72
55 | random_state: 0
56 |
57 | # Encoding parameters
58 | encoding_params:
59 | date:
60 | - year
61 | - month
62 | - day
63 | - hour
64 | - minute
65 |
66 | # Scaling parameters
67 | scaling_params:
68 | scaler: None
69 |
70 | # Model params
71 | max_length_input: 144
72 | length_pred: 12
73 |
74 | linreg:
75 | in_len: 12
76 |
77 | linreg_covariates:
78 | in_len: 12
79 |
80 | tft:
81 | in_len: 132
82 | max_samples_per_ts: 200
83 | hidden_size: 256
84 | num_attention_heads: 3
85 | dropout: 0.22683125764190215
86 | lr: 0.0005939103829095587
87 | batch_size: 32
88 | max_grad_norm: 0.9791152645996767
89 |
90 | tft_covariates:
91 | in_len: 120
92 | max_samples_per_ts: 100
93 | hidden_size: 32
94 | num_attention_heads: 3
95 | dropout: 0.10643530677029577
96 | lr: 0.004702414513886559
97 | batch_size: 32
98 | max_grad_norm: 0.8047252326588638
99 |
100 | xgboost:
101 | in_len: 120
102 | lr: 0.509
103 | subsample: 0.9
104 | min_child_weight: 5.0
105 | colsample_bytree: 0.9
106 | max_depth: 7
107 | gamma: 0.5
108 | alpha: 0.216
109 | lambda_: 0.241
110 | n_estimators: 352
111 |
112 | xgboost_covariates:
113 | in_len: 144
114 | lr: 0.883
115 | subsample: 0.9
116 | min_child_weight: 3.0
117 | colsample_bytree: 0.8
118 | max_depth: 5
119 | gamma: 0.5
120 | alpha: 0.055
121 | lambda_: 0.08700000000000001
122 | n_estimators: 416
123 |
124 | transformer:
125 | in_len: 108
126 | max_samples_per_ts: 200
127 | d_model: 64
128 | n_heads: 2
129 | num_encoder_layers: 3
130 | num_decoder_layers: 3
131 | dim_feedforward: 480
132 | dropout: 0.12434517563324206
133 | lr: 0.00048663109178350133
134 | batch_size: 32
135 | lr_epochs: 8
136 | max_grad_norm: 0.8299004621292704
137 |
138 | transformer_covariates:
139 | in_len: 120
140 | max_samples_per_ts: 200
141 | d_model: 128
142 | n_heads: 4
143 | num_encoder_layers: 4
144 | num_decoder_layers: 1
145 | dim_feedforward: 128
146 | dropout: 0.19572808311258694
147 | lr: 0.0008814762155445509
148 | batch_size: 32
149 | lr_epochs: 18
150 | max_grad_norm: 0.8168361106999547
151 |
152 | nhits:
153 | in_len: 132
154 | max_samples_per_ts: 100
155 | kernel_sizes: 3
156 | dropout: 0.18002875427414997
157 | lr: 0.0006643638126306677
158 | batch_size: 32
159 | lr_epochs: 2
160 |
161 | nhits_covariates:
162 | in_len: 96
163 | max_samples_per_ts: 50
164 | kernel_sizes: 3
165 | dropout: 0.13142967835347927
166 | lr: 0.0008921763677516184
167 | batch_size: 48
168 | lr_epochs: 16
169 |
170 | gluformer:
171 | in_len: 96
172 | max_samples_per_ts: 150
173 | d_model: 384
174 | n_heads: 12
175 | d_fcn: 512
176 | num_enc_layers: 1
177 | num_dec_layers: 1
178 |
179 | latentode:
180 | in_len: 48
181 | max_samples_per_ts: 100
182 | latents: 20
183 | rec_dims: 40
184 | rec_layers: 3
185 | gen_layers: 3
186 | units: 100
187 | gru_units: 100
188 |
--------------------------------------------------------------------------------
/config/dubosson.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | ds_name: dubosson2018
3 | data_csv_path: ./raw_data/dubosson.csv
4 | index_col: -1
5 | observation_interval: 5min
6 |
7 | # Columns
8 | column_definition:
9 | - name: id
10 | data_type: categorical
11 | input_type: id
12 | - name: time
13 | data_type: date
14 | input_type: time
15 | - name: gl
16 | data_type: real_valued
17 | input_type: target
18 | - name: fast_insulin
19 | data_type: real_valued
20 | input_type: observed_input
21 | - name: slow_insulin
22 | data_type: real_valued
23 | input_type: observed_input
24 | - name: calories
25 | data_type: real_valued
26 | input_type: observed_input
27 | - name: balance
28 | data_type: categorical
29 | input_type: observed_input
30 | - name: quality
31 | data_type: categorical
32 | input_type: observed_input
33 | - name: HR
34 | data_type: real_valued
35 | input_type: observed_input
36 | - name: BR
37 | data_type: real_valued
38 | input_type: observed_input
39 | - name: Posture
40 | data_type: real_valued
41 | input_type: observed_input
42 | - name: Activity
43 | data_type: real_valued
44 | input_type: observed_input
45 | - name: HRV
46 | data_type: real_valued
47 | input_type: observed_input
48 | - name: CoreTemp
49 | data_type: real_valued
50 | input_type: observed_input
51 |
52 | # Drop
53 | drop:
54 | rows: null
55 | columns:
56 | id:
57 | - 9
58 |
59 | # NA values abbreviation
60 | nan_vals: null
61 |
62 | # Interpolation parameters
63 | interpolation_params:
64 | gap_threshold: 30 # in minutes
65 | min_drop_length: 240 # number of points
66 |
67 | # Splitting parameters
68 | split_params:
69 | test_percent_subjects: 0.1
70 | length_segment: 144
71 | random_state: 0
72 |
73 | # Encoding parameters
74 | encoding_params:
75 | date:
76 | - year
77 | - month
78 | - day
79 | - hour
80 | - minute
81 |
82 | # Scaling parameters
83 | scaling_params:
84 | scaler: None
85 |
86 | # Model params
87 | max_length_input: 192
88 | length_pred: 12
89 |
90 | linreg:
91 | in_len: 12
92 |
93 | linreg_covariates:
94 | in_len: 12
95 |
96 | tft:
97 | in_len: 168
98 | max_samples_per_ts: 50
99 | hidden_size: 240
100 | num_attention_heads: 2
101 | dropout: 0.24910705171945197
102 | lr: 0.003353965994113796
103 | batch_size: 64
104 | max_grad_norm: 0.999584070166802
105 |
106 | tft_covariates:
107 | in_len: 120
108 | max_samples_per_ts: 50
109 | hidden_size: 240
110 | num_attention_heads: 1
111 | dropout: 0.2354005483884536
112 | lr: 0.0014372065280028868
113 | batch_size: 32
114 | max_grad_norm: 0.08770929102027172
115 |
116 | xgboost:
117 | in_len: 168
118 | lr: 0.6910000000000001
119 | subsample: 0.8
120 | min_child_weight: 5.0
121 | colsample_bytree: 0.8
122 | max_depth: 10
123 | gamma: 0.5
124 | alpha: 0.201
125 | lambda_: 0.279
126 | n_estimators: 416
127 |
128 | xgboost_covariates:
129 | in_len: 36
130 | lr: 0.651
131 | subsample: 0.8
132 | min_child_weight: 2.0
133 | colsample_bytree: 1.0
134 | max_depth: 6
135 | gamma: 1.5
136 | alpha: 0.148
137 | lambda_: 0.094
138 | n_estimators: 480
139 |
140 | transformer:
141 | in_len: 108
142 | max_samples_per_ts: 50
143 | d_model: 32
144 | n_heads: 2
145 | num_encoder_layers: 1
146 | num_decoder_layers: 1
147 | dim_feedforward: 384
148 | dropout: 0.038691123579122515
149 | lr: 0.0004450217945481336
150 | batch_size: 32
151 | lr_epochs: 6
152 | max_grad_norm: 0.20863935142150056
153 |
154 | transformer_covariates:
155 | in_len: 156
156 | max_samples_per_ts: 50
157 | d_model: 64
158 | n_heads: 2
159 | num_encoder_layers: 2
160 | num_decoder_layers: 1
161 | dim_feedforward: 384
162 | dropout: 0.0026811942171770446
163 | lr: 0.000998963295875978
164 | batch_size: 48
165 | lr_epochs: 20
166 | max_grad_norm: 0.1004169110387992
167 |
168 | nhits:
169 | in_len: 108
170 | max_samples_per_ts: 50
171 | kernel_sizes: 3
172 | dropout: 0.06496948174462439
173 | lr: 0.0003359362814711015
174 | batch_size: 32
175 | lr_epochs: 2
176 |
177 | nhits_covariates:
178 | in_len: 120
179 | max_samples_per_ts: 50
180 | kernel_sizes: 2
181 | dropout: 0.16272090435698405
182 | lr: 0.0004806891979994542
183 | batch_size: 48
184 | lr_epochs: 12
185 |
186 | gluformer:
187 | in_len: 108
188 | max_samples_per_ts: 100
189 | d_model: 384
190 | n_heads: 8
191 | d_fcn: 1024
192 | num_enc_layers: 1
193 | num_dec_layers: 3
194 |
195 | latentode:
196 | in_len: 48
197 | max_samples_per_ts: 100
198 | latents: 20
199 | rec_dims: 40
200 | rec_layers: 3
201 | gen_layers: 3
202 | units: 100
203 | gru_units: 100
204 |
205 |
206 |
--------------------------------------------------------------------------------
/config/hall.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | ds_name: hall2018
3 | data_csv_path: ./raw_data/hall.csv
4 | #./raw_data/Hall2018_processed_akhil.csv
5 | index_col: -1
6 | observation_interval: 5min
7 |
8 | # Columns
9 | column_definition:
10 | - name: id
11 | data_type: categorical
12 | input_type: id
13 | - name: time
14 | data_type: date
15 | input_type: time
16 | - name: gl
17 | data_type: real_valued
18 | input_type: target
19 | - name: Age
20 | data_type: real_valued
21 | input_type: static_input
22 | - name: BMI
23 | data_type: real_valued
24 | input_type: static_input
25 | - name: A1C
26 | data_type: real_valued
27 | input_type: static_input
28 | - name: FBG
29 | data_type: real_valued
30 | input_type: static_input
31 | - name: ogtt.2hr
32 | data_type: real_valued
33 | input_type: static_input
34 | - name: insulin
35 | data_type: real_valued
36 | input_type: static_input
37 | - name: hs.CRP
38 | data_type: real_valued
39 | input_type: static_input
40 | - name: Tchol
41 | data_type: real_valued
42 | input_type: static_input
43 | - name: Trg
44 | data_type: real_valued
45 | input_type: static_input
46 | - name: HDL
47 | data_type: real_valued
48 | input_type: static_input
49 | - name: LDL
50 | data_type: real_valued
51 | input_type: static_input
52 | - name: mean_glucose
53 | data_type: real_valued
54 | input_type: static_input
55 | - name: sd_glucose
56 | data_type: real_valued
57 | input_type: static_input
58 | - name: range_glucose
59 | data_type: real_valued
60 | input_type: static_input
61 | - name: min_glucose
62 | data_type: real_valued
63 | input_type: static_input
64 | - name: max_glucose
65 | data_type: real_valued
66 | input_type: static_input
67 | - name: quartile.25_glucose
68 | data_type: real_valued
69 | input_type: static_input
70 | - name: median_glucose
71 | data_type: real_valued
72 | input_type: static_input
73 | - name: quartile.75_glucose
74 | data_type: real_valued
75 | input_type: static_input
76 | - name: mean_slope
77 | data_type: real_valued
78 | input_type: static_input
79 | - name: max_slope
80 | data_type: real_valued
81 | input_type: static_input
82 | - name: number_Random140
83 | data_type: real_valued
84 | input_type: static_input
85 | - name: number_Random200
86 | data_type: real_valued
87 | input_type: static_input
88 | - name: percent_below.80
89 | data_type: real_valued
90 | input_type: static_input
91 | - name: se_glucose_mean
92 | data_type: real_valued
93 | input_type: static_input
94 | - name: numGE
95 | data_type: real_valued
96 | input_type: static_input
97 | - name: mage
98 | data_type: real_valued
99 | input_type: static_input
100 | - name: j_index
101 | data_type: real_valued
102 | input_type: static_input
103 | - name: IQR
104 | data_type: real_valued
105 | input_type: static_input
106 | - name: modd
107 | data_type: real_valued
108 | input_type: static_input
109 | - name: distance_traveled
110 | data_type: real_valued
111 | input_type: static_input
112 | - name: coef_variation
113 | data_type: real_valued
114 | input_type: static_input
115 | - name: number_Random140_normByDays
116 | data_type: real_valued
117 | input_type: static_input
118 | - name: number_Random200_normByDays
119 | data_type: real_valued
120 | input_type: static_input
121 | - name: numGE_normByDays
122 | data_type: real_valued
123 | input_type: static_input
124 | - name: distance_traveled_normByDays
125 | data_type: real_valued
126 | input_type: static_input
127 | - name: diagnosis
128 | data_type: categorical
129 | input_type: static_input
130 | - name: freq_low
131 | data_type: real_valued
132 | input_type: static_input
133 | - name: freq_moderate
134 | data_type: real_valued
135 | input_type: static_input
136 | - name: freq_severe
137 | data_type: real_valued
138 | input_type: static_input
139 | - name: glucotype
140 | data_type: categorical
141 | input_type: static_input
142 | - name: Height
143 | data_type: real_valued
144 | input_type: static_input
145 | - name: Weight
146 | data_type: real_valued
147 | input_type: static_input
148 | - name: Insulin_rate_dd
149 | data_type: real_valued
150 | input_type: static_input
151 | - name: perc_cgm_prediabetic_range
152 | data_type: real_valued
153 | input_type: static_input
154 | - name: perc_cgm_diabetic_range
155 | data_type: real_valued
156 | input_type: static_input
157 | - name: SSPG
158 | data_type: real_valued
159 | input_type: static_input
160 |
161 | # Drop
162 | drop:
163 | rows:
164 | - 57309
165 | columns: null
166 |
167 | # NA values abbreviation
168 | nan_vals: NA
169 |
170 | # Interpolation parameters
171 | interpolation_params:
172 | gap_threshold: 30 # in minutes
173 | min_drop_length: 192 # number of points
174 |
175 | # Splitting parameters
176 | split_params:
177 | test_percent_subjects: 0.1
178 | length_segment: 192
179 | random_state: 0
180 |
181 | # Encoding parameters
182 | encoding_params:
183 | date:
184 | - year
185 | - month
186 | - day
187 | - hour
188 | - minute
189 |
190 | # Scaling parameters
191 | scaling_params:
192 | scaler: None
193 |
194 | # Model params
195 | max_length_input: 144
196 | length_pred: 12
197 |
198 | linreg:
199 | in_len: 84
200 |
201 | linreg_covariates:
202 | in_len: 60
203 |
204 | tft:
205 | in_len: 96
206 | max_samples_per_ts: 50
207 | hidden_size: 160
208 | num_attention_heads: 2
209 | dropout: 0.12663651999137013
210 | lr: 0.0003909069464830342
211 | batch_size: 48
212 | max_grad_norm: 0.42691316697261855
213 |
214 | tft_covariates:
215 | in_len: 132
216 | max_samples_per_ts: 50
217 | hidden_size: 64
218 | num_attention_heads: 3
219 | dropout: 0.1514203549391074
220 | lr: 0.002278316839625157
221 | batch_size: 32
222 | max_grad_norm: 0.6617473571712074
223 |
224 | xgboost:
225 | in_len: 60
226 | lr: 0.515
227 | subsample: 0.9
228 | min_child_weight: 3.0
229 | colsample_bytree: 0.9
230 | max_depth: 6
231 | gamma: 2.0
232 | alpha: 0.099
233 | lambda_: 0.134
234 | n_estimators: 256
235 |
236 | xgboost_covariates:
237 | in_len: 120
238 | lr: 0.17200000000000001
239 | subsample: 0.7
240 | min_child_weight: 2.0
241 | colsample_bytree: 0.9
242 | max_depth: 6
243 | gamma: 1.0
244 | alpha: 0.167
245 | lambda_: 0.017
246 | n_estimators: 320
247 |
248 | transformer:
249 | in_len: 144
250 | max_samples_per_ts: 200
251 | d_model: 64
252 | n_heads: 4
253 | num_encoder_layers: 1
254 | num_decoder_layers: 1
255 | dim_feedforward: 96
256 | dropout: 0.014744750937083516
257 | lr: 0.00035186058101597097
258 | batch_size: 48
259 | lr_epochs: 14
260 | max_grad_norm: 0.43187285340924153
261 |
262 | transformer_covariates:
263 | in_len: 132
264 | max_samples_per_ts: 150
265 | d_model: 64
266 | n_heads: 4
267 | num_encoder_layers: 1
268 | num_decoder_layers: 3
269 | dim_feedforward: 192
270 | dropout: 0.1260638882066075
271 | lr: 0.0006944648317764303
272 | batch_size: 48
273 | lr_epochs: 4
274 | max_grad_norm: 0.22914229299130273
275 |
276 | nhits:
277 | in_len: 144
278 | max_samples_per_ts: 100
279 | kernel_sizes: 4
280 | dropout: 0.046869296882493555
281 | lr: 0.00011524084800602483
282 | batch_size: 48
283 | lr_epochs: 2
284 |
285 | nhits_covariates:
286 | in_len: 120
287 | max_samples_per_ts: 50
288 | kernel_sizes: 5
289 | dropout: 0.18679300209273494
290 | lr: 0.0004763622305085654
291 | batch_size: 48
292 | lr_epochs: 4
293 |
294 | gluformer:
295 | in_len: 96
296 | max_samples_per_ts: 200
297 | d_model: 384
298 | n_heads: 4
299 | d_fcn: 1024
300 | num_enc_layers: 1
301 | num_dec_layers: 1
302 |
303 | latentode:
304 | in_len: 48
305 | max_samples_per_ts: 100
306 | latents: 20
307 | rec_dims: 40
308 | rec_layers: 3
309 | gen_layers: 3
310 | units: 100
311 | gru_units: 100
--------------------------------------------------------------------------------
/config/iglu.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | ds_name: iglu
3 | data_csv_path: ./raw_data/iglu.csv
4 | index_col: -1
5 | observation_interval: 5min
6 |
7 | # Columns
8 | column_definition:
9 | - name: id
10 | data_type: categorical
11 | input_type: id
12 | - name: time
13 | data_type: date
14 | input_type: time
15 | - name: gl
16 | data_type: real_valued
17 | input_type: target
18 |
19 | # Drop
20 | drop: null
21 |
22 | # NA values abbreviation
23 | nan_vals: null
24 |
25 | # Interpolation parameters
26 | interpolation_params:
27 | gap_threshold: 45 # in minutes
28 | min_drop_length: 240 # in number of points (20 hrs)
29 |
30 | # Splitting parameters
31 | split_params:
32 | test_percent_subjects: 0.1
33 | length_segment: 240
34 | random_state: 0
35 |
36 | # Encoding parameters
37 | encoding_params:
38 | date:
39 | - year
40 | - month
41 | - day
42 | - hour
43 | - minute
44 | - second
45 |
46 | # Scaling parameters
47 | scaling_params:
48 | scaler: None
49 |
50 | # Model params
51 | max_length_input: 192 # in number of points (16 hrs)
52 | length_pred: 12 # in number of points (predict 1 hr)
53 |
54 | linreg:
55 | in_len: 192
56 |
57 | linreg_covariates:
58 | in_len: 12
59 |
60 | tft:
61 | in_len: 168
62 | max_samples_per_ts: 50
63 | hidden_size: 80
64 | num_attention_heads: 4
65 | dropout: 0.12792080253276716
66 | lr: 0.003164601797779577
67 | batch_size: 32
68 | max_grad_norm: 0.5265925565310886
69 |
70 | tft_covariates:
71 | in_len: 96
72 | max_samples_per_ts: 50
73 | hidden_size: 80
74 | num_attention_heads: 3
75 | dropout: 0.22790916758695268
76 | lr: 0.005050238867376333
77 | batch_size: 32
78 | max_grad_norm: 0.026706367007025333
79 |
80 | xgboost:
81 | in_len: 84
82 | lr: 0.506
83 | subsample: 0.9
84 | min_child_weight: 2.0
85 | colsample_bytree: 0.8
86 | max_depth: 9
87 | gamma: 0.5
88 | alpha: 0.124
89 | lambda_: 0.089
90 | n_estimators: 416
91 |
92 | xgboost_covariates:
93 | in_len: 96
94 | lr: 0.387
95 | subsample: 0.8
96 | min_child_weight: 1.0
97 | colsample_bytree: 1.0
98 | max_depth: 8
99 | gamma: 1.0
100 | alpha: 0.199
101 | lambda_: 0.018000000000000002
102 | n_estimators: 288
103 |
104 | transformer:
105 | in_len: 96
106 | max_samples_per_ts: 50
107 | d_model: 96
108 | n_heads: 4
109 | num_encoder_layers: 4
110 | num_decoder_layers: 1
111 | dim_feedforward: 448
112 | dropout: 0.10161152207464333
113 | lr: 0.000840888489686657
114 | batch_size: 32
115 | lr_epochs: 16
116 | max_grad_norm: 0.6740479322943925
117 |
118 | transformer_covariates:
119 | in_len: 108
120 | max_samples_per_ts: 50
121 | d_model: 128
122 | n_heads: 2
123 | num_encoder_layers: 2
124 | num_decoder_layers: 2
125 | dim_feedforward: 160
126 | dropout: 0.044926981080245884
127 | lr: 0.00029632347559614453
128 | batch_size: 32
129 | lr_epochs: 20
130 | max_grad_norm: 0.8890169619043728
131 |
132 | nhits:
133 | in_len: 96
134 | max_samples_per_ts: 50
135 | kernel_sizes: 5
136 | dropout: 0.12695408586813234
137 | lr: 0.0004510532358403777
138 | batch_size: 64
139 | lr_epochs: 16
140 |
141 | nhits_covariates:
142 | in_len: 144
143 | max_samples_per_ts: 50
144 | kernel_sizes: 3
145 | dropout: 0.09469970402653531
146 | lr: 0.0009786650965760999
147 | batch_size: 32
148 | lr_epochs: 10
149 |
150 | gluformer:
151 | in_len: 96
152 | max_samples_per_ts: 100
153 | d_model: 512
154 | n_heads: 4
155 | d_fcn: 512
156 | num_enc_layers: 1
157 | num_dec_layers: 4
158 |
159 | latentode:
160 | in_len: 48
161 | max_samples_per_ts: 100
162 | latents: 20
163 | rec_dims: 40
164 | rec_layers: 3
165 | gen_layers: 3
166 | units: 100
167 | gru_units: 100
168 |
169 |
--------------------------------------------------------------------------------
/config/weinstock.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | ds_name: weinstock2016
3 | data_csv_path: ./raw_data/weinstock.csv
4 | index_col: -1
5 | observation_interval: 5min
6 |
7 | # Columns
8 | column_definition:
9 | - name: id
10 | data_type: categorical
11 | input_type: id
12 | - name: time
13 | data_type: date
14 | input_type: time
15 | - name: gl
16 | data_type: real_valued
17 | input_type: target
18 | - name: Height
19 | data_type: real_valued
20 | input_type: static_input
21 | - name: Weight
22 | data_type: real_valued
23 | input_type: static_input
24 | - name: Gender
25 | data_type: categorical
26 | input_type: static_input
27 | - name: Race
28 | data_type: categorical
29 | input_type: static_input
30 | - name: EduLevel
31 | data_type: categorical
32 | input_type: static_input
33 | - name: AnnualInc
34 | data_type: real_valued
35 | input_type: static_input
36 | - name: MaritalStatus
37 | data_type: categorical
38 | input_type: static_input
39 | - name: DaysWkEx
40 | data_type: real_valued
41 | input_type: static_input
42 | - name: DaysWkDrinkAlc
43 | data_type: real_valued
44 | input_type: static_input
45 | - name: DaysMonBingeAlc
46 | data_type: real_valued
47 | input_type: static_input
48 | - name: T1DDiagAge
49 | data_type: real_valued
50 | input_type: static_input
51 | - name: NumHospDKA
52 | data_type: real_valued
53 | input_type: static_input
54 | - name: NumSHSinceT1DDiag
55 | data_type: real_valued
56 | input_type: static_input
57 | - name: InsDeliveryMethod
58 | data_type: categorical
59 | input_type: static_input
60 | - name: UnitsInsTotal
61 | data_type: real_valued
62 | input_type: static_input
63 | - name: NumMeterCheckDay
64 | data_type: real_valued
65 | input_type: static_input
66 | - name: Aspirin
67 | data_type: real_valued
68 | input_type: static_input
69 | - name: Simvastatin
70 | data_type: real_valued
71 | input_type: static_input
72 | - name: Lisinopril
73 | data_type: real_valued
74 | input_type: static_input
75 | - name: "Vitamin D"
76 | data_type: real_valued
77 | input_type: static_input
78 | - name: "Multivitamin preparation"
79 | data_type: real_valued
80 | input_type: static_input
81 | - name: Omeprazole
82 | data_type: real_valued
83 | input_type: static_input
84 | - name: atorvastatin
85 | data_type: real_valued
86 | input_type: static_input
87 | - name: Synthroid
88 | data_type: real_valued
89 | input_type: static_input
90 | - name: "vitamin D3"
91 | data_type: real_valued
92 | input_type: static_input
93 | - name: Hypertension
94 | data_type: real_valued
95 | input_type: static_input
96 | - name: Hyperlipidemia
97 | data_type: real_valued
98 | input_type: static_input
99 | - name: Hypothyroidism
100 | data_type: real_valued
101 | input_type: static_input
102 | - name: Depression
103 | data_type: real_valued
104 | input_type: static_input
105 | - name: "Coronary artery disease"
106 | data_type: real_valued
107 | input_type: static_input
108 | - name: "Diabetic peripheral neuropathy"
109 | data_type: real_valued
110 | input_type: static_input
111 | - name: Dyslipidemia
112 | data_type: real_valued
113 | input_type: static_input
114 | - name: "Chronic kidney disease"
115 | data_type: real_valued
116 | input_type: static_input
117 | - name: Osteoporosis
118 | data_type: real_valued
119 | input_type: static_input
120 | - name: "Proliferative diabetic retinopathy"
121 | data_type: real_valued
122 | input_type: static_input
123 | - name: Hypercholesterolemia
124 | data_type: real_valued
125 | input_type: static_input
126 | - name: "Erectile dysfunction"
127 | data_type: real_valued
128 | input_type: static_input
129 | - name: "Type I diabetes mellitus"
130 | data_type: real_valued
131 | input_type: static_input
132 |
133 | # Drop
134 | drop: null
135 |
136 | # NA values abbreviation
137 | nan_vals: null
138 |
139 | # Interpolation parameters
140 | interpolation_params:
141 | gap_threshold: 45
142 | min_drop_length: 240
143 |
144 | # Splitting parameters
145 | split_params:
146 | test_percent_subjects: 0.1
147 | length_segment: 240
148 | random_state: 0
149 |
150 | # Encoding parameters
151 | encoding_params:
152 | date:
153 | - year
154 | - month
155 | - day
156 | - hour
157 | - minute
158 |
159 | # Scaling parameters
160 | scaling_params:
161 | scaler: None
162 |
163 | # Model params
164 | max_length_input: 192
165 | length_pred: 12
166 |
167 | linreg:
168 | in_len: 84
169 |
170 | linreg_covariates:
171 | in_len: 84
172 |
173 | tft:
174 | in_len: 132
175 | max_samples_per_ts: 200
176 | hidden_size: 96
177 | num_attention_heads: 3
178 | dropout: 0.14019930679548182
179 | lr: 0.003399303384204884
180 | batch_size: 48
181 | max_grad_norm: 0.9962755235072169
182 |
183 | tft_covariates:
184 | in_len: 108
185 | max_samples_per_ts: 50
186 | hidden_size: 112
187 | num_attention_heads: 2
188 | dropout: 0.1504541564537306
189 | lr: 0.0018430630797167395
190 | batch_size: 48
191 | max_grad_norm: 0.9530046023189843
192 |
193 | xgboost:
194 | in_len: 84
195 | lr: 0.722
196 | subsample: 0.9
197 | min_child_weight: 5.0
198 | colsample_bytree: 1.0
199 | max_depth: 10
200 | gamma: 0.5
201 | alpha: 0.271
202 | lambda_: 0.07100000000000001
203 | n_estimators: 416
204 |
205 | xgboost_covariates:
206 | in_len: 96
207 | lr: 0.48000000000000004
208 | subsample: 1.0
209 | min_child_weight: 2.0
210 | colsample_bytree: 0.9
211 | max_depth: 6
212 | gamma: 1.5
213 | alpha: 0.159
214 | lambda_: 0.025
215 | n_estimators: 320
216 |
217 | transformer:
218 | in_len: 96
219 | max_samples_per_ts: 50
220 | d_model: 128
221 | n_heads: 2
222 | num_encoder_layers: 2
223 | num_decoder_layers: 4
224 | dim_feedforward: 64
225 | dropout: 0.0017011626095738697
226 | lr: 0.0007790307889667749
227 | batch_size: 32
228 | lr_epochs: 4
229 | max_grad_norm: 0.4226615744655383
230 |
231 | transformer_covariates:
232 | in_len: 96
233 | max_samples_per_ts: 50
234 | d_model: 128
235 | n_heads: 4
236 | num_encoder_layers: 1
237 | num_decoder_layers: 4
238 | dim_feedforward: 448
239 | dropout: 0.1901296977134417
240 | lr: 0.000965351785309486
241 | batch_size: 48
242 | lr_epochs: 4
243 | max_grad_norm: 0.19219462323820113
244 |
245 | nhits:
246 | in_len: 96
247 | max_samples_per_ts: 200
248 | kernel_sizes: 4
249 | dropout: 0.12642017123585755
250 | lr: 0.00032840023694932384
251 | batch_size: 64
252 | lr_epochs: 16
253 |
254 | nhits_covariates:
255 | in_len: 96
256 | max_samples_per_ts: 50
257 | kernel_sizes: 3
258 | dropout: 0.10162895545943862
259 | lr: 0.0009200129411689094
260 | batch_size: 32
261 | lr_epochs: 2
262 |
263 | gluformer:
264 | in_len: 144
265 | max_samples_per_ts: 100
266 | d_model: 512
267 | n_heads: 8
268 | d_fcn: 1408
269 | num_enc_layers: 1
270 | num_dec_layers: 4
271 |
272 | latentode:
273 | in_len: 48
274 | max_samples_per_ts: 100
275 | latents: 20
276 | rec_dims: 40
277 | rec_layers: 3
278 | gen_layers: 3
279 | units: 100
280 | gru_units: 100
--------------------------------------------------------------------------------
/data_formatter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/data_formatter/__init__.py
--------------------------------------------------------------------------------
/data_formatter/base.py:
--------------------------------------------------------------------------------
1 | '''Defines a generic data formatter for CGM data sets.'''
2 | import sys
3 | import warnings
4 | import numpy as np
5 | import pandas as pd
6 | import sklearn.preprocessing
7 | import data_formatter.types as types
8 | import data_formatter.utils as utils
9 |
10 | DataTypes = types.DataTypes
11 | InputTypes = types.InputTypes
12 |
13 | dict_data_type = {'categorical': DataTypes.CATEGORICAL,
14 | 'real_valued': DataTypes.REAL_VALUED,
15 | 'date': DataTypes.DATE}
16 | dict_input_type = {'target': InputTypes.TARGET,
17 | 'observed_input': InputTypes.OBSERVED_INPUT,
18 | 'known_input': InputTypes.KNOWN_INPUT,
19 | 'static_input': InputTypes.STATIC_INPUT,
20 | 'id': InputTypes.ID,
21 | 'time': InputTypes.TIME}
22 |
23 |
24 | class DataFormatter():
25 | # Defines and formats data for the IGLU dataset.
26 |
27 | def __init__(self, cnf, study_file = None):
28 | """Initialises formatter."""
29 | # load parameters from the config file
30 | self.params = cnf
31 | # write progress to file if specified
32 | self.study_file = study_file
33 | stdout = sys.stdout
34 | f = open(study_file, 'a') if study_file is not None else sys.stdout
35 | sys.stdout = f
36 |
37 | # load column definition
38 | print('-'*32)
39 | print('Loading column definition...')
40 | self.__process_column_definition()
41 |
42 | # check that column definition is valid
43 | print('Checking column definition...')
44 | self.__check_column_definition()
45 |
46 | # load data
47 | # check if data table has index col: -1 if not, index >= 0 if yes
48 | print('Loading data...')
49 | self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col']
50 | # read data table
51 | self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col'])
52 |
53 | # drop columns / rows
54 | print('Dropping columns / rows...')
55 | self.__drop()
56 |
57 | # check NA values
58 | print('Checking for NA values...')
59 | self.__check_nan()
60 |
61 | # set data types in DataFrame to match column definition
62 | print('Setting data types...')
63 | self.__set_data_types()
64 |
65 | # drop columns / rows
66 | print('Dropping columns / rows...')
67 | self.__drop()
68 |
69 | # encode
70 | print('Encoding data...')
71 | self._encoding_params = self.params['encoding_params']
72 | self.__encode()
73 |
74 | # interpolate
75 | print('Interpolating data...')
76 | self._interpolation_params = self.params['interpolation_params']
77 | self._interpolation_params['interval_length'] = self.params['observation_interval']
78 | self.__interpolate()
79 |
80 | # split data
81 | print('Splitting data...')
82 | self._split_params = self.params['split_params']
83 | self._split_params['max_length_input'] = self.params['max_length_input']
84 | self.__split_data()
85 |
86 | # scale
87 | print('Scaling data...')
88 | self._scaling_params = self.params['scaling_params']
89 | self.__scale()
90 |
91 | print('Data formatting complete.')
92 | print('-'*32)
93 | if study_file is not None:
94 | f.close()
95 | sys.stdout = stdout
96 |
97 |
98 | def __process_column_definition(self):
99 | self._column_definition = []
100 | for col in self.params['column_definition']:
101 | self._column_definition.append((col['name'],
102 | dict_data_type[col['data_type']],
103 | dict_input_type[col['input_type']]))
104 |
105 | def __check_column_definition(self):
106 | # check that there is unique ID column
107 | assert len([col for col in self._column_definition if col[2] == InputTypes.ID]) == 1, 'There must be exactly one ID column.'
108 | # check that there is unique time column
109 | assert len([col for col in self._column_definition if col[2] == InputTypes.TIME]) == 1, 'There must be exactly one time column.'
110 | # check that there is at least one target column
111 | assert len([col for col in self._column_definition if col[2] == InputTypes.TARGET]) >= 1, 'There must be at least one target column.'
112 |
113 | def __set_data_types(self):
114 | # set time column as datetime format in pandas
115 | for col in self._column_definition:
116 | if col[1] == DataTypes.DATE:
117 | self.data[col[0]] = pd.to_datetime(self.data[col[0]])
118 | if col[1] == DataTypes.CATEGORICAL:
119 | self.data[col[0]] = self.data[col[0]].astype('category')
120 | if col[1] == DataTypes.REAL_VALUED:
121 | self.data[col[0]] = self.data[col[0]].astype(np.float32)
122 |
123 | def __check_nan(self):
124 | # delete rows where target, time, or id are na
125 | self.data = self.data.dropna(subset=[col[0]
126 | for col in self._column_definition
127 | if col[2] in [InputTypes.TARGET, InputTypes.TIME, InputTypes.ID]])
128 | # assert that there are no na values in the data
129 | assert self.data.isna().sum().sum() == 0, 'There are NA values in the data even after dropping with missing time, glucose, or id.'
130 |
131 | def __drop(self):
132 | # drop columns that are not in the column definition
133 | self.data = self.data[[col[0] for col in self._column_definition]]
134 | # drop rows based on conditions set in the formatter
135 | if self.params['drop'] is not None:
136 | if self.params['drop']['rows'] is not None:
137 | # drop row at indices in the list self.params['drop']['rows']
138 | self.data = self.data.drop(self.params['drop']['rows'])
139 | self.data = self.data.reset_index(drop=True)
140 | if self.params['drop']['columns'] is not None:
141 | for col in self.params['drop']['columns'].keys():
142 | # drop rows where specified columns have values in the list self.params['drop']['columns'][col]
143 | self.data = self.data.loc[~self.data[col].isin(self.params['drop']['columns'][col])].copy()
144 |
145 | def __interpolate(self):
146 | self.data, self._column_definition = utils.interpolate(self.data,
147 | self._column_definition,
148 | **self._interpolation_params)
149 |
150 | def __split_data(self):
151 | if self.params['split_params']['test_percent_subjects'] == 0 or \
152 | self.params['split_params']['length_segment'] == 0:
153 | print('\tNo splitting performed since test_percent_subjects or length_segment is 0.')
154 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = None, None, None, None
155 | self.train_data, self.val_data, self.test_data = self.data, None, None
156 | else:
157 | assert self.params['split_params']['length_segment'] > self.params['length_pred'], \
158 | 'length_segment for test / val must be greater than length_pred.'
159 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
160 | self._column_definition,
161 | **self._split_params)
162 | self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
163 | self.data.iloc[self.val_idx], \
164 | self.data.iloc[self.test_idx + self.test_idx_ood]
165 |
166 | def __encode(self):
167 | self.data, self._column_definition, self.encoders = utils.encode(self.data,
168 | self._column_definition,
169 | **self._encoding_params)
170 |
171 | def __scale(self):
172 | self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
173 | self.val_data,
174 | self.test_data,
175 | self._column_definition,
176 | **self.params['scaling_params'])
177 |
178 | def reshuffle(self, seed):
179 | stdout = sys.stdout
180 | f = open(self.study_file, 'a')
181 | sys.stdout = f
182 | self.params['split_params']['random_state'] = seed
183 | # split data
184 | self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
185 | self._column_definition,
186 | **self._split_params)
187 | self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
188 | self.data.iloc[self.val_idx], \
189 | self.data.iloc[self.test_idx+self.test_idx_ood]
190 | # re-scale data
191 | self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
192 | self.val_data,
193 | self.test_data,
194 | self._column_definition,
195 | **self.params['scaling_params'])
196 | sys.stdout = stdout
197 | f.close()
198 |
199 | def get_column(self, column_name):
200 | # write cases for time, id, target, future, static, dynamic covariates
201 | if column_name == 'time':
202 | return [col[0] for col in self._column_definition if col[2] == InputTypes.TIME][0]
203 | elif column_name == 'id':
204 | return [col[0] for col in self._column_definition if col[2] == InputTypes.ID][0]
205 | elif column_name == 'sid':
206 | return [col[0] for col in self._column_definition if col[2] == InputTypes.SID][0]
207 | elif column_name == 'target':
208 | return [col[0] for col in self._column_definition if col[2] == InputTypes.TARGET]
209 | elif column_name == 'future_covs':
210 | future_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.KNOWN_INPUT]
211 | return future_covs if len(future_covs) > 0 else None
212 | elif column_name == 'static_covs':
213 | static_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.STATIC_INPUT]
214 | return static_covs if len(static_covs) > 0 else None
215 | elif column_name == 'dynamic_covs':
216 | dynamic_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.OBSERVED_INPUT]
217 | return dynamic_covs if len(dynamic_covs) > 0 else None
218 | else:
219 | raise ValueError('Column {} not found.'.format(column_name))
220 |
221 |
--------------------------------------------------------------------------------
/data_formatter/types.py:
--------------------------------------------------------------------------------
1 | '''Defines data and input types of each column in the dataset.'''
2 |
3 | import enum
4 |
5 | class DataTypes(enum.IntEnum):
6 | """Defines numerical types of each column."""
7 | REAL_VALUED = 0
8 | CATEGORICAL = 1
9 | DATE = 2
10 |
11 | class InputTypes(enum.IntEnum):
12 | """Defines input types of each column."""
13 | TARGET = 0
14 | OBSERVED_INPUT = 1
15 | KNOWN_INPUT = 2
16 | STATIC_INPUT = 3
17 | ID = 4 # Single column used as an entity identifier
18 | SID = 5 # Single column used as a segment identifier
19 | TIME = 6 # Single column exclusively used as a time index
--------------------------------------------------------------------------------
/exploratory_analysis/colas.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "oFh18yX0FrjN"
7 | },
8 | "source": [
9 | "# Loading libraries"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "eawH97OBFrjS"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import sys\n",
21 | "import os\n",
22 | "import yaml\n",
23 | "import pandas as pd\n",
24 | "import numpy as np\n",
25 | "sys.path.insert(1, '..')\n",
26 | "os.chdir('..')\n",
27 | "\n",
28 | "import seaborn as sns\n",
29 | "sns.set_style('whitegrid')\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "import statsmodels.api as sm\n",
32 | "import sklearn\n",
33 | "import optuna\n",
34 | "\n",
35 | "from darts import models\n",
36 | "from darts import metrics\n",
37 | "from darts import TimeSeries\n",
38 | "from darts.dataprocessing.transformers import Scaler\n",
39 | "\n",
40 | "from statsforecast.models import AutoARIMA\n",
41 | "\n",
42 | "from data_formatter.base import *\n",
43 | "from bin.utils import *"
44 | ]
45 | },
46 | {
47 | "attachments": {},
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "# Processing"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "filenames = []\n",
61 | "for root, dir, files in os.walk('raw_data/Colas2019'):\n",
62 | " for file in files:\n",
63 | " if '.csv' in file:\n",
64 | " filenames.append(os.path.join(root, file))\n",
65 | " \n",
66 | "# next we loop through each file\n",
67 | "nfiles = len(files)\n",
68 | "\n",
69 | "count = 0\n",
70 | "for file in filenames:\n",
71 | " # read in data and extract id from filename\n",
72 | " curr = pd.read_csv(file)\n",
73 | " curr['id'] = int(file.split()[1].split(\".\")[0])\n",
74 | " # select desired columns, rename, and drop nas\n",
75 | " curr = curr[['id', 'hora', 'glucemia']]\n",
76 | " curr.rename(columns = {'hora': 'time', 'glucemia': 'gl'}, inplace=True)\n",
77 | " curr.dropna(inplace=True)\n",
78 | "\n",
79 | " # calculate time (only given in hms) as follows:\n",
80 | " # (1) get the time per day in seconds, (2) get the time differences, and correct for the day crossove (< 0)\n",
81 | " # (3) take the cumulative sum and add the cumulative number of seconds from start to the base date\n",
82 | " # thus the hms are real, while the year, month, day are fake\n",
83 | " time_secs = []\n",
84 | " for i in curr['time']:\n",
85 | " time_secs.append(int(i.split(\":\")[0])*60*60 + int(i.split(\":\")[1])*60 + int(i.split(\":\")[2])*1)\n",
86 | " time_diff = np.diff(np.array(time_secs)).tolist()\n",
87 | " time_diff_adj = [x if x > 0 else 24*60*60 + x for x in time_diff]\n",
88 | " time_diff_adj.insert(0, 0)\n",
89 | " cumin = np.cumsum(time_diff_adj)\n",
90 | " datetime = pd.to_datetime('2012-01-01') + pd.to_timedelta(cumin, unit='sec')\n",
91 | " curr['time'] = datetime\n",
92 | " curr['id'] = curr['id'].astype('int')\n",
93 | " curr.reset_index(drop=True, inplace=True)\n",
94 | "\n",
95 | " if count == 0:\n",
96 | " df = curr\n",
97 | " count += 1\n",
98 | " else:\n",
99 | " df = pd.concat([df, curr])"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "# join with covariates\n",
109 | "covariates = pd.read_csv('raw_data/Colas2019/clinical_data.txt', sep = \" \")\n",
110 | "covariates['id'] = covariates.index\n",
111 | "\n",
112 | "combined = pd.merge(\n",
113 | " df, covariates, how = \"left\"\n",
114 | ")\n",
115 | "\n",
116 | "# define NA fill values for covariates\n",
117 | "values = {\n",
118 | " 'gender': 2, # if gender is NA, create own category\n",
119 | " 'age': combined['age'].mean(),\n",
120 | " 'BMI': combined['BMI'].mean(),\n",
121 | " 'glycaemia': combined['glycaemia'].mean(),\n",
122 | " 'HbA1c': combined['HbA1c'].mean(),\n",
123 | " 'follow.up': combined['follow.up'].mean(),\n",
124 | " 'T2DM': False\n",
125 | "}\n",
126 | "combined = combined.fillna(value = values)\n",
127 | "\n",
128 | "# write to csv\n",
129 | "combined.to_csv('raw_data/colas.csv')"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "WeAHZmAmFrjV"
136 | },
137 | "source": [
138 | "# Check statistics of the data"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {
145 | "colab": {
146 | "base_uri": "https://localhost:8080/"
147 | },
148 | "id": "pkOzK6gcFrjW",
149 | "outputId": "769510ff-79ba-4020-8d9c-dc78a7cdb7ff"
150 | },
151 | "outputs": [],
152 | "source": [
153 | "import matplotlib.pyplot as plt\n",
154 | "\n",
155 | "# load yaml config file\n",
156 | "with open('./config/colas.yaml', 'r') as f:\n",
157 | " config = yaml.safe_load(f)\n",
158 | "\n",
159 | "# set interpolation params for no interpolation\n",
160 | "new_config = config.copy()\n",
161 | "new_config['interpolation_params']['gap_threshold'] = 5\n",
162 | "new_config['interpolation_params']['min_drop_length'] = 0\n",
163 | "# set split params for no splitting\n",
164 | "new_config['split_params']['test_percent_subjects'] = 0\n",
165 | "new_config['split_params']['length_segment'] = 0\n",
166 | "# set scaling params for no scaling\n",
167 | "new_config['scaling_params']['scaler'] = 'None'\n",
168 | "\n",
169 | "formatter = DataFormatter(new_config)"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {
176 | "colab": {
177 | "base_uri": "https://localhost:8080/"
178 | },
179 | "id": "eCBgEjuAFrjX",
180 | "outputId": "1d40e5fa-1fd5-45ea-ae93-41d14226a0c5"
181 | },
182 | "outputs": [],
183 | "source": [
184 | "# print min, max, median, mean, std of segment lengths\n",
185 | "segment_lens = []\n",
186 | "for group, data in formatter.train_data.groupby('id_segment'):\n",
187 | " segment_lens.append(len(data))\n",
188 | "print('Train segment lengths:')\n",
189 | "print('\\tMin: ', min(segment_lens))\n",
190 | "print('\\tMax: ', max(segment_lens))\n",
191 | "print('\\t1st Quartile: ', np.quantile(segment_lens, 0.25))\n",
192 | "print('\\tMedian: ', np.median(segment_lens))\n",
193 | "print('\\tMean: ', np.mean(segment_lens))\n",
194 | "print('\\tStd: ', np.std(segment_lens))\n",
195 | "\n",
196 | "# plot first 9 segments\n",
197 | "num_segments = 9\n",
198 | "plot_data = formatter.train_data\n",
199 | "\n",
200 | "fig, axs = plt.subplots(1, num_segments, figsize=(30, 5))\n",
201 | "for i, (group, data) in enumerate(plot_data.groupby('id_segment')):\n",
202 | " data.plot(x='time', y='gl', ax=axs[i], title='Segment {}'.format(group))\n",
203 | " if i >= num_segments - 1:\n",
204 | " break"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "colab": {
212 | "base_uri": "https://localhost:8080/",
213 | "height": 341
214 | },
215 | "id": "iU2AUHTfFrjZ",
216 | "outputId": "ac25dfa6-4eee-4fc8-9c0c-11efa1b4fa14"
217 | },
218 | "outputs": [],
219 | "source": [
220 | "# plot acf of random samples from first 9 segments segments\n",
221 | "fig, ax = plt.subplots(2, num_segments, figsize=(30, 5))\n",
222 | "lags = 300; k = 0\n",
223 | "for i, (group, data) in enumerate(plot_data.groupby('id_segment')):\n",
224 | " data = data['gl']\n",
225 | " if len(data) < lags:\n",
226 | " print('Segment {} is too short'.format(group))\n",
227 | " continue\n",
228 | " else:\n",
229 | " # select 10 random samples from index of data\n",
230 | " sample = np.random.choice(range(len(data))[:-lags], 10, replace=False)\n",
231 | " # plot acf / pacf of each sample\n",
232 | " for j in sample:\n",
233 | " acf, acf_ci = sm.tsa.stattools.acf(data[j:j+lags], nlags=lags, alpha=0.05)\n",
234 | " pacf, pacf_ci = sm.tsa.stattools.pacf(data[j:j+lags], method='ols-adjusted', alpha=0.05)\n",
235 | " ax[0, k].plot(acf)\n",
236 | " ax[1, k].plot(pacf)\n",
237 | " k += 1\n",
238 | " if k >= num_segments:\n",
239 | " break"
240 | ]
241 | }
242 | ],
243 | "metadata": {
244 | "colab": {
245 | "provenance": []
246 | },
247 | "kernelspec": {
248 | "display_name": "base",
249 | "language": "python",
250 | "name": "python3"
251 | },
252 | "language_info": {
253 | "codemirror_mode": {
254 | "name": "ipython",
255 | "version": 3
256 | },
257 | "file_extension": ".py",
258 | "mimetype": "text/x-python",
259 | "name": "python",
260 | "nbconvert_exporter": "python",
261 | "pygments_lexer": "ipython3",
262 | "version": "3.6.9"
263 | },
264 | "orig_nbformat": 4,
265 | "vscode": {
266 | "interpreter": {
267 | "hash": "ad2bdc8ecc057115af97d19610ffacc2b4e99fae6737bb82f5d7fb13d2f2c186"
268 | }
269 | }
270 | },
271 | "nbformat": 4,
272 | "nbformat_minor": 0
273 | }
274 |
--------------------------------------------------------------------------------
/exploratory_analysis/iglu.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Loading libraries"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# DANGER: only run 1x otherwise will chdir too many times\n",
17 | "import sys\n",
18 | "import os\n",
19 | "import yaml\n",
20 | "\n",
21 | "sys.path.insert(1, '..')\n",
22 | "os.chdir('..')\n",
23 | "\n",
24 | "import seaborn as sns\n",
25 | "sns.set_style('whitegrid')\n",
26 | "\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "import statsmodels.api as sm\n",
29 | "import sklearn\n",
30 | "import optuna\n",
31 | "\n",
32 | "from darts import models, metrics, TimeSeries\n",
33 | "from darts.dataprocessing.transformers import Scaler\n",
34 | "\n",
35 | "from data_formatter.base import * # TODO: inefficient"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "# Check statistics of the data"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "# load yaml config file\n",
52 | "with open('./config/iglu.yaml', 'r') as f:\n",
53 | " config = yaml.safe_load(f)\n",
54 | "\n",
55 | "# set interpolation params for no interpolation\n",
56 | "new_config = config.copy()\n",
57 | "new_config['interpolation_params']['gap_threshold'] = 30\n",
58 | "new_config['interpolation_params']['min_drop_length'] = 0\n",
59 | "# set split params for no splitting\n",
60 | "new_config['split_params']['test_percent_subjects'] = 0\n",
61 | "new_config['split_params']['length_segment'] = 0\n",
62 | "# set scaling params for no scaling\n",
63 | "new_config['scaling_params']['scaler'] = 'None'\n",
64 | "\n",
65 | "formatter = DataFormatter(new_config)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "%%capture\n",
75 | "\n",
76 | "# Need: Tradeoff between interpolation and segment length\n",
77 | "# Problem: Manually tuning is slow and potentially imprecise\n",
78 | "# Idea: have automated function that can help determine what the gap threshold should be\n",
79 | "# Proof of concept below\n",
80 | "\n",
81 | "import numpy as np\n",
82 | "\n",
83 | "def calc_percent(a, b):\n",
84 | " return a*100/b\n",
85 | "\n",
86 | "gap_threshold = np.arange(5, 70, 1)\n",
87 | "percent_valid = []\n",
88 | "for i in gap_threshold:\n",
89 | " new_config['interpolation_params']['gap_threshold'] = i\n",
90 | " df = DataFormatter(new_config).train_data\n",
91 | " \n",
92 | " segment_lens = []\n",
93 | " for group, data in df.groupby('id_segment'):\n",
94 | " segment_lens.append(len(data))\n",
95 | " \n",
96 | " threshold = 240\n",
97 | " valid_ids = df.groupby('id_segment')['time'].count().loc[lambda x : x>threshold].reset_index()['id_segment']\n",
98 | " \n",
99 | " percent_valid.append((len(valid_ids)*100/len(segment_lens)))"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "# Plot results\n",
109 | "plt.plot(gap_threshold, percent_valid)\n",
110 | "plt.title(\"Gap Threshold affect on % Segments > 240 Length\")\n",
111 | "plt.ylabel(\"% Above Threshhold\")\n",
112 | "plt.xlabel(\"Gap Threshold (min)\")"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "# print min, max, median, mean, std of segment lengths\n",
122 | "df = formatter.train_data\n",
123 | "segment_lens = []\n",
124 | "for group, data in df.groupby('id_segment'):\n",
125 | " segment_lens.append(len(data))\n",
126 | "\n",
127 | "print('Train segment lengths:')\n",
128 | "print('\\tMin: ', min(segment_lens))\n",
129 | "print('\\tMax: ', max(segment_lens))\n",
130 | "print('\\tMedian: ', np.median(segment_lens))\n",
131 | "print('\\tMean: ', np.mean(segment_lens))\n",
132 | "print('\\tStd: ', np.std(segment_lens))\n",
133 | "\n",
134 | "# Visualize segment lengths to see approx # of valid ones (>240)\n",
135 | "plt.title(\"Segment Lengths (Line at 240)\")\n",
136 | "plt.hist(segment_lens)\n",
137 | "plt.axvline(240, color='r', linestyle='dashed', linewidth=1)\n",
138 | "\n",
139 | "# filter to get valid indices\n",
140 | "threshold = 240\n",
141 | "valid_ids = df.groupby('id_segment')['time'].count().loc[lambda x : x>threshold].reset_index()['id_segment']\n",
142 | "df_filtered = df.loc[df['id_segment'].isin(valid_ids)]\n",
143 | "\n",
144 | "# plot each segment\n",
145 | "num_segments = df_filtered['id_segment'].nunique()\n",
146 | "\n",
147 | "fig, axs = plt.subplots(1, num_segments, figsize=(30, 5))\n",
148 | "for i, (group, data) in enumerate(df_filtered.groupby('id_segment')):\n",
149 | " data.plot(x='time', y='gl', ax=axs[i], title='Segment {}'.format(group))"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "df.head(10)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "# plot acf of random samples from segments\n",
168 | "fig, ax = plt.subplots(2, 5, figsize=(30, 5))\n",
169 | "lags = 240\n",
170 | "for i, (group, data) in enumerate(df_filtered.groupby('id_segment')):\n",
171 | " # only view top 5\n",
172 | " if i < 5:\n",
173 | " data = data['gl']\n",
174 | " if len(data) < lags: # TODO: Could probably do filtering in pandas which would be faster\n",
175 | " print('Segment {} is too short'.format(group))\n",
176 | " continue\n",
177 | " # select 10 random samples from index of data\n",
178 | " sample = np.random.choice(range(len(data))[:-lags], 10, replace=False)\n",
179 | " # plot acf / pacf of each sample\n",
180 | " for j in sample:\n",
181 | " acf, acf_ci = sm.tsa.stattools.acf(data[j:j+lags], nlags=lags, alpha=0.05)\n",
182 | " pacf, pacf_ci = sm.tsa.stattools.pacf(data[j:j+lags], method='ols-adjusted', alpha=0.05)\n",
183 | " ax[0, i].plot(acf)\n",
184 | " ax[1, i].plot(pacf)\n"
185 | ]
186 | }
187 | ],
188 | "metadata": {
189 | "kernelspec": {
190 | "display_name": "Python 3 (ipykernel)",
191 | "language": "python",
192 | "name": "python3"
193 | },
194 | "language_info": {
195 | "codemirror_mode": {
196 | "name": "ipython",
197 | "version": 3
198 | },
199 | "file_extension": ".py",
200 | "mimetype": "text/x-python",
201 | "name": "python",
202 | "nbconvert_exporter": "python",
203 | "pygments_lexer": "ipython3",
204 | "version": "3.10.1"
205 | },
206 | "vscode": {
207 | "interpreter": {
208 | "hash": "95662931fb0811c75e2373330a012ba90aa4548ba779055436524c46bd94b0ee"
209 | }
210 | }
211 | },
212 | "nbformat": 4,
213 | "nbformat_minor": 2
214 | }
215 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/__init__.py
--------------------------------------------------------------------------------
/lib/arima.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Dict
2 | import sys
3 | import os
4 | import yaml
5 | import datetime
6 | import argparse
7 |
8 | from statsforecast.models import AutoARIMA
9 |
10 | # import data formatter
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
12 | from data_formatter.base import *
13 |
14 | def test_model(test_data, scaler, in_len, out_len, stride, target_col, group_col):
15 | errors = []
16 | for group, data in test_data.groupby(group_col):
17 | train_set = data[target_col].iloc[:in_len].values.flatten()
18 | # fit model
19 | model = AutoARIMA(start_p = 0,
20 | max_p = 10,
21 | start_q = 0,
22 | max_q = 10,
23 | start_P = 0,
24 | max_P = 10,
25 | start_Q=0,
26 | max_Q=10,
27 | allowdrift=True,
28 | allowmean=True,
29 | parallel=False)
30 | model.fit(train_set)
31 | # get valid sampling locations for future prediction
32 | start_idx = np.arange(start=stride, stop=len(data) - in_len - out_len + 1, step=stride)
33 | end_idx = start_idx + in_len
34 | # iterate and collect predictions
35 | for i in range(len(start_idx)):
36 | input = data[target_col].iloc[start_idx[i]:end_idx[i]].values.flatten()
37 | true = data[target_col].iloc[end_idx[i]:(end_idx[i]+out_len)].values.flatten()
38 | prediction = model.forward(input, h=out_len)['mean']
39 | # unscale true and prediction
40 | true = scaler.inverse_transform(true.reshape(-1, 1)).flatten()
41 | prediction = scaler.inverse_transform(prediction.reshape(-1, 1)).flatten()
42 | # collect errors
43 | errors.append(np.array([np.mean((true - prediction)**2), np.mean(np.abs(true - prediction))]))
44 | errors = np.vstack(errors)
45 | return errors
46 |
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--dataset', type=str, default='weinstock')
49 | parser.add_argument('--reduction1', type=str, default='mean')
50 | parser.add_argument('--reduction2', type=str, default='median')
51 | parser.add_argument('--reduction3', type=str, default=None)
52 | args = parser.parse_args()
53 | reductions = [args.reduction1, args.reduction2, args.reduction3]
54 | if __name__ == '__main__':
55 | # study file
56 | study_file = f'./output/arima_{args.dataset}.txt'
57 | # check that file exists otherwise create it
58 | if not os.path.exists(study_file):
59 | with open(study_file, "w") as f:
60 | # write current date and time
61 | f.write(f"Optimization started at {datetime.datetime.now()}\n")
62 | # load data
63 | with open(f'./config/{args.dataset}.yaml', 'r') as f:
64 | config = yaml.safe_load(f)
65 | config['scaling_params']['scaler'] = 'MinMaxScaler'
66 | formatter = DataFormatter(config, study_file = study_file)
67 |
68 | # set params
69 | in_len = formatter.params['max_length_input']
70 | out_len = formatter.params['length_pred']
71 | stride = formatter.params['length_pred'] // 2
72 | target_col = formatter.get_column('target')
73 | group_col = formatter.get_column('sid')
74 |
75 | seeds = list(range(10, 20))
76 | id_errors_cv = {key: [] for key in reductions if key is not None}
77 | ood_errors_cv = {key: [] for key in reductions if key is not None}
78 | for seed in seeds:
79 | formatter.reshuffle(seed)
80 | test_data = formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)]
81 | test_data_ood = formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]
82 |
83 | # backtest on the ID test set
84 | id_errors_sample = test_model(test_data,
85 | formatter.scalers[target_col[0]],
86 | in_len,
87 | out_len,
88 | stride,
89 | target_col,
90 | group_col)
91 | # backtest on the ood test set
92 | ood_errors_sample = test_model(test_data_ood,
93 | formatter.scalers[target_col[0]],
94 | in_len,
95 | out_len,
96 | stride,
97 | target_col,
98 | group_col)
99 | # compute, save, and print results
100 | with open(study_file, "a") as f:
101 | for reduction in reductions:
102 | if reduction is not None:
103 | # compute
104 | reduction_f = getattr(np, reduction)
105 | id_errors_sample_red = reduction_f(id_errors_sample, axis=0)
106 | ood_errors_sample_red = reduction_f(ood_errors_sample, axis=0)
107 | # save
108 | id_errors_cv[reduction].append(id_errors_sample_red)
109 | ood_errors_cv[reduction].append(ood_errors_sample_red)
110 | # print
111 | f.write(f"\tSeed: {seed} ID {reduction} of (MSE, MAE): {id_errors_sample_red.tolist()}\n")
112 | f.write(f"\tSeed: {seed} OOD {reduction} of (MSE, MAE) stats: {ood_errors_sample_red.tolist()}\n")
113 | # compute, save, and print results
114 | with open(study_file, "a") as f:
115 | for reduction in reductions:
116 | if reduction is not None:
117 | # compute
118 | id_errors_cv[reduction] = np.vstack(id_errors_cv[reduction])
119 | ood_errors_cv[reduction] = np.vstack(ood_errors_cv[reduction])
120 | id_errors_cv[reduction] = np.mean(id_errors_cv[reduction], axis=0)
121 | ood_errors_cv[reduction] = np.mean(ood_errors_cv[reduction], axis=0)
122 | # print
123 | f.write(f"ID {reduction} of (MSE, MAE): {id_errors_cv[reduction].tolist()}\n")
124 | f.write(f"OOD {reduction} of (MSE, MAE): {ood_errors_cv[reduction].tolist()}\n")
--------------------------------------------------------------------------------
/lib/gluformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/gluformer/__init__.py
--------------------------------------------------------------------------------
/lib/gluformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from math import sqrt
6 |
7 | class CausalConv1d(torch.nn.Conv1d):
8 | def __init__(self,
9 | in_channels,
10 | out_channels,
11 | kernel_size,
12 | stride=1,
13 | dilation=1,
14 | groups=1,
15 | bias=True):
16 | self.__padding = (kernel_size - 1) * dilation
17 |
18 | super(CausalConv1d, self).__init__(
19 | in_channels,
20 | out_channels,
21 | kernel_size=kernel_size,
22 | stride=stride,
23 | padding=self.__padding,
24 | dilation=dilation,
25 | groups=groups,
26 | bias=bias)
27 |
28 | def forward(self, input):
29 | result = super(CausalConv1d, self).forward(input)
30 | if self.__padding != 0:
31 | return result[:, :, :-self.__padding]
32 | return result
33 |
34 | class TriangularCausalMask():
35 | def __init__(self, b, n, device="cpu"):
36 | mask_shape = [b, 1, n, n]
37 | with torch.no_grad():
38 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
39 |
40 | @property
41 | def mask(self):
42 | return self._mask
43 |
44 | class MultiheadAttention(nn.Module):
45 | def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1):
46 | super(MultiheadAttention, self).__init__()
47 | self.h, self.d, self.mask_flag= n_heads, d_keys, mask_flag
48 | self.proj_q = nn.Linear(d_model, self.h * self.d)
49 | self.proj_k = nn.Linear(d_model, self.h * self.d)
50 | self.proj_v = nn.Linear(d_model, self.h * self.d)
51 | self.proj_out = nn.Linear(self.h * self.d, d_model)
52 | self.dropout = nn.Dropout(r_att_drop)
53 |
54 | def forward(self, q, k, v):
55 | b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d
56 |
57 | q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v) # b, n_*, h*d
58 | q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v]) # b, n_*, h, d
59 | scores = torch.einsum('bnhd,bmhd->bhnm', (q,k)) # b, h, n_q, n_k
60 |
61 | if self.mask_flag:
62 | att_mask = TriangularCausalMask(b, n_q, device=q.device)
63 | scores.masked_fill_(att_mask.mask, -np.inf)
64 |
65 | att = F.softmax(scores / (self.d ** .5), dim=-1) # b, h, n_q, n_k
66 | att = self.dropout(att)
67 | att_out = torch.einsum('bhnm,bmhd->bnhd', (att,v)) # b, n_q, h, d
68 | att_out = att_out.reshape(b, -1, h*d) # b, n_q, h*d
69 | out = self.proj_out(att_out) # b, n_q, d_model
70 | return out
71 |
--------------------------------------------------------------------------------
/lib/gluformer/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .attention import *
6 |
7 | class DecoderLayer(nn.Module):
8 | def __init__(self, self_att, cross_att, d_model, d_fcn,
9 | r_drop, activ="relu"):
10 | super(DecoderLayer, self).__init__()
11 |
12 | self.self_att = self_att
13 | self.cross_att = cross_att
14 |
15 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
16 | self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
17 |
18 | self.norm1 = nn.LayerNorm(d_model)
19 | self.norm2 = nn.LayerNorm(d_model)
20 | self.norm3 = nn.LayerNorm(d_model)
21 |
22 | self.dropout = nn.Dropout(r_drop)
23 | self.activ = F.relu if activ == "relu" else F.gelu
24 |
25 | def forward(self, x_dec, x_enc):
26 | x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec)
27 | x_dec = self.norm1(x_dec)
28 |
29 | x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc)
30 | res = x_dec = self.norm2(x_dec)
31 |
32 | res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
33 | res = self.dropout(self.conv2(res).transpose(-1,1))
34 |
35 | return self.norm3(x_dec+res)
36 |
37 | class Decoder(nn.Module):
38 | def __init__(self, layers, norm_layer=None):
39 | super(Decoder, self).__init__()
40 | self.layers = nn.ModuleList(layers)
41 | self.norm = norm_layer
42 |
43 | def forward(self, x_dec, x_enc):
44 | for layer in self.layers:
45 | x_dec = layer(x_dec, x_enc)
46 |
47 | if self.norm is not None:
48 | x_dec = self.norm(x_dec)
49 |
50 | return x_dec
51 |
--------------------------------------------------------------------------------
/lib/gluformer/embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | class PositionalEmbedding(nn.Module):
7 | def __init__(self, d_model, max_len=5000):
8 | super(PositionalEmbedding, self).__init__()
9 | # Compute the positional encodings once in log space.
10 | pos_emb = torch.zeros(max_len, d_model)
11 | pos_emb.require_grad = False
12 |
13 | position = torch.arange(0, max_len).unsqueeze(1)
14 | div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).exp()
15 |
16 | pos_emb[:, 0::2] = torch.sin(position * div_term)
17 | pos_emb[:, 1::2] = torch.cos(position * div_term)
18 |
19 | pos_emb = pos_emb.unsqueeze(0)
20 | self.register_buffer('pos_emb', pos_emb)
21 |
22 | def forward(self, x):
23 | return self.pos_emb[:, :x.size(1)]
24 |
25 | class TokenEmbedding(nn.Module):
26 | def __init__(self, d_model):
27 | super(TokenEmbedding, self).__init__()
28 | D_INP = 1 # one sequence
29 | self.conv = nn.Conv1d(in_channels=D_INP, out_channels=d_model,
30 | kernel_size=3, padding=1, padding_mode='circular')
31 | # nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='leaky_relu')
32 |
33 | def forward(self, x):
34 | x = self.conv(x.transpose(-1, 1)).transpose(-1, 1)
35 | return x
36 |
37 | class TemporalEmbedding(nn.Module):
38 | def __init__(self, d_model, num_features):
39 | super(TemporalEmbedding, self).__init__()
40 | self.embed = nn.Linear(num_features, d_model)
41 |
42 | def forward(self, x):
43 | return self.embed(x)
44 |
45 | class SubjectEmbedding(nn.Module):
46 | def __init__(self, d_model, num_features):
47 | super(SubjectEmbedding, self).__init__()
48 | self.id_embedding = nn.Linear(num_features, d_model)
49 |
50 | def forward(self, x):
51 | embed_x = self.id_embedding(x)
52 |
53 | return embed_x
54 |
55 | class DataEmbedding(nn.Module):
56 | def __init__(self, d_model, r_drop, num_dynamic_features, num_static_features):
57 | super(DataEmbedding, self).__init__()
58 | # note: d_model // 2 == 0
59 | self.value_embedding = TokenEmbedding(d_model)
60 | self.time_embedding = TemporalEmbedding(d_model, num_dynamic_features) # alternative: TimeFeatureEmbedding
61 | self.positional_embedding = PositionalEmbedding(d_model)
62 | self.subject_embedding = SubjectEmbedding(d_model, num_static_features)
63 | self.dropout = nn.Dropout(r_drop)
64 |
65 | def forward(self, x_id, x, x_mark):
66 | x = self.value_embedding(x) + self.positional_embedding(x) + self.time_embedding(x_mark)
67 | x_id = self.subject_embedding(x_id)
68 | x = torch.cat((x_id.unsqueeze(1), x), dim = 1)
69 | return self.dropout(x)
70 |
--------------------------------------------------------------------------------
/lib/gluformer/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .attention import *
6 |
7 | class ConvLayer(nn.Module):
8 | def __init__(self, d_model):
9 | super(ConvLayer, self).__init__()
10 | self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model,
11 | kernel_size=3, padding=1, padding_mode='circular')
12 | self.norm = nn.BatchNorm1d(d_model)
13 | self.activ = nn.ELU()
14 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
15 |
16 | def forward(self, x):
17 | x = self.downConv(x.transpose(-1, 1))
18 | x = self.norm(x)
19 | x = self.activ(x)
20 | x = self.maxPool(x)
21 | x = x.transpose(-1,1)
22 | return x
23 |
24 | class EncoderLayer(nn.Module):
25 | def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"):
26 | super(EncoderLayer, self).__init__()
27 |
28 | self.att = att
29 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
30 | self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
31 | self.norm1 = nn.LayerNorm(d_model)
32 | self.norm2 = nn.LayerNorm(d_model)
33 | self.dropout = nn.Dropout(r_drop)
34 | self.activ = F.relu if activ == "relu" else F.gelu
35 |
36 | def forward(self, x):
37 | new_x = self.att(x, x, x)
38 | x = x + self.dropout(new_x)
39 |
40 | res = x = self.norm1(x)
41 | res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
42 | res = self.dropout(self.conv2(res).transpose(-1,1))
43 |
44 | return self.norm2(x+res)
45 |
46 | class Encoder(nn.Module):
47 | def __init__(self, enc_layers, conv_layers=None, norm_layer=None):
48 | super(Encoder, self).__init__()
49 | self.enc_layers = nn.ModuleList(enc_layers)
50 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
51 | self.norm = norm_layer
52 |
53 | def forward(self, x):
54 | # x [B, L, D]
55 | if self.conv_layers is not None:
56 | for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers):
57 | x = enc_layer(x)
58 | x = conv_layer(x)
59 | x = self.enc_layers[-1](x)
60 | else:
61 | for enc_layer in self.enc_layers:
62 | x = enc_layer(x)
63 |
64 | if self.norm is not None:
65 | x = self.norm(x)
66 |
67 | return x
--------------------------------------------------------------------------------
/lib/gluformer/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/gluformer/utils/__init__.py
--------------------------------------------------------------------------------
/lib/gluformer/utils/collate.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple, Union, Mapping, Sequence
2 | import torch
3 | import re
4 | import collections
5 |
6 | np_str_obj_array_pattern = re.compile(r'[SaUO]')
7 |
8 | # Replace torch._six.string_classes with this
9 | string_classes = (str, bytes)
10 |
11 | def default_convert(data: Any) -> Any:
12 | r"""Converts each NumPy array data field into a tensor"""
13 | elem_type = type(data)
14 | if isinstance(data, torch.Tensor):
15 | return data
16 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
17 | and elem_type.__name__ != 'string_':
18 | # array of string classes and object
19 | if elem_type.__name__ == 'ndarray' \
20 | and np_str_obj_array_pattern.search(data.dtype.str) is not None:
21 | return data
22 | return torch.as_tensor(data)
23 | elif isinstance(data, collections.abc.Mapping):
24 | return {key: default_convert(data[key]) for key in data}
25 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
26 | return elem_type(*(default_convert(d) for d in data))
27 | elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes):
28 | return [default_convert(d) for d in data]
29 | else:
30 | return data
31 |
32 | default_collate_err_msg_format = (
33 | "default_collate: batch must contain tensors, numpy arrays, numbers, "
34 | "dicts or lists; found {}")
35 |
36 | def default_collate(batch: List[Any]) -> Union[torch.Tensor, List, Dict, Tuple]:
37 | r"""Puts each data field into a tensor with outer dimension batch size"""
38 |
39 | elem = batch[0]
40 | elem_type = type(elem)
41 | if isinstance(elem, torch.Tensor):
42 | out = None
43 | if torch.utils.data.get_worker_info() is not None:
44 | # If we're in a background process, concatenate directly into a
45 | # shared memory tensor to avoid an extra copy
46 | numel = sum(x.numel() for x in batch)
47 | storage = elem.storage()._new_shared(numel)
48 | out = elem.new(storage)
49 | return torch.stack(batch, 0, out=out)
50 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
51 | and elem_type.__name__ != 'string_':
52 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
53 | # array of string classes and object
54 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
55 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
56 |
57 | return default_collate([torch.as_tensor(b) for b in batch])
58 | elif elem.shape == (): # scalars
59 | return torch.as_tensor(batch)
60 | elif isinstance(elem, float):
61 | return torch.tensor(batch, dtype=torch.float64)
62 | elif isinstance(elem, int):
63 | return torch.tensor(batch)
64 | elif isinstance(elem, string_classes):
65 | return batch
66 | elif isinstance(elem, Mapping):
67 | return {key: default_collate([d[key] for d in batch]) for key in elem}
68 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
69 | return elem_type(*(default_collate(samples) for samples in zip(*batch)))
70 | elif isinstance(elem, Sequence):
71 | # check to make sure that the elements in batch have consistent size
72 | it = iter(batch)
73 | elem_size = len(next(it))
74 | if not all(len(elem) == elem_size for elem in it):
75 | raise RuntimeError('each element in list of batch should be of equal size')
76 | transposed = zip(*batch)
77 | return [default_collate(samples) for samples in transposed]
78 |
79 | raise TypeError(default_collate_err_msg_format.format(elem_type))
--------------------------------------------------------------------------------
/lib/gluformer/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import yaml
4 | import random
5 | from typing import Any, \
6 | BinaryIO, \
7 | Callable, \
8 | Dict, \
9 | List, \
10 | Optional, \
11 | Sequence, \
12 | Tuple, \
13 | Union
14 |
15 | import numpy as np
16 | import scipy as sp
17 | import pandas as pd
18 | import torch
19 |
20 | # import data formatter
21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
22 |
23 | def test(series: np.ndarray,
24 | forecasts: np.ndarray,
25 | var: np.ndarray,
26 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
27 | ):
28 | """
29 | Test the (rescaled to original scale) forecasts on the series.
30 |
31 | Parameters
32 | ----------
33 | series
34 | The target time series of shape (n, t),
35 | where t is length of prediction.
36 | forecasts
37 | The forecasted means of mixture components of shape (n, t, k),
38 | where k is the number of mixture components.
39 | var
40 | The forecasted variances of mixture components of shape (n, 1, k),
41 | where k is the number of mixture components.
42 | metric
43 | The metric or metrics to use for backtesting.
44 | cal_thresholds
45 | The thresholds to use for computing the calibration error.
46 |
47 | Returns
48 | -------
49 | np.ndarray
50 | Error array. Array of shape (n, p)
51 | where n = series.shape[0] = forecasts.shape[0] and p = len(metric).
52 | float
53 | The estimated log-likelihood of the model on the data.
54 | np.ndarray
55 | The ECE for each time point in the forecast.
56 | """
57 | # compute errors: 1) get samples 2) compute errors using median
58 | samples = np.random.normal(loc=forecasts[..., None],
59 | scale=np.sqrt(var)[..., None],
60 | size=(forecasts.shape[0],
61 | forecasts.shape[1],
62 | forecasts.shape[2],
63 | 30))
64 | samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
65 | mse = np.mean((series.squeeze() - forecasts.mean(axis=-1))**2, axis=-1)
66 | mae = np.mean(np.abs(series.squeeze() - forecasts.mean(axis=-1)), axis=-1)
67 | errors = np.stack([mse, mae], axis=-1)
68 |
69 | # compute likelihood
70 | log_likelihood = sp.special.logsumexp((forecasts - series)**2 / (2 * var) -
71 | 0.5 * np.log(2 * np.pi * var), axis=-1)
72 | log_likelihood = np.mean(log_likelihood)
73 |
74 | # compute calibration error:
75 | cal_error = np.zeros(forecasts.shape[1])
76 | for p in cal_thresholds:
77 | q = np.quantile(samples, p, axis=-1)
78 | est_p = np.mean(series.squeeze() <= q, axis=0)
79 | cal_error += (est_p - p) ** 2
80 |
81 | return errors, log_likelihood, cal_error
--------------------------------------------------------------------------------
/lib/gluformer/utils/training.py:
--------------------------------------------------------------------------------
1 | from re import sub
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 |
6 | from .collate import *
7 |
8 | class EarlyStop:
9 | def __init__(self, patience, delta):
10 | self.patience = patience
11 | self.delta = delta
12 | self.counter = 0
13 | self.best_loss = np.Inf
14 | self.stop = False
15 |
16 | def __call__(self, loss, model, path):
17 | if loss < self.best_loss:
18 | self.best_loss = loss
19 | self.counter = 0
20 | torch.save(model.state_dict(), path)
21 | elif loss > self.best_loss + self.delta:
22 | self.counter = self.counter + 1
23 | if self.counter >= self.patience:
24 | self.stop = True
25 |
26 | class ExpLikeliLoss(nn.Module):
27 | def __init__(self, num_samples = 100):
28 | # , var = 0.3
29 | super(ExpLikeliLoss, self).__init__()
30 | self.num_samples = num_samples
31 |
32 | def forward(self, pred, true, logvar):
33 | # pred & true: [b, l, d]
34 | b, l, d = pred.size(0), pred.size(1), pred.size(2)
35 | true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
36 | pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
37 | logvar = logvar.reshape(-1, self.num_samples)
38 |
39 | loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1))
40 | return loss
41 |
42 | def modify_collate(num_samples):
43 | '''
44 | Repeat each sample in the dataset.
45 | '''
46 | def wrapper(batch):
47 | batch_rep = [sample for sample in batch for i in range(num_samples)]
48 | return default_collate(batch_rep)
49 |
50 | return wrapper
51 |
52 | def adjust_learning_rate(model_optim, epoch, lr):
53 | lr = lr * (0.5 ** epoch)
54 | print("Learning rate halfing...")
55 | print(f"New lr: {lr:.7f}")
56 | for param_group in model_optim.param_groups:
57 | param_group['lr'] = lr
58 |
59 | def process_batch(subj_id,
60 | batch_x, batch_y,
61 | batch_x_mark, batch_y_mark,
62 | len_pred, len_label,
63 | model, device):
64 | # read data
65 | subj_id = subj_id.long().to(device)
66 | batch_x = batch_x.float().to(device)
67 | batch_y = batch_y.float()
68 | batch_x_mark = batch_x_mark.float().to(device)
69 | batch_y_mark = batch_y_mark.float().to(device)
70 |
71 | # extract true
72 | true = batch_y[:, -len_pred:, :].to(device)
73 |
74 | # decoder input
75 | dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]]).float()
76 | dec_inp = torch.cat([batch_y[:, :len_label, :], dec_inp], dim=1).float().to(device)
77 |
78 | # model prediction
79 | # pred = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)
80 | pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)
81 |
82 | # clean cache
83 | del subj_id
84 | del batch_x
85 | del batch_x_mark
86 | del dec_inp
87 | del batch_y_mark
88 | return pred, true, logvar
--------------------------------------------------------------------------------
/lib/gluformer/variance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Variance(nn.Module):
6 | def __init__(self, d_model, r_drop, len_seq):
7 | super(Variance, self).__init__()
8 |
9 | self.proj1 = nn.Linear(d_model, 1)
10 | self.dropout = nn.Dropout(r_drop)
11 | self.activ1 = nn.ReLU()
12 | # + 1 (for seq) for embedded person token
13 | self.proj2 = nn.Linear(len_seq+1, 1)
14 | self.activ2 = nn.Tanh()
15 |
16 | def forward(self, x):
17 | x = self.proj1(x)
18 | x = self.activ1(x)
19 | x = self.dropout(x)
20 | x = x.transpose(-1, 1)
21 | x = self.proj2(x)
22 | # scale to [-10, 10] range
23 | x = 10 * self.activ2(x)
24 | return x
--------------------------------------------------------------------------------
/lib/latent_ode/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yulia Rubanova
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lib/latent_ode/README.md:
--------------------------------------------------------------------------------
1 | # Latent ODEs for Irregularly-Sampled Time Series
2 |
3 | Code for the paper:
4 | > Yulia Rubanova, Ricky Chen, David Duvenaud. "Latent ODEs for Irregularly-Sampled Time Series" (2019)
5 | [[arxiv]](https://arxiv.org/abs/1907.03907)
6 |
7 |
8 |
9 |
10 |
11 | ## Prerequisites
12 |
13 | Install `torchdiffeq` from https://github.com/rtqichen/torchdiffeq.
14 |
15 | ## Experiments on different datasets
16 |
17 | By default, the dataset are downloadeded and processed when script is run for the first time.
18 |
19 | Raw datasets:
20 | [[MuJoCo]](http://www.cs.toronto.edu/~rtqichen/datasets/HopperPhysics/training.pt)
21 | [[Physionet]](https://physionet.org/physiobank/database/challenge/2012/)
22 | [[Human Activity]](https://archive.ics.uci.edu/ml/datasets/Localization+Data+for+Person+Activity/)
23 |
24 | To generate MuJoCo trajectories from scratch, [DeepMind Control Suite](https://github.com/deepmind/dm_control/) is required
25 |
26 |
27 | * Toy dataset of 1d periodic functions
28 | ```
29 | python3 run_models.py --niters 500 -n 1000 -s 50 -l 10 --dataset periodic --latent-ode --noise-weight 0.01
30 | ```
31 |
32 | * MuJoCo
33 |
34 | ```
35 | python3 run_models.py --niters 300 -n 10000 -l 15 --dataset hopper --latent-ode --rec-dims 30 --gru-units 100 --units 300 --gen-layers 3 --rec-layers 3
36 | ```
37 |
38 | * Physionet (discretization by 1 min)
39 | ```
40 | python3 run_models.py --niters 100 -n 8000 -l 20 --dataset physionet --latent-ode --rec-dims 40 --rec-layers 3 --gen-layers 3 --units 50 --gru-units 50 --quantization 0.016 --classif
41 |
42 | ```
43 |
44 | * Human Activity
45 | ```
46 | python3 run_models.py --niters 200 -n 10000 -l 15 --dataset activity --latent-ode --rec-dims 100 --rec-layers 4 --gen-layers 2 --units 500 --gru-units 50 --classif --linear-classif
47 |
48 | ```
49 |
50 |
51 | ### Running different models
52 |
53 | * ODE-RNN
54 | ```
55 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --ode-rnn
56 | ```
57 |
58 | * Latent ODE with ODE-RNN encoder
59 | ```
60 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode
61 | ```
62 |
63 | * Latent ODE with ODE-RNN encoder and poisson likelihood
64 | ```
65 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --poisson
66 | ```
67 |
68 | * Latent ODE with RNN encoder (Chen et al, 2018)
69 | ```
70 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --latent-ode --z0-encoder rnn
71 | ```
72 |
73 | * RNN-VAE
74 | ```
75 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --rnn-vae
76 | ```
77 |
78 | * Classic RNN
79 | ```
80 | python3 run_models.py --niters 500 -n 1000 -l 10 --dataset periodic --classic-rnn
81 | ```
82 |
83 | * GRU-D
84 |
85 | GRU-D consists of two parts: input imputation (--input-decay) and exponential decay of the hidden state (--rnn-cell expdecay)
86 |
87 | ```
88 | python3 run_models.py --niters 500 -n 100 -b 30 -l 10 --dataset periodic --classic-rnn --input-decay --rnn-cell expdecay
89 | ```
90 |
91 |
92 | ### Making the visualization
93 | ```
94 | python3 run_models.py --niters 100 -n 5000 -b 100 -l 3 --dataset periodic --latent-ode --noise-weight 0.5 --lr 0.01 --viz --rec-layers 2 --gen-layers 2 -u 100 -c 30
95 | ```
96 |
--------------------------------------------------------------------------------
/lib/latent_ode/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/lib/latent_ode/__init__.py
--------------------------------------------------------------------------------
/lib/latent_ode/base_models.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 |
11 | from . import utils as utils
12 | from .encoder_decoder import *
13 | from .likelihood_eval import *
14 |
15 | from torch.distributions.multivariate_normal import MultivariateNormal
16 | from torch.distributions.normal import Normal
17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
18 |
19 | from torch.distributions.normal import Normal
20 | from torch.distributions import Independent
21 | from torch.nn.parameter import Parameter
22 |
23 |
24 | def create_classifier(z0_dim, n_labels):
25 | return nn.Sequential(
26 | nn.Linear(z0_dim, 300),
27 | nn.ReLU(),
28 | nn.Linear(300, 300),
29 | nn.ReLU(),
30 | nn.Linear(300, n_labels),)
31 |
32 |
33 | class Baseline(nn.Module):
34 | def __init__(self, input_dim, latent_dim, device,
35 | obsrv_std = 0.01, use_binary_classif = False,
36 | classif_per_tp = False,
37 | use_poisson_proc = False,
38 | linear_classifier = False,
39 | n_labels = 1,
40 | train_classif_w_reconstr = False):
41 | super(Baseline, self).__init__()
42 |
43 | self.input_dim = input_dim
44 | self.latent_dim = latent_dim
45 | self.n_labels = n_labels
46 |
47 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
48 | self.device = device
49 |
50 | self.use_binary_classif = use_binary_classif
51 | self.classif_per_tp = classif_per_tp
52 | self.use_poisson_proc = use_poisson_proc
53 | self.linear_classifier = linear_classifier
54 | self.train_classif_w_reconstr = train_classif_w_reconstr
55 |
56 | z0_dim = latent_dim
57 | if use_poisson_proc:
58 | z0_dim += latent_dim
59 |
60 | if use_binary_classif:
61 | if linear_classifier:
62 | self.classifier = nn.Sequential(
63 | nn.Linear(z0_dim, n_labels))
64 | else:
65 | self.classifier = create_classifier(z0_dim, n_labels)
66 | utils.init_network_weights(self.classifier)
67 |
68 |
69 | def get_gaussian_likelihood(self, truth, pred_y, mask = None):
70 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
71 | # truth shape [n_traj, n_tp, n_dim]
72 | if mask is not None:
73 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
74 |
75 | # Compute likelihood of the data under the predictions
76 | log_density_data = masked_gaussian_log_density(pred_y, truth,
77 | obsrv_std = self.obsrv_std, mask = mask)
78 | log_density_data = log_density_data.permute(1,0)
79 |
80 | # Compute the total density
81 | # Take mean over n_traj_samples
82 | log_density = torch.mean(log_density_data, 0)
83 |
84 | # shape: [n_traj]
85 | return log_density
86 |
87 |
88 | def get_mse(self, truth, pred_y, mask = None):
89 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
90 | # truth shape [n_traj, n_tp, n_dim]
91 | if mask is not None:
92 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
93 |
94 | # Compute likelihood of the data under the predictions
95 | log_density_data = compute_mse(pred_y, truth, mask = mask)
96 | # shape: [1]
97 | return torch.mean(log_density_data)
98 |
99 |
100 | def compute_all_losses(self, batch_dict,
101 | n_tp_to_sample = None, n_traj_samples = 1, kl_coef = 1.):
102 |
103 | # Condition on subsampled points
104 | # Make predictions for all the points
105 | pred_x, info = self.get_reconstruction(batch_dict["tp_to_predict"],
106 | batch_dict["observed_data"], batch_dict["observed_tp"],
107 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
108 | mode = batch_dict["mode"])
109 |
110 | # Compute likelihood of all the points
111 | likelihood = self.get_gaussian_likelihood(batch_dict["data_to_predict"], pred_x,
112 | mask = batch_dict["mask_predicted_data"])
113 |
114 | mse = self.get_mse(batch_dict["data_to_predict"], pred_x,
115 | mask = batch_dict["mask_predicted_data"])
116 |
117 | ################################
118 | # Compute CE loss for binary classification on Physionet
119 | # Use only last attribute -- mortatility in the hospital
120 | device = get_device(batch_dict["data_to_predict"])
121 | ce_loss = torch.Tensor([0.]).to(device)
122 |
123 | if (batch_dict["labels"] is not None) and self.use_binary_classif:
124 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
125 | ce_loss = compute_binary_CE_loss(
126 | info["label_predictions"],
127 | batch_dict["labels"])
128 | else:
129 | ce_loss = compute_multiclass_CE_loss(
130 | info["label_predictions"],
131 | batch_dict["labels"],
132 | mask = batch_dict["mask_predicted_data"])
133 |
134 | if torch.isnan(ce_loss):
135 | print("label pred")
136 | print(info["label_predictions"])
137 | print("labels")
138 | print( batch_dict["labels"])
139 | raise Exception("CE loss is Nan!")
140 |
141 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
142 | if self.use_poisson_proc:
143 | pois_log_likelihood = compute_poisson_proc_likelihood(
144 | batch_dict["data_to_predict"], pred_x,
145 | info, mask = batch_dict["mask_predicted_data"])
146 | # Take mean over n_traj
147 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1)
148 |
149 | loss = - torch.mean(likelihood)
150 |
151 | if self.use_poisson_proc:
152 | loss = loss - 0.1 * pois_log_likelihood
153 |
154 | if self.use_binary_classif:
155 | if self.train_classif_w_reconstr:
156 | loss = loss + ce_loss * 100
157 | else:
158 | loss = ce_loss
159 |
160 | # Take mean over the number of samples in a batch
161 | results = {}
162 | results["loss"] = torch.mean(loss)
163 | results["likelihood"] = torch.mean(likelihood).detach()
164 | results["mse"] = torch.mean(mse).detach()
165 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
166 | results["ce_loss"] = torch.mean(ce_loss).detach()
167 | results["kl"] = 0.
168 | results["kl_first_p"] = 0.
169 | results["std_first_p"] = 0.
170 |
171 | if batch_dict["labels"] is not None and self.use_binary_classif:
172 | results["label_predictions"] = info["label_predictions"].detach()
173 | return results
174 |
175 |
176 |
177 | class VAE_Baseline(nn.Module):
178 | def __init__(self, input_dim, latent_dim,
179 | z0_prior, device,
180 | obsrv_std = 0.01,
181 | use_binary_classif = False,
182 | classif_per_tp = False,
183 | use_poisson_proc = False,
184 | linear_classifier = False,
185 | n_labels = 1,
186 | train_classif_w_reconstr = False):
187 |
188 | super(VAE_Baseline, self).__init__()
189 |
190 | self.input_dim = input_dim
191 | self.latent_dim = latent_dim
192 | self.device = device
193 | self.n_labels = n_labels
194 |
195 | self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
196 |
197 | self.z0_prior = z0_prior
198 | self.use_binary_classif = use_binary_classif
199 | self.classif_per_tp = classif_per_tp
200 | self.use_poisson_proc = use_poisson_proc
201 | self.linear_classifier = linear_classifier
202 | self.train_classif_w_reconstr = train_classif_w_reconstr
203 |
204 | z0_dim = latent_dim
205 | if use_poisson_proc:
206 | z0_dim += latent_dim
207 |
208 | if use_binary_classif:
209 | if linear_classifier:
210 | self.classifier = nn.Sequential(
211 | nn.Linear(z0_dim, n_labels))
212 | else:
213 | self.classifier = create_classifier(z0_dim, n_labels)
214 | utils.init_network_weights(self.classifier)
215 |
216 |
217 | def get_gaussian_likelihood(self, truth, pred_y, mask = None):
218 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
219 | # truth shape [n_traj, n_tp, n_dim]
220 | n_traj, n_tp, n_dim = truth.size()
221 |
222 | # Compute likelihood of the data under the predictions
223 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
224 |
225 | if mask is not None:
226 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
227 | log_density_data = masked_gaussian_log_density(pred_y, truth_repeated,
228 | obsrv_std = self.obsrv_std, mask = mask)
229 | log_density_data = log_density_data.permute(1,0)
230 | log_density = torch.mean(log_density_data, 1)
231 |
232 | # shape: [n_traj_samples]
233 | return log_density
234 |
235 |
236 | def get_mse(self, truth, pred_y, mask = None):
237 | # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
238 | # truth shape [n_traj, n_tp, n_dim]
239 | n_traj, n_tp, n_dim = truth.size()
240 |
241 | # Compute likelihood of the data under the predictions
242 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
243 |
244 | if mask is not None:
245 | mask = mask.repeat(pred_y.size(0), 1, 1, 1)
246 |
247 | # Compute likelihood of the data under the predictions
248 | log_density_data = compute_mse(pred_y, truth_repeated, mask = mask)
249 | # shape: [1]
250 | return torch.mean(log_density_data)
251 |
252 |
253 | def compute_all_losses(self, batch_dict, n_traj_samples = 1, kl_coef = 1.):
254 | # Condition on subsampled points
255 | # Make predictions for all the points
256 | pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"],
257 | batch_dict["observed_data"], batch_dict["observed_tp"],
258 | mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
259 | mode = batch_dict["mode"])
260 |
261 | #print("get_reconstruction done -- computing likelihood")
262 | fp_mu, fp_std, fp_enc = info["first_point"]
263 | fp_std = fp_std.abs().clamp(min = 1e-2)
264 | fp_distr = Normal(fp_mu, fp_std)
265 |
266 | assert(torch.sum(fp_std < 0) == 0.)
267 |
268 | kldiv_z0 = kl_divergence(fp_distr, self.z0_prior)
269 |
270 | if torch.isnan(kldiv_z0).any():
271 | print(fp_mu)
272 | print(fp_std)
273 | raise Exception("kldiv_z0 is Nan!")
274 |
275 | # Mean over number of latent dimensions
276 | # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated)
277 | # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly)
278 | # shape after: [n_traj_samples]
279 | kldiv_z0 = torch.mean(kldiv_z0,(1,2))
280 |
281 | # Compute likelihood of all the points
282 | rec_likelihood = self.get_gaussian_likelihood(
283 | batch_dict["data_to_predict"], pred_y,
284 | mask = batch_dict["mask_predicted_data"])
285 |
286 | mse = self.get_mse(
287 | batch_dict["data_to_predict"], pred_y,
288 | mask = batch_dict["mask_predicted_data"])
289 |
290 | pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
291 | if self.use_poisson_proc:
292 | pois_log_likelihood = compute_poisson_proc_likelihood(
293 | batch_dict["data_to_predict"], pred_y,
294 | info, mask = batch_dict["mask_predicted_data"])
295 | # Take mean over n_traj
296 | pois_log_likelihood = torch.mean(pois_log_likelihood, 1)
297 |
298 | ################################
299 | # Compute CE loss for binary classification on Physionet
300 | device = get_device(batch_dict["data_to_predict"])
301 | ce_loss = torch.Tensor([0.]).to(device)
302 | if (batch_dict["labels"] is not None) and self.use_binary_classif:
303 |
304 | if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
305 | ce_loss = compute_binary_CE_loss(
306 | info["label_predictions"],
307 | batch_dict["labels"])
308 | else:
309 | ce_loss = compute_multiclass_CE_loss(
310 | info["label_predictions"],
311 | batch_dict["labels"],
312 | mask = batch_dict["mask_predicted_data"])
313 |
314 | # IWAE loss
315 | loss = - torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0,0)
316 | if torch.isnan(loss):
317 | loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)
318 |
319 | if self.use_poisson_proc:
320 | loss = loss - 0.1 * pois_log_likelihood
321 |
322 | if self.use_binary_classif:
323 | if self.train_classif_w_reconstr:
324 | loss = loss + ce_loss * 100
325 | else:
326 | loss = ce_loss
327 |
328 | results = {}
329 | results["loss"] = torch.mean(loss)
330 | results["likelihood"] = torch.mean(rec_likelihood).detach()
331 | results["mse"] = torch.mean(mse).detach()
332 | results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
333 | results["ce_loss"] = torch.mean(ce_loss).detach()
334 | results["kl_first_p"] = torch.mean(kldiv_z0).detach()
335 | results["std_first_p"] = torch.mean(fp_std).detach()
336 |
337 | if batch_dict["labels"] is not None and self.use_binary_classif:
338 | results["label_predictions"] = info["label_predictions"].detach()
339 |
340 | return results
341 |
342 |
343 |
344 |
--------------------------------------------------------------------------------
/lib/latent_ode/create_latent_ode_model.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import os
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn.functional import relu
12 |
13 | from . import utils as utils
14 | from .latent_ode import LatentODE
15 | from .encoder_decoder import *
16 | from .diffeq_solver import DiffeqSolver
17 |
18 | from torch.distributions.normal import Normal
19 | from .ode_func import ODEFunc, ODEFunc_w_Poisson
20 |
21 | #####################################################################################################
22 |
23 | def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device,
24 | classif_per_tp = False, n_labels = 1):
25 |
26 | dim = args.latents
27 | if args.poisson:
28 | lambda_net = utils.create_net(dim, input_dim,
29 | n_layers = 1, n_units = args.units, nonlinear = nn.Tanh)
30 |
31 | # ODE function produces the gradient for latent state and for poisson rate
32 | ode_func_net = utils.create_net(dim * 2, args.latents * 2,
33 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)
34 |
35 | gen_ode_func = ODEFunc_w_Poisson(
36 | input_dim = input_dim,
37 | latent_dim = args.latents * 2,
38 | ode_func_net = ode_func_net,
39 | lambda_net = lambda_net,
40 | device = device).to(device)
41 | else:
42 | dim = args.latents
43 | ode_func_net = utils.create_net(dim, args.latents,
44 | n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)
45 |
46 | gen_ode_func = ODEFunc(
47 | input_dim = input_dim,
48 | latent_dim = args.latents,
49 | ode_func_net = ode_func_net,
50 | device = device).to(device)
51 |
52 | z0_diffeq_solver = None
53 | n_rec_dims = args.rec_dims
54 | enc_input_dim = int(input_dim) * 2 # we concatenate the mask
55 | gen_data_dim = input_dim
56 |
57 | z0_dim = args.latents
58 | if args.poisson:
59 | z0_dim += args.latents # predict the initial poisson rate
60 |
61 | if args.z0_encoder == "odernn":
62 | ode_func_net = utils.create_net(n_rec_dims, n_rec_dims,
63 | n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)
64 |
65 | rec_ode_func = ODEFunc(
66 | input_dim = enc_input_dim,
67 | latent_dim = n_rec_dims,
68 | ode_func_net = ode_func_net,
69 | device = device).to(device)
70 |
71 | z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents,
72 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
73 |
74 | encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver,
75 | z0_dim = z0_dim, n_gru_units = args.gru_units, device = device).to(device)
76 |
77 | elif args.z0_encoder == "rnn":
78 | encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim,
79 | lstm_output_size = n_rec_dims, device = device).to(device)
80 | else:
81 | raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder)
82 |
83 | decoder = Decoder(args.latents, gen_data_dim).to(device)
84 |
85 | diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 'dopri5', args.latents,
86 | odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
87 |
88 | model = LatentODE(
89 | input_dim = gen_data_dim,
90 | latent_dim = args.latents,
91 | encoder_z0 = encoder_z0,
92 | decoder = decoder,
93 | diffeq_solver = diffeq_solver,
94 | z0_prior = z0_prior,
95 | device = device,
96 | obsrv_std = obsrv_std,
97 | use_poisson_proc = args.poisson,
98 | use_binary_classif = args.classif,
99 | linear_classifier = args.linear_classif,
100 | classif_per_tp = classif_per_tp,
101 | n_labels = n_labels,
102 | train_classif_w_reconstr = (args.dataset == "physionet")
103 | ).to(device)
104 |
105 | return model
106 |
--------------------------------------------------------------------------------
/lib/latent_ode/diffeq_solver.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import time
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from . import utils as utils
13 | from torch.distributions.multivariate_normal import MultivariateNormal
14 |
15 | # git clone https://github.com/rtqichen/torchdiffeq.git
16 | from torchdiffeq import odeint as odeint
17 |
18 | #####################################################################################################
19 |
20 | class DiffeqSolver(nn.Module):
21 | def __init__(self, input_dim, ode_func, method, latents,
22 | odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")):
23 | super(DiffeqSolver, self).__init__()
24 |
25 | self.ode_method = method
26 | self.latents = latents
27 | self.device = device
28 | self.ode_func = ode_func
29 |
30 | self.odeint_rtol = odeint_rtol
31 | self.odeint_atol = odeint_atol
32 |
33 | def forward(self, first_point, time_steps_to_predict, backwards = False):
34 | """
35 | # Decode the trajectory through ODE Solver
36 | """
37 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1]
38 | n_dims = first_point.size()[-1]
39 |
40 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict,
41 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
42 | pred_y = pred_y.permute(1,2,0,3)
43 |
44 | assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001)
45 | assert(pred_y.size()[0] == n_traj_samples)
46 | assert(pred_y.size()[1] == n_traj)
47 |
48 | return pred_y
49 |
50 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict,
51 | n_traj_samples = 1):
52 | """
53 | # Decode the trajectory through ODE Solver using samples from the prior
54 |
55 | time_steps_to_predict: time steps at which we want to sample the new trajectory
56 | """
57 | func = self.ode_func.sample_next_point_from_prior
58 |
59 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict,
60 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
61 | # shape: [n_traj_samples, n_traj, n_tp, n_dim]
62 | pred_y = pred_y.permute(1,2,0,3)
63 | return pred_y
64 |
65 |
66 |
--------------------------------------------------------------------------------
/lib/latent_ode/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 | from . import utils as utils
11 | from torch.distributions import Categorical, Normal
12 | from torch.nn.modules.rnn import LSTM, GRU
13 | from .utils import get_device
14 |
15 |
16 | # GRU description:
17 | # http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
18 | class GRU_unit(nn.Module):
19 | def __init__(self, latent_dim, input_dim,
20 | update_gate = None,
21 | reset_gate = None,
22 | new_state_net = None,
23 | n_units = 100,
24 | device = torch.device("cpu")):
25 | super(GRU_unit, self).__init__()
26 |
27 | if update_gate is None:
28 | self.update_gate = nn.Sequential(
29 | nn.Linear(latent_dim * 2 + input_dim, n_units),
30 | nn.Tanh(),
31 | nn.Linear(n_units, latent_dim),
32 | nn.Sigmoid())
33 | utils.init_network_weights(self.update_gate)
34 | else:
35 | self.update_gate = update_gate
36 |
37 | if reset_gate is None:
38 | self.reset_gate = nn.Sequential(
39 | nn.Linear(latent_dim * 2 + input_dim, n_units),
40 | nn.Tanh(),
41 | nn.Linear(n_units, latent_dim),
42 | nn.Sigmoid())
43 | utils.init_network_weights(self.reset_gate)
44 | else:
45 | self.reset_gate = reset_gate
46 |
47 | if new_state_net is None:
48 | self.new_state_net = nn.Sequential(
49 | nn.Linear(latent_dim * 2 + input_dim, n_units),
50 | nn.Tanh(),
51 | nn.Linear(n_units, latent_dim * 2))
52 | utils.init_network_weights(self.new_state_net)
53 | else:
54 | self.new_state_net = new_state_net
55 |
56 |
57 | def forward(self, y_mean, y_std, x, masked_update = True):
58 | y_concat = torch.cat([y_mean, y_std, x], -1)
59 |
60 | update_gate = self.update_gate(y_concat)
61 | reset_gate = self.reset_gate(y_concat)
62 | concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)
63 |
64 | new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat))
65 | new_state_std = new_state_std.abs()
66 |
67 | new_y = (1-update_gate) * new_state + update_gate * y_mean
68 | new_y_std = (1-update_gate) * new_state_std + update_gate * y_std
69 |
70 | assert(not torch.isnan(new_y).any())
71 |
72 | if masked_update:
73 | # IMPORTANT: assumes that x contains both data and mask
74 | # update only the hidden states for hidden state only if at least one feature is present for the current time point
75 | n_data_dims = x.size(-1)//2
76 | mask = x[:, :, n_data_dims:]
77 | utils.check_mask(x[:, :, :n_data_dims], mask)
78 |
79 | mask = (torch.sum(mask, -1, keepdim = True) > 0).float()
80 |
81 | assert(not torch.isnan(mask).any())
82 |
83 | new_y = mask * new_y + (1-mask) * y_mean
84 | new_y_std = mask * new_y_std + (1-mask) * y_std
85 |
86 | if torch.isnan(new_y).any():
87 | print("new_y is nan!")
88 | print(mask)
89 | print(y_mean)
90 | print(prev_new_y)
91 | exit()
92 |
93 | new_y_std = new_y_std.abs()
94 | return new_y, new_y_std
95 |
96 |
97 |
98 | class Encoder_z0_RNN(nn.Module):
99 | def __init__(self, latent_dim, input_dim, lstm_output_size = 20,
100 | use_delta_t = True, device = torch.device("cpu")):
101 |
102 | super(Encoder_z0_RNN, self).__init__()
103 |
104 | self.gru_rnn_output_size = lstm_output_size
105 | self.latent_dim = latent_dim
106 | self.input_dim = input_dim
107 | self.device = device
108 | self.use_delta_t = use_delta_t
109 |
110 | self.hiddens_to_z0 = nn.Sequential(
111 | nn.Linear(self.gru_rnn_output_size, 50),
112 | nn.Tanh(),
113 | nn.Linear(50, latent_dim * 2),)
114 |
115 | utils.init_network_weights(self.hiddens_to_z0)
116 |
117 | input_dim = self.input_dim
118 |
119 | if use_delta_t:
120 | self.input_dim += 1
121 | self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device)
122 |
123 | def forward(self, data, time_steps, run_backwards = True):
124 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
125 |
126 | # data shape: [n_traj, n_tp, n_dims]
127 | # shape required for rnn: (seq_len, batch, input_size)
128 | # t0: not used here
129 | n_traj = data.size(0)
130 |
131 | assert(not torch.isnan(data).any())
132 | assert(not torch.isnan(time_steps).any())
133 |
134 | data = data.permute(1,0,2)
135 |
136 | if run_backwards:
137 | # Look at data in the reverse order: from later points to the first
138 | data = utils.reverse(data)
139 |
140 | if self.use_delta_t:
141 | delta_t = time_steps[1:] - time_steps[:-1]
142 | if run_backwards:
143 | # we are going backwards in time with
144 | delta_t = utils.reverse(delta_t)
145 | # append zero delta t in the end
146 | delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device)))
147 | delta_t = delta_t.unsqueeze(1).repeat((1,n_traj)).unsqueeze(-1)
148 | data = torch.cat((delta_t, data),-1)
149 |
150 | outputs, _ = self.gru_rnn(data)
151 |
152 | # LSTM output shape: (seq_len, batch, num_directions * hidden_size)
153 | last_output = outputs[-1]
154 |
155 | self.extra_info ={"rnn_outputs": outputs, "time_points": time_steps}
156 |
157 | mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
158 | std = std.abs()
159 |
160 | assert(not torch.isnan(mean).any())
161 | assert(not torch.isnan(std).any())
162 |
163 | return mean.unsqueeze(0), std.unsqueeze(0)
164 |
165 |
166 |
167 |
168 |
169 | class Encoder_z0_ODE_RNN(nn.Module):
170 | # Derive z0 by running ode backwards.
171 | # For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i
172 | # Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1
173 | # Continue until we get to z0
174 | def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None,
175 | z0_dim = None, GRU_update = None,
176 | n_gru_units = 100,
177 | device = torch.device("cpu")):
178 |
179 | super(Encoder_z0_ODE_RNN, self).__init__()
180 |
181 | if z0_dim is None:
182 | self.z0_dim = latent_dim
183 | else:
184 | self.z0_dim = z0_dim
185 |
186 | if GRU_update is None:
187 | self.GRU_update = GRU_unit(latent_dim, input_dim,
188 | n_units = n_gru_units,
189 | device=device).to(device)
190 | else:
191 | self.GRU_update = GRU_update
192 |
193 | self.z0_diffeq_solver = z0_diffeq_solver
194 | self.latent_dim = latent_dim
195 | self.input_dim = input_dim
196 | self.device = device
197 | self.extra_info = None
198 |
199 | self.transform_z0 = nn.Sequential(
200 | nn.Linear(latent_dim * 2, 100),
201 | nn.Tanh(),
202 | nn.Linear(100, self.z0_dim * 2),)
203 | utils.init_network_weights(self.transform_z0)
204 |
205 |
206 | def forward(self, data, time_steps, run_backwards = True, save_info = False):
207 | # data, time_steps -- observations and their time stamps
208 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
209 | assert(not torch.isnan(data).any())
210 | assert(not torch.isnan(time_steps).any())
211 |
212 | n_traj, n_tp, n_dims = data.size()
213 | if len(time_steps) == 1:
214 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
215 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
216 |
217 | xi = data[:,0,:].unsqueeze(0)
218 |
219 | last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
220 | extra_info = None
221 | else:
222 |
223 | last_yi, last_yi_std, _, extra_info = self.run_odernn(
224 | data, time_steps, run_backwards = run_backwards,
225 | save_info = save_info)
226 |
227 | means_z0 = last_yi.reshape(1, n_traj, self.latent_dim)
228 | std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim)
229 |
230 | mean_z0, std_z0 = utils.split_last_dim( self.transform_z0( torch.cat((means_z0, std_z0), -1)))
231 | std_z0 = std_z0.abs()
232 | if save_info:
233 | self.extra_info = extra_info
234 |
235 | return mean_z0, std_z0
236 |
237 |
238 | def run_odernn(self, data, time_steps,
239 | run_backwards = True, save_info = False):
240 | # IMPORTANT: assumes that 'data' already has mask concatenated to it
241 |
242 | n_traj, n_tp, n_dims = data.size()
243 | extra_info = []
244 |
245 | t0 = time_steps[-1]
246 | if run_backwards:
247 | t0 = time_steps[0]
248 |
249 | device = get_device(data)
250 |
251 | prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
252 | prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)
253 |
254 | prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]
255 |
256 | interval_length = time_steps[-1] - time_steps[0]
257 | minimum_step = interval_length / 50
258 |
259 | #print("minimum step: {}".format(minimum_step))
260 |
261 | assert(not torch.isnan(data).any())
262 | assert(not torch.isnan(time_steps).any())
263 |
264 | latent_ys = []
265 | # Run ODE backwards and combine the y(t) estimates using gating
266 | time_points_iter = range(0, len(time_steps))
267 | if run_backwards:
268 | time_points_iter = reversed(time_points_iter)
269 |
270 | for i in time_points_iter:
271 | if (prev_t - t_i) < minimum_step:
272 | time_points = torch.stack((prev_t, t_i))
273 | inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t)
274 |
275 | assert(not torch.isnan(inc).any())
276 |
277 | ode_sol = prev_y + inc
278 | ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)
279 |
280 | assert(not torch.isnan(ode_sol).any())
281 | else:
282 | n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int())
283 |
284 | time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp)
285 | ode_sol = self.z0_diffeq_solver(prev_y, time_points)
286 |
287 | assert(not torch.isnan(ode_sol).any())
288 |
289 | if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
290 | print("Error: first point of the ODE is not equal to initial value")
291 | print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
292 | exit()
293 | #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001)
294 |
295 | yi_ode = ode_sol[:, :, -1, :]
296 | xi = data[:,i,:].unsqueeze(0)
297 |
298 | yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)
299 |
300 | prev_y, prev_std = yi, yi_std
301 | prev_t, t_i = time_steps[i], time_steps[i-1]
302 |
303 | latent_ys.append(yi)
304 |
305 | if save_info:
306 | d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data,
307 | "yi": yi.detach(), "yi_std": yi_std.detach(),
308 | "time_points": time_points.detach(), "ode_sol": ode_sol.detach()}
309 | extra_info.append(d)
310 |
311 | latent_ys = torch.stack(latent_ys, 1)
312 |
313 | assert(not torch.isnan(yi).any())
314 | assert(not torch.isnan(yi_std).any())
315 |
316 | return yi, yi_std, latent_ys, extra_info
317 |
318 |
319 |
320 | class Decoder(nn.Module):
321 | def __init__(self, latent_dim, input_dim):
322 | super(Decoder, self).__init__()
323 | # decode data from latent space where we are solving an ODE back to the data space
324 |
325 | decoder = nn.Sequential(
326 | nn.Linear(latent_dim, input_dim),)
327 |
328 | utils.init_network_weights(decoder)
329 | self.decoder = decoder
330 |
331 | def forward(self, data):
332 | return self.decoder(data)
333 |
334 |
335 |
--------------------------------------------------------------------------------
/lib/latent_ode/eval_glunet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import yaml
4 | import random
5 | from typing import Optional
6 |
7 | import numpy as np
8 | import scipy as sp
9 | import pandas as pd
10 | import torch
11 |
12 | # import likelihood evaluation
13 | from .likelihood_eval import masked_gaussian_log_density
14 |
15 | def test(series: np.ndarray,
16 | forecasts: np.ndarray,
17 | obsrv_std: np.ndarray,
18 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
19 | ):
20 | """
21 | Test the (rescaled to original scale) forecasts on the series.
22 |
23 | Parameters
24 | ----------
25 | series
26 | The target time series of shape (n_traj, n_tp, n_dim),
27 | where t is length of prediction.
28 | forecasts
29 | The forecasted means of mixture components of shape (n_traj_samples, n_traj, n_tp, n_dim)
30 | where k is the number of mixture components.
31 | obsrv_std
32 | The forecasted std of mixture components of shape (1).
33 | cal_thresholds
34 | The thresholds to use for computing the calibration error.
35 |
36 | Returns
37 | -------
38 | np.ndarray
39 | Error array. Array of shape (n_traj, 2), where
40 | along last dimension, we have MSE and MAE.
41 | float
42 | The estimated log-likelihood of the model on the data.
43 | np.ndarray
44 | The ECE for each time point in the forecast.
45 | """
46 | mse = np.mean((series - forecasts.mean(axis=0))**2, axis=-2)
47 | mae = np.mean(np.abs(series - forecasts.mean(axis=0)), axis=-2)
48 | errors = np.stack([mse.squeeze(), mae.squeeze()], axis=-1)
49 |
50 | # compute likelihood
51 | series, forecasts = torch.tensor(series), torch.tensor(forecasts)
52 | obsrv_std = torch.Tensor(obsrv_std)
53 |
54 | series_repeated = series.repeat(forecasts.size(0), 1, 1, 1)
55 | log_density_data = masked_gaussian_log_density(forecasts, series_repeated,
56 | obsrv_std = obsrv_std, mask = None)
57 | log_density_data = log_density_data.permute(1,0)
58 | log_density = torch.mean(log_density_data, 1)
59 | log_likelihood = torch.mean(log_density).item()
60 |
61 | # compute calibration error
62 | samples = torch.distributions.Normal(loc=forecasts, scale=obsrv_std).sample((100, ))
63 | samples = samples.view(samples.shape[0] * samples.shape[1],
64 | samples.shape[2],
65 | samples.shape[3])
66 | series = series.squeeze()
67 | cal_error = torch.zeros(series.shape[1])
68 | for p in cal_thresholds:
69 | q = torch.quantile(samples, p, dim=0)
70 | est_p = torch.mean((series <= q).float(), dim=0)
71 | cal_error += (est_p - p) ** 2
72 | cal_error = cal_error.numpy()
73 |
74 | return errors, log_likelihood, cal_error
--------------------------------------------------------------------------------
/lib/latent_ode/latent_ode.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import sklearn as sk
8 | import numpy as np
9 | #import gc
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn.functional import relu
13 |
14 | from . import utils as utils
15 | from .utils import get_device
16 | from .encoder_decoder import *
17 | from .likelihood_eval import *
18 |
19 | from torch.distributions.multivariate_normal import MultivariateNormal
20 | from torch.distributions.normal import Normal
21 | from torch.distributions import kl_divergence, Independent
22 | from .base_models import VAE_Baseline
23 |
24 |
25 |
26 | class LatentODE(VAE_Baseline):
27 | def __init__(self, input_dim, latent_dim, encoder_z0, decoder, diffeq_solver,
28 | z0_prior, device, obsrv_std = None,
29 | use_binary_classif = False, use_poisson_proc = False,
30 | linear_classifier = False,
31 | classif_per_tp = False,
32 | n_labels = 1,
33 | train_classif_w_reconstr = False):
34 |
35 | super(LatentODE, self).__init__(
36 | input_dim = input_dim, latent_dim = latent_dim,
37 | z0_prior = z0_prior,
38 | device = device, obsrv_std = obsrv_std,
39 | use_binary_classif = use_binary_classif,
40 | classif_per_tp = classif_per_tp,
41 | linear_classifier = linear_classifier,
42 | use_poisson_proc = use_poisson_proc,
43 | n_labels = n_labels,
44 | train_classif_w_reconstr = train_classif_w_reconstr)
45 |
46 | self.encoder_z0 = encoder_z0
47 | self.diffeq_solver = diffeq_solver
48 | self.decoder = decoder
49 | self.use_poisson_proc = use_poisson_proc
50 |
51 | def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps,
52 | mask = None, n_traj_samples = 1, run_backwards = True, mode = None):
53 |
54 | if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \
55 | isinstance(self.encoder_z0, Encoder_z0_RNN):
56 |
57 | truth_w_mask = truth
58 | if mask is not None:
59 | truth_w_mask = torch.cat((truth, mask), -1)
60 | first_point_mu, first_point_std = self.encoder_z0(
61 | truth_w_mask, truth_time_steps, run_backwards = run_backwards)
62 |
63 | means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
64 | sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
65 | first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
66 |
67 | else:
68 | raise Exception("Unknown encoder type {}".format(type(self.encoder_z0).__name__))
69 |
70 | first_point_std = first_point_std.abs()
71 | assert(torch.sum(first_point_std < 0) == 0.)
72 |
73 | if self.use_poisson_proc:
74 | n_traj_samples, n_traj, n_dims = first_point_enc.size()
75 | # append a vector of zeros to compute the integral of lambda
76 | zeros = torch.zeros([n_traj_samples, n_traj,self.input_dim]).to(get_device(truth))
77 | first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
78 | means_z0_aug = torch.cat((means_z0, zeros), -1)
79 | else:
80 | first_point_enc_aug = first_point_enc
81 | means_z0_aug = means_z0
82 |
83 | assert(not torch.isnan(time_steps_to_predict).any())
84 | assert(not torch.isnan(first_point_enc).any())
85 | assert(not torch.isnan(first_point_enc_aug).any())
86 |
87 | # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
88 | sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict)
89 |
90 | if self.use_poisson_proc:
91 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
92 |
93 | assert(torch.sum(int_lambda[:,:,0,:]) == 0.)
94 | assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)
95 |
96 | pred_x = self.decoder(sol_y)
97 |
98 | all_extra_info = {
99 | "first_point": (first_point_mu, first_point_std, first_point_enc),
100 | "latent_traj": sol_y.detach()
101 | }
102 |
103 | if self.use_poisson_proc:
104 | # intergral of lambda from the last step of ODE Solver
105 | all_extra_info["int_lambda"] = int_lambda[:,:,-1,:]
106 | all_extra_info["log_lambda_y"] = log_lambda_y
107 |
108 | if self.use_binary_classif:
109 | if self.classif_per_tp:
110 | all_extra_info["label_predictions"] = self.classifier(sol_y)
111 | else:
112 | all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1)
113 |
114 | return pred_x, all_extra_info
115 |
116 |
117 | def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples = 1):
118 | # input_dim = starting_point.size()[-1]
119 | # starting_point = starting_point.view(1,1,input_dim)
120 |
121 | # Sample z0 from prior
122 | starting_point_enc = self.z0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1)
123 |
124 | starting_point_enc_aug = starting_point_enc
125 | if self.use_poisson_proc:
126 | n_traj_samples, n_traj, n_dims = starting_point_enc.size()
127 | # append a vector of zeros to compute the integral of lambda
128 | zeros = torch.zeros(n_traj_samples, n_traj,self.input_dim).to(self.device)
129 | starting_point_enc_aug = torch.cat((starting_point_enc, zeros), -1)
130 |
131 | sol_y = self.diffeq_solver.sample_traj_from_prior(starting_point_enc_aug, time_steps_to_predict,
132 | n_traj_samples = 3)
133 |
134 | if self.use_poisson_proc:
135 | sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
136 |
137 | return self.decoder(sol_y)
138 |
139 |
140 |
--------------------------------------------------------------------------------
/lib/latent_ode/likelihood_eval.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import gc
7 | import numpy as np
8 | import sklearn as sk
9 | import numpy as np
10 | #import gc
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn.functional import relu
14 |
15 | from . import utils as utils
16 | from .utils import get_device
17 | from .encoder_decoder import *
18 | from .likelihood_eval import *
19 |
20 | from torch.distributions.multivariate_normal import MultivariateNormal
21 | from torch.distributions.normal import Normal
22 | from torch.distributions import kl_divergence, Independent
23 |
24 |
25 | def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
26 | n_data_points = mu_2d.size()[-1]
27 |
28 | if n_data_points > 0:
29 | gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
30 | log_prob = gaussian.log_prob(data_2d)
31 | log_prob = log_prob / n_data_points
32 | else:
33 | log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
34 | return log_prob
35 |
36 |
37 | def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas):
38 | # masked_log_lambdas and masked_data
39 | n_data_points = masked_data.size()[-1]
40 |
41 | if n_data_points > 0:
42 | log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
43 | #log_prob = log_prob / n_data_points
44 | else:
45 | log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
46 | return log_prob
47 |
48 |
49 |
50 | def compute_binary_CE_loss(label_predictions, mortality_label):
51 | #print("Computing binary classification loss: compute_CE_loss")
52 |
53 | mortality_label = mortality_label.reshape(-1)
54 |
55 | if len(label_predictions.size()) == 1:
56 | label_predictions = label_predictions.unsqueeze(0)
57 |
58 | n_traj_samples = label_predictions.size(0)
59 | label_predictions = label_predictions.reshape(n_traj_samples, -1)
60 |
61 | idx_not_nan = ~torch.isnan(mortality_label)
62 | if len(idx_not_nan) == 0.:
63 | print("All are labels are NaNs!")
64 | ce_loss = torch.Tensor(0.).to(get_device(mortality_label))
65 |
66 | label_predictions = label_predictions[:,idx_not_nan]
67 | mortality_label = mortality_label[idx_not_nan]
68 |
69 | if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0:
70 | print("Warning: all examples in a batch belong to the same class -- please increase the batch size.")
71 |
72 | assert(not torch.isnan(label_predictions).any())
73 | assert(not torch.isnan(mortality_label).any())
74 |
75 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
76 | mortality_label = mortality_label.repeat(n_traj_samples, 1)
77 | ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)
78 |
79 | # divide by number of patients in a batch
80 | ce_loss = ce_loss / n_traj_samples
81 | return ce_loss
82 |
83 |
84 | def compute_multiclass_CE_loss(label_predictions, true_label, mask):
85 | #print("Computing multi-class classification loss: compute_multiclass_CE_loss")
86 |
87 | if (len(label_predictions.size()) == 3):
88 | label_predictions = label_predictions.unsqueeze(0)
89 |
90 | n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()
91 |
92 | # assert(not torch.isnan(label_predictions).any())
93 | # assert(not torch.isnan(true_label).any())
94 |
95 | # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
96 | true_label = true_label.repeat(n_traj_samples, 1, 1)
97 |
98 | label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims)
99 | true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)
100 |
101 | # choose time points with at least one measurement
102 | mask = torch.sum(mask, -1) > 0
103 |
104 | # repeat the mask for each label to mark that the label for this time point is present
105 | pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0)
106 |
107 | label_mask = mask
108 | pred_mask = pred_mask.repeat(n_traj_samples,1,1,1)
109 | label_mask = label_mask.repeat(n_traj_samples,1,1,1)
110 |
111 | pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
112 | label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)
113 |
114 | if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
115 | assert(label_predictions.size(-1) == true_label.size(-1))
116 | # targets are in one-hot encoding -- convert to indices
117 | _, true_label = true_label.max(-1)
118 |
119 | res = []
120 | for i in range(true_label.size(0)):
121 | pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool())
122 | labels = torch.masked_select(true_label[i], label_mask[i].bool())
123 |
124 | pred_masked = pred_masked.reshape(-1, n_dims)
125 |
126 | if (len(labels) == 0):
127 | continue
128 |
129 | ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
130 | res.append(ce_loss)
131 |
132 | ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
133 | ce_loss = torch.mean(ce_loss)
134 | # # divide by number of patients in a batch
135 | # ce_loss = ce_loss / n_traj_samples
136 | return ce_loss
137 |
138 |
139 |
140 |
141 | def compute_masked_likelihood(mu, data, mask, likelihood_func):
142 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
143 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
144 |
145 | res = []
146 | for i in range(n_traj_samples):
147 | for k in range(n_traj):
148 | for j in range(n_dims):
149 | data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool())
150 |
151 | #assert(torch.sum(data_masked == 0.) < 10)
152 |
153 | mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool())
154 | log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j))
155 | res.append(log_prob)
156 | # shape: [n_traj*n_traj_samples, 1]
157 |
158 | res = torch.stack(res, 0).to(get_device(data))
159 | res = res.reshape((n_traj_samples, n_traj, n_dims))
160 | # Take mean over the number of dimensions
161 | res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean
162 | res = res.transpose(0,1)
163 | return res
164 |
165 |
166 | def masked_gaussian_log_density(mu, data, obsrv_std, mask = None):
167 | # these cases are for plotting through plot_estim_density
168 | if (len(mu.size()) == 3):
169 | # add additional dimension for gp samples
170 | mu = mu.unsqueeze(0)
171 |
172 | if (len(data.size()) == 2):
173 | # add additional dimension for gp samples and time step
174 | data = data.unsqueeze(0).unsqueeze(2)
175 | elif (len(data.size()) == 3):
176 | # add additional dimension for gp samples
177 | data = data.unsqueeze(0)
178 |
179 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
180 |
181 | assert(data.size()[-1] == n_dims)
182 |
183 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
184 | if mask is None:
185 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
186 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
187 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
188 |
189 | res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std)
190 | res = res.reshape(n_traj_samples, n_traj).transpose(0,1)
191 | else:
192 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
193 | func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices)
194 | res = compute_masked_likelihood(mu, data, mask, func)
195 | return res
196 |
197 |
198 |
199 | def mse(mu, data, indices = None):
200 | n_data_points = mu.size()[-1]
201 |
202 | if n_data_points > 0:
203 | mse = nn.MSELoss()(mu, data)
204 | else:
205 | mse = torch.zeros([1]).to(get_device(data)).squeeze()
206 | return mse
207 |
208 |
209 | def compute_mse(mu, data, mask = None):
210 | # these cases are for plotting through plot_estim_density
211 | if (len(mu.size()) == 3):
212 | # add additional dimension for gp samples
213 | mu = mu.unsqueeze(0)
214 |
215 | if (len(data.size()) == 2):
216 | # add additional dimension for gp samples and time step
217 | data = data.unsqueeze(0).unsqueeze(2)
218 | elif (len(data.size()) == 3):
219 | # add additional dimension for gp samples
220 | data = data.unsqueeze(0)
221 |
222 | n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
223 | assert(data.size()[-1] == n_dims)
224 |
225 | # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
226 | if mask is None:
227 | mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
228 | n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
229 | data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
230 | res = mse(mu_flat, data_flat)
231 | else:
232 | # Compute the likelihood per patient so that we don't priorize patients with more measurements
233 | res = compute_masked_likelihood(mu, data, mask, mse)
234 | return res
235 |
236 |
237 |
238 |
239 | def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None):
240 | # Compute Poisson likelihood
241 | # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process
242 | # Sum log lambdas across all time points
243 | if mask is None:
244 | poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"]
245 | # Sum over data dims
246 | poisson_log_l = torch.mean(poisson_log_l, -1)
247 | else:
248 | # Compute likelihood of the data under the predictions
249 | truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
250 | mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1)
251 |
252 | # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
253 | int_lambda = info["int_lambda"]
254 | f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda)
255 | poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f)
256 | poisson_log_l = poisson_log_l.permute(1,0)
257 | # Take mean over n_traj
258 | #poisson_log_l = torch.mean(poisson_log_l, 1)
259 |
260 | # poisson_log_l shape: [n_traj_samples, n_traj]
261 | return poisson_log_l
262 |
263 |
264 |
265 |
--------------------------------------------------------------------------------
/lib/latent_ode/ode_func.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.utils.spectral_norm import spectral_norm
10 |
11 | from . import utils as utils
12 |
13 | #####################################################################################################
14 |
15 | class ODEFunc(nn.Module):
16 | def __init__(self, input_dim, latent_dim, ode_func_net, device = torch.device("cpu")):
17 | """
18 | input_dim: dimensionality of the input
19 | latent_dim: dimensionality used for ODE. Analog of a continous latent state
20 | """
21 | super(ODEFunc, self).__init__()
22 |
23 | self.input_dim = input_dim
24 | self.device = device
25 |
26 | utils.init_network_weights(ode_func_net)
27 | self.gradient_net = ode_func_net
28 |
29 | def forward(self, t_local, y, backwards = False):
30 | """
31 | Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
32 |
33 | t_local: current time point
34 | y: value at the current time point
35 | """
36 | grad = self.get_ode_gradient_nn(t_local, y)
37 | if backwards:
38 | grad = -grad
39 | return grad
40 |
41 | def get_ode_gradient_nn(self, t_local, y):
42 | return self.gradient_net(y)
43 |
44 | def sample_next_point_from_prior(self, t_local, y):
45 | """
46 | t_local: current time point
47 | y: value at the current time point
48 | """
49 | return self.get_ode_gradient_nn(t_local, y)
50 |
51 | #####################################################################################################
52 |
53 | class ODEFunc_w_Poisson(ODEFunc):
54 |
55 | def __init__(self, input_dim, latent_dim, ode_func_net,
56 | lambda_net, device = torch.device("cpu")):
57 | """
58 | input_dim: dimensionality of the input
59 | latent_dim: dimensionality used for ODE. Analog of a continous latent state
60 | """
61 | super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, ode_func_net, device)
62 |
63 | self.latent_ode = ODEFunc(input_dim = input_dim,
64 | latent_dim = latent_dim,
65 | ode_func_net = ode_func_net,
66 | device = device)
67 |
68 | self.latent_dim = latent_dim
69 | self.lambda_net = lambda_net
70 | # The computation of poisson likelihood can become numerically unstable.
71 | #The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T]
72 | #Exponent of lambda can also take large values
73 | # So we divide lambda by the constant and then multiply the integral of lambda by the constant
74 | self.const_for_lambda = torch.Tensor([100.]).to(device)
75 |
76 | def extract_poisson_rate(self, augmented, final_result = True):
77 | y, log_lambdas, int_lambda = None, None, None
78 |
79 | assert(augmented.size(-1) == self.latent_dim + self.input_dim)
80 | latent_lam_dim = self.latent_dim // 2
81 |
82 | if len(augmented.size()) == 3:
83 | int_lambda = augmented[:,:,-self.input_dim:]
84 | y_latent_lam = augmented[:,:,:-self.input_dim]
85 |
86 | log_lambdas = self.lambda_net(y_latent_lam[:,:,-latent_lam_dim:])
87 | y = y_latent_lam[:,:,:-latent_lam_dim]
88 |
89 | elif len(augmented.size()) == 4:
90 | int_lambda = augmented[:,:,:,-self.input_dim:]
91 | y_latent_lam = augmented[:,:,:,:-self.input_dim]
92 |
93 | log_lambdas = self.lambda_net(y_latent_lam[:,:,:,-latent_lam_dim:])
94 | y = y_latent_lam[:,:,:,:-latent_lam_dim]
95 |
96 | # Multiply the intergral over lambda by a constant
97 | # only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn)
98 | if final_result:
99 | int_lambda = int_lambda * self.const_for_lambda
100 |
101 | # Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas)
102 | assert(y.size(-1) == latent_lam_dim)
103 |
104 | return y, log_lambdas, int_lambda, y_latent_lam
105 |
106 |
107 | def get_ode_gradient_nn(self, t_local, augmented):
108 | y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result = False)
109 | dydt_dldt = self.latent_ode(t_local, y_latent_lam)
110 |
111 | log_lam = log_lam - torch.log(self.const_for_lambda)
112 | return torch.cat((dydt_dldt, torch.exp(log_lam)),-1)
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/lib/latent_ode/ode_rnn.py:
--------------------------------------------------------------------------------
1 | ###########################
2 | # Latent ODEs for Irregularly-Sampled Time Series
3 | # Author: Yulia Rubanova
4 | ###########################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.functional import relu
10 |
11 | from . import utils as utils
12 | from .encoder_decoder import *
13 | from .likelihood_eval import *
14 |
15 | from torch.distributions.multivariate_normal import MultivariateNormal
16 | from torch.distributions.normal import Normal
17 | from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase
18 |
19 | from torch.distributions.normal import Normal
20 | from torch.distributions import Independent
21 | from torch.nn.parameter import Parameter
22 | from .base_models import Baseline
23 |
24 |
25 | class ODE_RNN(Baseline):
26 | def __init__(self, input_dim, latent_dim, device = torch.device("cpu"),
27 | z0_diffeq_solver = None, n_gru_units = 100, n_units = 100,
28 | concat_mask = False, obsrv_std = 0.1, use_binary_classif = False,
29 | classif_per_tp = False, n_labels = 1, train_classif_w_reconstr = False):
30 |
31 | Baseline.__init__(self, input_dim, latent_dim, device = device,
32 | obsrv_std = obsrv_std, use_binary_classif = use_binary_classif,
33 | classif_per_tp = classif_per_tp,
34 | n_labels = n_labels,
35 | train_classif_w_reconstr = train_classif_w_reconstr)
36 |
37 | ode_rnn_encoder_dim = latent_dim
38 |
39 | self.ode_gru = Encoder_z0_ODE_RNN(
40 | latent_dim = ode_rnn_encoder_dim,
41 | input_dim = (input_dim) * 2, # input and the mask
42 | z0_diffeq_solver = z0_diffeq_solver,
43 | n_gru_units = n_gru_units,
44 | device = device).to(device)
45 |
46 | self.z0_diffeq_solver = z0_diffeq_solver
47 |
48 | self.decoder = nn.Sequential(
49 | nn.Linear(latent_dim, n_units),
50 | nn.Tanh(),
51 | nn.Linear(n_units, input_dim),)
52 |
53 | utils.init_network_weights(self.decoder)
54 |
55 |
56 | def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps,
57 | mask = None, n_traj_samples = None, mode = None):
58 |
59 | if (len(truth_time_steps) != len(time_steps_to_predict)) or (torch.sum(time_steps_to_predict - truth_time_steps) != 0):
60 | raise Exception("Extrapolation mode not implemented for ODE-RNN")
61 |
62 | # time_steps_to_predict and truth_time_steps should be the same
63 | assert(len(truth_time_steps) == len(time_steps_to_predict))
64 | assert(mask is not None)
65 |
66 | data_and_mask = data
67 | if mask is not None:
68 | data_and_mask = torch.cat([data, mask],-1)
69 |
70 | _, _, latent_ys, _ = self.ode_gru.run_odernn(
71 | data_and_mask, truth_time_steps, run_backwards = False)
72 |
73 | latent_ys = latent_ys.permute(0,2,1,3)
74 | last_hidden = latent_ys[:,:,-1,:]
75 |
76 | #assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)
77 |
78 | outputs = self.decoder(latent_ys)
79 | # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
80 | first_point = data[:,0,:]
81 | outputs = utils.shift_outputs(outputs, first_point)
82 |
83 | extra_info = {"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:])}
84 |
85 | if self.use_binary_classif:
86 | if self.classif_per_tp:
87 | extra_info["label_predictions"] = self.classifier(latent_ys)
88 | else:
89 | extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1)
90 |
91 | # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
92 | return outputs, extra_info
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/lib/linreg.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Dict
2 | import sys
3 | import os
4 | import yaml
5 | import datetime
6 | import argparse
7 | from functools import partial
8 | import optuna
9 |
10 | from darts import models
11 | from darts import metrics
12 | from darts import TimeSeries
13 |
14 | # import data formatter
15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
16 | from data_formatter.base import *
17 | from utils.darts_processing import load_data, reshuffle_data
18 | from utils.darts_evaluation import rescale_and_backtest
19 | from utils.darts_training import print_callback
20 |
21 | # lag setter for covariates
22 | def set_lags(in_len, args):
23 | lags_past_covariates = None
24 | lags_future_covariates = None
25 | if args.use_covs == 'True':
26 | if series['train']['dynamic'] is not None:
27 | lags_past_covariates = in_len
28 | if series['train']['future'] is not None:
29 | lags_future_covariates = (in_len, formatter.params['length_pred'])
30 | return lags_past_covariates, lags_future_covariates
31 |
32 | # define objective function
33 | def objective(trial):
34 | # select input and output chunk lengths
35 | out_len = formatter.params["length_pred"]
36 | in_len = trial.suggest_int("in_len", 12, formatter.params["max_length_input"], step=12) # at least 2 hours of predictions left
37 | lags_past_covariates, lags_future_covariates = set_lags(in_len, args)
38 |
39 | # build the Linear Regression model
40 | model = models.LinearRegressionModel(lags = in_len,
41 | lags_past_covariates = lags_past_covariates,
42 | lags_future_covariates = lags_future_covariates,
43 | output_chunk_length = out_len)
44 |
45 | # train the model
46 | model.fit(series['train']['target'],
47 | past_covariates=series['train']['dynamic'],
48 | future_covariates=series['train']['future'])
49 |
50 | # backtest on the validation set
51 | errors = model.backtest(series['val']['target'],
52 | past_covariates=series['val']['dynamic'],
53 | future_covariates=series['val']['future'],
54 | forecast_horizon=out_len,
55 | stride=out_len,
56 | retrain=False,
57 | verbose=False,
58 | metric=metrics.rmse,
59 | last_points_only=False,
60 | )
61 | avg_error = np.mean(errors)
62 |
63 | return avg_error
64 |
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--dataset', type=str, default='weinstock')
67 | parser.add_argument('--use_covs', type=str, default='False')
68 | parser.add_argument('--optuna', type=str, default='True')
69 | parser.add_argument('--reduction1', type=str, default='mean')
70 | parser.add_argument('--reduction2', type=str, default='median')
71 | parser.add_argument('--reduction3', type=str, default=None)
72 | args = parser.parse_args()
73 | reductions = [args.reduction1, args.reduction2, args.reduction3]
74 | if __name__ == '__main__':
75 | # load data
76 | study_file = f'./output/linreg_{args.dataset}.txt' if args.use_covs == 'False' \
77 | else f'./output/linreg_covariates_{args.dataset}.txt'
78 | if not os.path.exists(study_file):
79 | with open(study_file, "w") as f:
80 | f.write(f"Optimization started at {datetime.datetime.now()}\n")
81 | formatter, series, scalers = load_data(study_file=study_file,
82 | dataset=args.dataset,
83 | use_covs=True if args.use_covs == 'True' else False,
84 | cov_type='mixed',
85 | use_static_covs=True)
86 |
87 | # hyperparameter optimization
88 | best_params = None
89 | if args.optuna == 'True':
90 | study = optuna.create_study(direction="minimize")
91 | print_call = partial(print_callback, study_file=study_file)
92 | study.optimize(objective, n_trials=50,
93 | callbacks=[print_call],
94 | catch=(np.linalg.LinAlgError, KeyError))
95 | best_params = study.best_trial.params
96 | else:
97 | key = "linreg_covariates" if args.use_covs == 'True' else "linreg"
98 | assert formatter.params[key] is not None, "No saved hyperparameters found for this model"
99 | best_params = formatter.params[key]
100 |
101 | # select best hyperparameters
102 | in_len = best_params['in_len']
103 | out_len = formatter.params["length_pred"]
104 | stride = out_len // 2
105 | lags_past_covariates, lags_future_covariates = set_lags(in_len, args)
106 |
107 | # test on ID and OOD data
108 | seeds = list(range(10, 20))
109 | id_errors_cv = {key: [] for key in reductions if key is not None}
110 | ood_errors_cv = {key: [] for key in reductions if key is not None}
111 | id_likelihoods_cv = []; ood_likelihoods_cv = []
112 | id_cal_errors_cv = []; ood_cal_errors_cv = []
113 | for seed in seeds:
114 | formatter, series, scalers = reshuffle_data(formatter,
115 | seed,
116 | use_covs=True if args.use_covs == 'True' else False,
117 | cov_type='mixed',
118 | use_static_covs=True)
119 | # build the model
120 | model = models.LinearRegressionModel(lags = in_len,
121 | lags_past_covariates = lags_past_covariates,
122 | lags_future_covariates = lags_future_covariates,
123 | output_chunk_length = formatter.params['length_pred'])
124 | # train the model
125 | model.fit(series['train']['target'],
126 | past_covariates=series['train']['dynamic'],
127 | future_covariates=series['train']['future'])
128 |
129 | # backtest on the test set
130 | forecasts = model.historical_forecasts(series['test']['target'],
131 | past_covariates = series['test']['dynamic'],
132 | future_covariates = series['test']['future'],
133 | forecast_horizon=out_len,
134 | stride=stride,
135 | retrain=False,
136 | verbose=False,
137 | last_points_only=False,
138 | start=formatter.params["max_length_input"])
139 | id_errors_sample, \
140 | id_likelihood_sample, \
141 | id_cal_errors_sample = rescale_and_backtest(series['test']['target'],
142 | forecasts,
143 | [metrics.mse, metrics.mae],
144 | scalers['target'],
145 | reduction=None)
146 | # backtest on the OOD set
147 | forecasts = model.historical_forecasts(series['test_ood']['target'],
148 | past_covariates = series['test_ood']['dynamic'],
149 | future_covariates = series['test_ood']['future'],
150 | forecast_horizon=out_len,
151 | stride=stride,
152 | retrain=False,
153 | verbose=False,
154 | last_points_only=False,
155 | start=formatter.params["max_length_input"])
156 |
157 | ood_errors_sample, \
158 | ood_likelihood_sample, \
159 | ood_cal_errors_sample = rescale_and_backtest(series['test_ood']['target'],
160 | forecasts,
161 | [metrics.mse, metrics.mae],
162 | scalers['target'],
163 | reduction=None)
164 |
165 | # compute, save, and print results
166 | with open(study_file, "a") as f:
167 | for reduction in reductions:
168 | if reduction is not None:
169 | # compute
170 | reduction_f = getattr(np, reduction)
171 | id_errors_sample_red = reduction_f(id_errors_sample, axis=0)
172 | ood_errors_sample_red = reduction_f(ood_errors_sample, axis=0)
173 | # save
174 | id_errors_cv[reduction].append(id_errors_sample_red)
175 | ood_errors_cv[reduction].append(ood_errors_sample_red)
176 | # print
177 | f.write(f"\tSeed: {seed} ID {reduction} of (MSE, MAE): {id_errors_sample_red.tolist()}\n")
178 | f.write(f"\tSeed: {seed} OOD {reduction} of (MSE, MAE) stats: {ood_errors_sample_red.tolist()}\n")
179 | # save
180 | id_likelihoods_cv.append(id_likelihood_sample)
181 | ood_likelihoods_cv.append(ood_likelihood_sample)
182 | id_cal_errors_cv.append(id_cal_errors_sample)
183 | ood_cal_errors_cv.append(ood_cal_errors_sample)
184 | # print
185 | f.write(f"\tSeed: {seed} ID likelihoods: {id_likelihood_sample}\n")
186 | f.write(f"\tSeed: {seed} OOD likelihoods: {ood_likelihood_sample}\n")
187 | f.write(f"\tSeed: {seed} ID calibration errors: {id_cal_errors_sample.tolist()}\n")
188 | f.write(f"\tSeed: {seed} OOD calibration errors: {ood_cal_errors_sample.tolist()}\n")
189 |
190 | # compute, save, and print results
191 | with open(study_file, "a") as f:
192 | for reduction in reductions:
193 | if reduction is not None:
194 | # compute
195 | id_errors_cv[reduction] = np.vstack(id_errors_cv[reduction])
196 | ood_errors_cv[reduction] = np.vstack(ood_errors_cv[reduction])
197 | id_errors_cv[reduction] = np.mean(id_errors_cv[reduction], axis=0)
198 | ood_errors_cv[reduction] = np.mean(ood_errors_cv[reduction], axis=0)
199 | # print
200 | f.write(f"ID {reduction} of (MSE, MAE): {id_errors_cv[reduction].tolist()}\n")
201 | f.write(f"OOD {reduction} of (MSE, MAE): {ood_errors_cv[reduction].tolist()}\n")
202 | # compute
203 | id_likelihoods_cv = np.mean(id_likelihoods_cv)
204 | ood_likelihoods_cv = np.mean(ood_likelihoods_cv)
205 | id_cal_errors_cv = np.vstack(id_cal_errors_cv)
206 | ood_cal_errors_cv = np.vstack(ood_cal_errors_cv)
207 | id_cal_errors_cv = np.mean(id_cal_errors_cv, axis=0)
208 | ood_cal_errors_cv = np.mean(ood_cal_errors_cv, axis=0)
209 | # print
210 | f.write(f"ID likelihoods: {id_likelihoods_cv}\n")
211 | f.write(f"OOD likelihoods: {ood_likelihoods_cv}\n")
212 | f.write(f"ID calibration errors: {id_cal_errors_cv.tolist()}\n")
213 | f.write(f"OOD calibration errors: {ood_cal_errors_cv.tolist()}\n")
--------------------------------------------------------------------------------
/output/arima_colas.txt:
--------------------------------------------------------------------------------
1 | Optimization started at 2023-03-22 17:28:57.493837
2 | --------------------------------
3 | Loading column definition...
4 | Checking column definition...
5 | Loading data...
6 | Dropping columns / rows...
7 | Checking for NA values...
8 | Setting data types...
9 | Dropping columns / rows...
10 | Encoding data...
11 | Updated column definition:
12 | id: REAL_VALUED (ID)
13 | time: DATE (TIME)
14 | gl: REAL_VALUED (TARGET)
15 | gender: REAL_VALUED (STATIC_INPUT)
16 | age: REAL_VALUED (STATIC_INPUT)
17 | BMI: REAL_VALUED (STATIC_INPUT)
18 | glycaemia: REAL_VALUED (STATIC_INPUT)
19 | HbA1c: REAL_VALUED (STATIC_INPUT)
20 | follow.up: REAL_VALUED (STATIC_INPUT)
21 | T2DM: REAL_VALUED (STATIC_INPUT)
22 | time_year: REAL_VALUED (KNOWN_INPUT)
23 | time_month: REAL_VALUED (KNOWN_INPUT)
24 | time_day: REAL_VALUED (KNOWN_INPUT)
25 | time_hour: REAL_VALUED (KNOWN_INPUT)
26 | time_minute: REAL_VALUED (KNOWN_INPUT)
27 | Interpolating data...
28 | Dropped segments: 63
29 | Extracted segments: 205
30 | Interpolated values: 241
31 | Percent of values interpolated: 0.22%
32 | Splitting data...
33 | Train: 72275 (45.89%)
34 | Val: 35713 (22.68%)
35 | Test: 38253 (24.29%)
36 | Test OOD: 11242 (7.14%)
37 | Scaling data...
38 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
39 | Data formatting complete.
40 | --------------------------------
41 | Train: 72173 (45.75%)
42 | Val: 35885 (22.74%)
43 | Test: 38253 (24.25%)
44 | Test OOD: 11460 (7.26%)
45 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
46 | Seed: 10 ID mean of (MSE, MAE): [120.6773731841762, 6.6476186379323945]
47 | Seed: 10 OOD mean of (MSE, MAE) stats: [98.49733495015732, 5.962563865719668]
48 | Seed: 10 ID median of (MSE, MAE): [34.119099014633285, 4.8265690635532]
49 | Seed: 10 OOD median of (MSE, MAE) stats: [25.749051321984638, 4.250000388361499]
50 | Train: 71945 (45.73%)
51 | Val: 35713 (22.70%)
52 | Test: 38037 (24.18%)
53 | Test OOD: 11644 (7.40%)
54 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
55 | Seed: 11 ID mean of (MSE, MAE): [110.98849348314229, 6.512092322826006]
56 | Seed: 11 OOD mean of (MSE, MAE) stats: [143.8660811618338, 6.893749882596932]
57 | Seed: 11 ID median of (MSE, MAE): [33.800056814408784, 4.806109972796595]
58 | Seed: 11 OOD median of (MSE, MAE) stats: [31.027526450957808, 4.534814927485767]
59 | Train: 71565 (45.53%)
60 | Val: 35497 (22.58%)
61 | Test: 38037 (24.20%)
62 | Test OOD: 12096 (7.69%)
63 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
64 | Seed: 12 ID mean of (MSE, MAE): [116.43544050531845, 6.5224223472168985]
65 | Seed: 12 OOD mean of (MSE, MAE) stats: [172.39000789417685, 7.705362619675043]
66 | Seed: 12 ID median of (MSE, MAE): [33.086893503512236, 4.795547888826327]
67 | Seed: 12 OOD median of (MSE, MAE) stats: [45.878204185631716, 5.544957629357246]
68 | Train: 73201 (46.27%)
69 | Val: 36332 (22.97%)
70 | Test: 38469 (24.32%)
71 | Test OOD: 10201 (6.45%)
72 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
73 | Seed: 13 ID mean of (MSE, MAE): [120.10565412076012, 6.670143753397182]
74 | Seed: 13 OOD mean of (MSE, MAE) stats: [156.32090280031198, 7.36687788419089]
75 | Seed: 13 ID median of (MSE, MAE): [34.19618027031381, 4.847393085094296]
76 | Seed: 13 OOD median of (MSE, MAE) stats: [40.91801930218372, 5.3125520415657235]
77 | Train: 72721 (45.97%)
78 | Val: 36577 (23.12%)
79 | Test: 38240 (24.17%)
80 | Test OOD: 10665 (6.74%)
81 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
82 | Seed: 14 ID mean of (MSE, MAE): [119.96416492732897, 6.6778244052465086]
83 | Seed: 14 OOD mean of (MSE, MAE) stats: [116.36936191941496, 6.324921143713053]
84 | Seed: 14 ID median of (MSE, MAE): [34.350508699577546, 4.845801287754875]
85 | Seed: 14 OOD median of (MSE, MAE) stats: [28.11721839067325, 4.371174125759447]
86 | Train: 72280 (45.90%)
87 | Val: 35929 (22.81%)
88 | Test: 38037 (24.15%)
89 | Test OOD: 11237 (7.14%)
90 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
91 | Seed: 15 ID mean of (MSE, MAE): [115.49432265308754, 6.4833270353370915]
92 | Seed: 15 OOD mean of (MSE, MAE) stats: [141.04209618074884, 7.00637395501393]
93 | Seed: 15 ID median of (MSE, MAE): [32.41810861864848, 4.718699547205426]
94 | Seed: 15 OOD median of (MSE, MAE) stats: [32.61929274481187, 4.849921090236984]
95 | Train: 71826 (45.65%)
96 | Val: 35713 (22.70%)
97 | Test: 38037 (24.18%)
98 | Test OOD: 11763 (7.48%)
99 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
100 | Seed: 16 ID mean of (MSE, MAE): [118.60316998673208, 6.570685408763114]
101 | Seed: 16 OOD mean of (MSE, MAE) stats: [172.53892554528767, 7.4691531687181705]
102 | Seed: 16 ID median of (MSE, MAE): [32.9922624794985, 4.776514077411481]
103 | Seed: 16 OOD median of (MSE, MAE) stats: [37.722542194470094, 5.025782263597927]
104 | Train: 72187 (45.92%)
105 | Val: 35497 (22.58%)
106 | Test: 38037 (24.20%)
107 | Test OOD: 11474 (7.30%)
108 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
109 | Seed: 17 ID mean of (MSE, MAE): [118.4226276104948, 6.565824634002014]
110 | Seed: 17 OOD mean of (MSE, MAE) stats: [120.57096069178137, 6.718860547848144]
111 | Seed: 17 ID median of (MSE, MAE): [32.75335854245767, 4.776514077411481]
112 | Seed: 17 OOD median of (MSE, MAE) stats: [36.189480921506764, 4.916667081997722]
113 | Train: 71880 (45.73%)
114 | Val: 35497 (22.58%)
115 | Test: 38037 (24.20%)
116 | Test OOD: 11781 (7.49%)
117 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
118 | Seed: 18 ID mean of (MSE, MAE): [118.39438829343224, 6.60487045660232]
119 | Seed: 18 OOD mean of (MSE, MAE) stats: [146.7202339119919, 7.267020937497021]
120 | Seed: 18 ID median of (MSE, MAE): [34.3461620511888, 4.825534360244473]
121 | Seed: 18 OOD median of (MSE, MAE) stats: [34.70171129616214, 4.99999870546165]
122 | Train: 72349 (45.90%)
123 | Val: 36145 (22.93%)
124 | Test: 38037 (24.13%)
125 | Test OOD: 11096 (7.04%)
126 | Scaled columns: ['id', 'gl', 'gender', 'age', 'BMI', 'glycaemia', 'HbA1c', 'follow.up', 'T2DM', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
127 | Seed: 19 ID mean of (MSE, MAE): [121.40247352207783, 6.655633166667995]
128 | Seed: 19 OOD mean of (MSE, MAE) stats: [137.0253490270386, 6.895576622625327]
129 | Seed: 19 ID median of (MSE, MAE): [33.92787207032102, 4.81699414126529]
130 | Seed: 19 OOD median of (MSE, MAE) stats: [36.766653251293356, 4.943558197736419]
131 | ID mean of (MSE, MAE): [118.04881082865504, 6.591044216799152] +- [2.9257223155159435, 0.06704249579628511]
132 | OOD mean of (MSE, MAE): [140.53412540827432, 6.961046062759818] +- [22.582784874200478, 0.5041473132551306]
133 | ID median of (MSE, MAE): [33.59905020645601, 4.803567750156345] +- [0.6810816239186338, 0.036923417227094066]
134 | OOD median of (MSE, MAE): [34.968970005967535, 4.874942645156038] +- [5.654969833183597, 0.3802513881996165]
135 |
--------------------------------------------------------------------------------
/output/arima_dubosson.txt:
--------------------------------------------------------------------------------
1 | Optimization started at 2023-03-22 17:28:57.405260
2 | --------------------------------
3 | Loading column definition...
4 | Checking column definition...
5 | Loading data...
6 | Dropping columns / rows...
7 | Checking for NA values...
8 | Setting data types...
9 | Dropping columns / rows...
10 | Encoding data...
11 | Updated column definition:
12 | id: REAL_VALUED (ID)
13 | time: DATE (TIME)
14 | gl: REAL_VALUED (TARGET)
15 | fast_insulin: REAL_VALUED (OBSERVED_INPUT)
16 | slow_insulin: REAL_VALUED (OBSERVED_INPUT)
17 | calories: REAL_VALUED (OBSERVED_INPUT)
18 | balance: REAL_VALUED (OBSERVED_INPUT)
19 | quality: REAL_VALUED (OBSERVED_INPUT)
20 | HR: REAL_VALUED (OBSERVED_INPUT)
21 | BR: REAL_VALUED (OBSERVED_INPUT)
22 | Posture: REAL_VALUED (OBSERVED_INPUT)
23 | Activity: REAL_VALUED (OBSERVED_INPUT)
24 | HRV: REAL_VALUED (OBSERVED_INPUT)
25 | CoreTemp: REAL_VALUED (OBSERVED_INPUT)
26 | time_year: REAL_VALUED (KNOWN_INPUT)
27 | time_month: REAL_VALUED (KNOWN_INPUT)
28 | time_day: REAL_VALUED (KNOWN_INPUT)
29 | time_hour: REAL_VALUED (KNOWN_INPUT)
30 | time_minute: REAL_VALUED (KNOWN_INPUT)
31 | Interpolating data...
32 | Dropped segments: 1
33 | Extracted segments: 8
34 | Interpolated values: 0
35 | Percent of values interpolated: 0.00%
36 | Splitting data...
37 | Train: 4654 (47.17%)
38 | Val: 2016 (20.43%)
39 | Test: 2057 (20.85%)
40 | Test OOD: 1140 (11.55%)
41 | Scaling data...
42 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
43 | Data formatting complete.
44 | --------------------------------
45 | Train: 4825 (48.90%)
46 | Val: 2016 (20.43%)
47 | Test: 2057 (20.85%)
48 | Test OOD: 969 (9.82%)
49 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
50 | Seed: 10 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993]
51 | Seed: 10 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358]
52 | Seed: 10 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072]
53 | Seed: 10 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295]
54 | Train: 4825 (48.90%)
55 | Val: 2016 (20.43%)
56 | Test: 2057 (20.85%)
57 | Test OOD: 969 (9.82%)
58 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
59 | Seed: 11 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993]
60 | Seed: 11 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358]
61 | Seed: 11 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072]
62 | Seed: 11 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295]
63 | Train: 4514 (45.75%)
64 | Val: 2016 (20.43%)
65 | Test: 2057 (20.85%)
66 | Test OOD: 1280 (12.97%)
67 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
68 | Seed: 12 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009]
69 | Seed: 12 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665]
70 | Seed: 12 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228]
71 | Seed: 12 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026]
72 | Train: 4738 (48.02%)
73 | Val: 2016 (20.43%)
74 | Test: 2057 (20.85%)
75 | Test OOD: 1056 (10.70%)
76 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
77 | Seed: 13 ID mean of (MSE, MAE): [1335.955292005022, 16.945752079471546]
78 | Seed: 13 OOD mean of (MSE, MAE) stats: [1267.9886795415546, 21.3231923989978]
79 | Seed: 13 ID median of (MSE, MAE): [146.7812227135625, 10.124277894741939]
80 | Seed: 13 OOD median of (MSE, MAE) stats: [307.56874380049413, 14.881323057095171]
81 | Train: 4654 (47.17%)
82 | Val: 2016 (20.43%)
83 | Test: 2057 (20.85%)
84 | Test OOD: 1140 (11.55%)
85 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
86 | Seed: 14 ID mean of (MSE, MAE): [1562.139924622624, 19.56357352526093]
87 | Seed: 14 OOD mean of (MSE, MAE) stats: [575.304073148924, 15.900844263774417]
88 | Seed: 14 ID median of (MSE, MAE): [202.41938206965023, 11.364012783638028]
89 | Seed: 14 OOD median of (MSE, MAE) stats: [243.24603852839633, 12.629752971974511]
90 | Train: 4885 (49.51%)
91 | Val: 2016 (20.43%)
92 | Test: 2057 (20.85%)
93 | Test OOD: 909 (9.21%)
94 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
95 | Seed: 15 ID mean of (MSE, MAE): [1529.236218668717, 19.214881713534723]
96 | Seed: 15 OOD mean of (MSE, MAE) stats: [306.96453496345606, 11.01309437759319]
97 | Seed: 15 ID median of (MSE, MAE): [171.44486633813915, 10.68677369459521]
98 | Seed: 15 OOD median of (MSE, MAE) stats: [92.69485559816981, 7.574119570527739]
99 | Train: 4825 (48.90%)
100 | Val: 2016 (20.43%)
101 | Test: 2057 (20.85%)
102 | Test OOD: 969 (9.82%)
103 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
104 | Seed: 16 ID mean of (MSE, MAE): [630.3234671862778, 14.982298035310993]
105 | Seed: 16 OOD mean of (MSE, MAE) stats: [2311.552668832363, 29.388185991283358]
106 | Seed: 16 ID median of (MSE, MAE): [148.09641012473497, 10.29883954624072]
107 | Seed: 16 OOD median of (MSE, MAE) stats: [686.3397895458669, 21.1499645439295]
108 | Train: 4514 (45.75%)
109 | Val: 2016 (20.43%)
110 | Test: 2057 (20.85%)
111 | Test OOD: 1280 (12.97%)
112 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
113 | Seed: 17 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009]
114 | Seed: 17 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665]
115 | Seed: 17 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228]
116 | Seed: 17 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026]
117 | Train: 4514 (45.75%)
118 | Val: 2016 (20.43%)
119 | Test: 2057 (20.85%)
120 | Test OOD: 1280 (12.97%)
121 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
122 | Seed: 18 ID mean of (MSE, MAE): [1568.9926209773005, 19.97402098946009]
123 | Seed: 18 OOD mean of (MSE, MAE) stats: [390.67696842721256, 13.13029240748665]
124 | Seed: 18 ID median of (MSE, MAE): [239.79360040284143, 12.464939426312228]
125 | Seed: 18 OOD median of (MSE, MAE) stats: [168.3885649689339, 10.784758322794026]
126 | Train: 4738 (48.02%)
127 | Val: 2016 (20.43%)
128 | Test: 2057 (20.85%)
129 | Test OOD: 1056 (10.70%)
130 | Scaled columns: ['id', 'gl', 'fast_insulin', 'slow_insulin', 'calories', 'balance', 'quality', 'HR', 'BR', 'Posture', 'Activity', 'HRV', 'CoreTemp', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute']
131 | Seed: 19 ID mean of (MSE, MAE): [1335.955292005022, 16.945752079471546]
132 | Seed: 19 OOD mean of (MSE, MAE) stats: [1267.9886795415546, 21.3231923989978]
133 | Seed: 19 ID median of (MSE, MAE): [146.7812227135625, 10.124277894741939]
134 | Seed: 19 OOD median of (MSE, MAE) stats: [307.56874380049413, 14.881323057095171]
135 | ID mean of (MSE, MAE): [1236.123499179212, 17.7538916472052] +- [405.6502363774581, 2.1106334578285266]
136 | OOD mean of (MSE, MAE): [1152.4934878974216, 19.711575863567326] +- [827.4111971731463, 7.114984874764907]
137 | ID median of (MSE, MAE): [183.1096725417644, 11.059067918537597] +- [40.57762165059761, 0.9819635818277976]
138 | OOD median of (MSE, MAE): [351.52634452719565, 14.577068725686317] +- [227.92572769481512, 4.75154932415569]
139 |
--------------------------------------------------------------------------------
/output/arima_iglu.txt:
--------------------------------------------------------------------------------
1 | Optimization started at 2023-03-22 17:28:57.399282
2 | --------------------------------
3 | Loading column definition...
4 | Checking column definition...
5 | Loading data...
6 | Dropping columns / rows...
7 | Checking for NA values...
8 | Setting data types...
9 | Dropping columns / rows...
10 | Encoding data...
11 | Updated column definition:
12 | id: REAL_VALUED (ID)
13 | time: DATE (TIME)
14 | gl: REAL_VALUED (TARGET)
15 | time_year: REAL_VALUED (KNOWN_INPUT)
16 | time_month: REAL_VALUED (KNOWN_INPUT)
17 | time_day: REAL_VALUED (KNOWN_INPUT)
18 | time_hour: REAL_VALUED (KNOWN_INPUT)
19 | time_minute: REAL_VALUED (KNOWN_INPUT)
20 | time_second: REAL_VALUED (KNOWN_INPUT)
21 | Interpolating data...
22 | Dropped segments: 17
23 | Extracted segments: 15
24 | Interpolated values: 561
25 | Percent of values interpolated: 4.37%
26 | Splitting data...
27 | Train: 9056 (64.79%)
28 | Val: 1774 (12.69%)
29 | Test: 1848 (13.22%)
30 | Test OOD: 1300 (9.30%)
31 | Scaling data...
32 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
33 | Data formatting complete.
34 | --------------------------------
35 | Train: 9056 (64.79%)
36 | Val: 1774 (12.69%)
37 | Test: 1848 (13.22%)
38 | Test OOD: 1300 (9.30%)
39 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
40 | Seed: 10 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166]
41 | Seed: 10 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754]
42 | Seed: 10 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182]
43 | Seed: 10 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878]
44 | Train: 9056 (64.79%)
45 | Val: 1774 (12.69%)
46 | Test: 1848 (13.22%)
47 | Test OOD: 1300 (9.30%)
48 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
49 | Seed: 11 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166]
50 | Seed: 11 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754]
51 | Seed: 11 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182]
52 | Seed: 11 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878]
53 | Train: 8110 (59.66%)
54 | Val: 1342 (9.87%)
55 | Test: 2017 (14.84%)
56 | Test OOD: 2125 (15.63%)
57 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
58 | Seed: 12 ID mean of (MSE, MAE): [427.0167933906876, 12.995149726570737]
59 | Seed: 12 OOD mean of (MSE, MAE) stats: [531.8962502619843, 14.276775793295084]
60 | Seed: 12 ID median of (MSE, MAE): [122.93745002201426, 9.253760038373024]
61 | Seed: 12 OOD median of (MSE, MAE) stats: [144.73854388702725, 10.216134214504272]
62 | Train: 7661 (55.57%)
63 | Val: 1296 (9.40%)
64 | Test: 2017 (14.63%)
65 | Test OOD: 2812 (20.40%)
66 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
67 | Seed: 13 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374]
68 | Seed: 13 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812]
69 | Seed: 13 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416]
70 | Seed: 13 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078]
71 | Train: 7661 (55.57%)
72 | Val: 1296 (9.40%)
73 | Test: 2017 (14.63%)
74 | Test OOD: 2812 (20.40%)
75 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
76 | Seed: 14 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374]
77 | Seed: 14 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812]
78 | Seed: 14 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416]
79 | Seed: 14 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078]
80 | Train: 9056 (64.79%)
81 | Val: 1774 (12.69%)
82 | Test: 1848 (13.22%)
83 | Test OOD: 1300 (9.30%)
84 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
85 | Seed: 15 ID mean of (MSE, MAE): [389.2651923831744, 13.158376410922166]
86 | Seed: 15 OOD mean of (MSE, MAE) stats: [444.1124199595163, 12.280007585027754]
87 | Seed: 15 ID median of (MSE, MAE): [137.96437105020195, 9.472417880109182]
88 | Seed: 15 OOD median of (MSE, MAE) stats: [73.18435077204008, 7.078929258345878]
89 | Train: 8110 (59.66%)
90 | Val: 1342 (9.87%)
91 | Test: 2017 (14.84%)
92 | Test OOD: 2125 (15.63%)
93 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
94 | Seed: 16 ID mean of (MSE, MAE): [427.0167933906876, 12.995149726570737]
95 | Seed: 16 OOD mean of (MSE, MAE) stats: [531.8962502619843, 14.276775793295084]
96 | Seed: 16 ID median of (MSE, MAE): [122.93745002201426, 9.253760038373024]
97 | Seed: 16 OOD median of (MSE, MAE) stats: [144.73854388702725, 10.216134214504272]
98 | Train: 7643 (55.44%)
99 | Val: 1342 (9.73%)
100 | Test: 1897 (13.76%)
101 | Test OOD: 2904 (21.06%)
102 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
103 | Seed: 17 ID mean of (MSE, MAE): [415.69378849524696, 12.258175847193586]
104 | Seed: 17 OOD mean of (MSE, MAE) stats: [917.813246785934, 18.642029569388345]
105 | Seed: 17 ID median of (MSE, MAE): [83.99256708897218, 7.697683488797355]
106 | Seed: 17 OOD median of (MSE, MAE) stats: [224.85990324440925, 12.498202781002497]
107 | Train: 7643 (55.44%)
108 | Val: 1342 (9.73%)
109 | Test: 1897 (13.76%)
110 | Test OOD: 2904 (21.06%)
111 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
112 | Seed: 18 ID mean of (MSE, MAE): [415.69378849524696, 12.258175847193586]
113 | Seed: 18 OOD mean of (MSE, MAE) stats: [917.813246785934, 18.642029569388345]
114 | Seed: 18 ID median of (MSE, MAE): [83.99256708897218, 7.697683488797355]
115 | Seed: 18 OOD median of (MSE, MAE) stats: [224.85990324440925, 12.498202781002497]
116 | Train: 7661 (55.57%)
117 | Val: 1296 (9.40%)
118 | Test: 2017 (14.63%)
119 | Test OOD: 2812 (20.40%)
120 | Scaled columns: ['id', 'gl', 'time_year', 'time_month', 'time_day', 'time_hour', 'time_minute', 'time_second']
121 | Seed: 19 ID mean of (MSE, MAE): [433.3503441184686, 12.76526222735374]
122 | Seed: 19 OOD mean of (MSE, MAE) stats: [376.08592954422636, 13.031398874300812]
123 | Seed: 19 ID median of (MSE, MAE): [93.73748299466513, 8.118844711107416]
124 | Seed: 19 OOD median of (MSE, MAE) stats: [140.69631926285496, 10.14154025440078]
125 | ID mean of (MSE, MAE): [415.3267773276798, 12.827756706235636] +- [18.126631939072656, 0.32319195589326777]
126 | OOD mean of (MSE, MAE): [536.0014042607065, 14.177183010335256] +- [198.40461290827625, 2.337217034828997]
127 | ID median of (MSE, MAE): [110.89655963565743, 8.667667482799056] +- [21.952031091293335, 0.7358828558087035]
128 | OOD median of (MSE, MAE): [138.0838904367558, 9.709008252925353] +- [52.73049091547017, 1.933565723289778]
129 |
--------------------------------------------------------------------------------
/paper_results/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import numpy as np
4 | from typing import Sequence, Dict
5 |
6 | def avg_results(model_names: str,
7 | model_names_with_covs: str = None,
8 | time_steps: int = 12)->Sequence[Dict[str, Dict[str, np.array]]]:
9 | """
10 | Function to load final model results: averaged across random seeds / folds
11 | ----------
12 | model_names: str
13 | path to models' results file
14 | model_name_with_covs: str
15 | path to models' results file with covariates
16 | time_steps: int
17 | number of time steps that were predicted
18 | NOTE: model_names and model_names_with_covs should be in the same order
19 |
20 | Output
21 | ------
22 | Computes the following set of dictionaries:
23 | dict1:
24 | Dictionary of MSE / MAE values for ID / OOD sets with and without covariates
25 | key1: id / od, key2: covs / no_covs
26 | dict2:
27 | Dictionary of likelihood / calibration values for ID / OOD sets with and without covariates
28 | key1: id / od, key2: covs / no_covs
29 | """
30 | def parser(model_names):
31 | arr_id_errors = np.full((len(model_names), 2), np.nan)
32 | arr_ood_errors = arr_id_errors.copy()
33 | arr_id_likelihoods = arr_id_errors.copy()
34 | arr_ood_likelihoods = arr_id_errors.copy()
35 | arr_id_errors_std = np.full((len(model_names), 2, 2), np.nan)
36 | arr_ood_errors_std = arr_id_errors_std.copy()
37 | arr_id_likelihoods_std = arr_id_errors_std.copy()
38 | arr_ood_likelihoods_std = arr_id_errors_std.copy()
39 | for model_name in model_names:
40 | if not os.path.isfile(model_name):
41 | continue
42 | with open(model_name, 'r') as f:
43 | for line in f:
44 | if line.startswith('ID median of (MSE, MAE):'):
45 | id_mse_mae = re.findall(r'\d+\.\d+(?:e-\d+)?', line)
46 | arr_id_errors[model_names.index(model_name), 0] = float(id_mse_mae[0])
47 | arr_id_errors[model_names.index(model_name), 1] = float(id_mse_mae[1])
48 | if len(id_mse_mae) > 2:
49 | arr_id_errors_std[model_names.index(model_name), 0, 0] = float(id_mse_mae[2])
50 | arr_id_errors_std[model_names.index(model_name), 0, 1] = float(id_mse_mae[3])
51 | if len(id_mse_mae) > 4:
52 | arr_id_errors_std[model_names.index(model_name), 1, 0] = float(id_mse_mae[4])
53 | arr_id_errors_std[model_names.index(model_name), 1, 1] = float(id_mse_mae[5])
54 | elif line.startswith('OOD median of (MSE, MAE):'):
55 | ood_mse_mae = re.findall(r'\d+\.\d+(?:e-\d+)?', line)
56 | arr_ood_errors[model_names.index(model_name), 0] = float(ood_mse_mae[0])
57 | arr_ood_errors[model_names.index(model_name), 1] = float(ood_mse_mae[1])
58 | if len(ood_mse_mae) > 2:
59 | arr_ood_errors_std[model_names.index(model_name), 0, 0] = float(ood_mse_mae[2])
60 | arr_ood_errors_std[model_names.index(model_name), 0, 1] = float(ood_mse_mae[3])
61 | if len(ood_mse_mae) > 4:
62 | arr_ood_errors_std[model_names.index(model_name), 1, 0] = float(ood_mse_mae[4])
63 | arr_ood_errors_std[model_names.index(model_name), 1, 1] = float(ood_mse_mae[5])
64 | elif line.startswith('ID likelihoods:'):
65 | id_likelihoods = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line)
66 | arr_id_likelihoods[model_names.index(model_name), 0] = float(id_likelihoods[0])
67 | if len(id_likelihoods) > 1:
68 | arr_id_likelihoods_std[model_names.index(model_name), 0, 0] = float(id_likelihoods[1])
69 | if len(id_likelihoods) > 2:
70 | arr_id_likelihoods_std[model_names.index(model_name), 1, 0] = float(id_likelihoods[2])
71 | elif line.startswith('OOD likelihoods:'):
72 | ood_likelihoods = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line)
73 | arr_ood_likelihoods[model_names.index(model_name), 0] = float(ood_likelihoods[0])
74 | if len(ood_likelihoods) > 1:
75 | arr_ood_likelihoods_std[model_names.index(model_name), 0, 0] = float(ood_likelihoods[1])
76 | if len(ood_likelihoods) > 2:
77 | arr_ood_likelihoods_std[model_names.index(model_name), 1, 0] = float(ood_likelihoods[2])
78 | elif line.startswith('ID calibration errors:'):
79 | id_calib = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line)
80 | arr_id_likelihoods[model_names.index(model_name), 1] = np.mean([float(x) for x in id_calib[:time_steps]])
81 | if len(id_calib) > time_steps:
82 | arr_id_likelihoods_std[model_names.index(model_name), 0, 1] = np.mean([float(x) for x in id_calib[time_steps:]])
83 | if len(id_calib) > 2*time_steps:
84 | arr_id_likelihoods_std[model_names.index(model_name), 1, 1] = np.mean([float(x) for x in id_calib[2*time_steps:]])
85 | elif line.startswith('OOD calibration errors:'):
86 | ood_calib = re.findall(r'-?\d+\.\d+(?:e-\d+)?', line)
87 | arr_ood_likelihoods[model_names.index(model_name), 1] = np.mean([float(x) for x in ood_calib[:time_steps]])
88 | if len(ood_calib) > time_steps:
89 | arr_ood_likelihoods_std[model_names.index(model_name), 0, 1] = np.mean([float(x) for x in ood_calib[time_steps:]])
90 | if len(ood_calib) > 2*time_steps:
91 | arr_ood_likelihoods_std[model_names.index(model_name), 1, 1] = np.mean([float(x) for x in ood_calib[2*time_steps:]])
92 | return (arr_id_errors, arr_ood_errors, arr_id_likelihoods, arr_ood_likelihoods), \
93 | (arr_id_errors_std, arr_ood_errors_std, arr_id_likelihoods_std, arr_ood_likelihoods_std)
94 |
95 | dict_errors, dict_errors_std = {}, {}
96 | dict_likelihoods, dict_likelihoods_std = {}, {}
97 | error, error_std = parser(model_names)
98 | dict_errors['id'] = {'no_covs': error[0]}
99 | dict_errors['ood'] = {'no_covs': error[1]}
100 | dict_likelihoods['id'] = {'no_covs': error[2]}
101 | dict_likelihoods['ood'] = {'no_covs': error[3]}
102 | dict_errors_std['id'] = {'no_covs': error_std[0]}
103 | dict_errors_std['ood'] = {'no_covs': error_std[1]}
104 | dict_likelihoods_std['id'] = {'no_covs': error_std[2]}
105 | dict_likelihoods_std['ood'] = {'no_covs': error_std[3]}
106 |
107 | if model_names_with_covs is not None:
108 | error, error_std = parser(model_names_with_covs)
109 | dict_errors['id']['covs'] = error[0]
110 | dict_errors['ood']['covs'] = error[1]
111 | dict_likelihoods['id']['covs'] = error[2]
112 | dict_likelihoods['ood']['covs'] = error[3]
113 | dict_errors_std['id']['covs'] = error_std[0]
114 | dict_errors_std['ood']['covs'] = error_std[1]
115 | dict_likelihoods_std['id']['covs'] = error_std[2]
116 | dict_likelihoods_std['ood']['covs'] = error_std[3]
117 |
118 | return (dict_errors, dict_likelihoods), (dict_errors_std, dict_likelihoods_std)
--------------------------------------------------------------------------------
/paper_results/plots/figure2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure2.pdf
--------------------------------------------------------------------------------
/paper_results/plots/figure3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure3.pdf
--------------------------------------------------------------------------------
/paper_results/plots/figure3_annot.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure3_annot.pptx
--------------------------------------------------------------------------------
/paper_results/plots/figure4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure4.pdf
--------------------------------------------------------------------------------
/paper_results/plots/figure5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure5.pdf
--------------------------------------------------------------------------------
/paper_results/plots/figure6.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure6.pdf
--------------------------------------------------------------------------------
/paper_results/plots/figure6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/figure6.png
--------------------------------------------------------------------------------
/paper_results/plots/nhits_single_prediction.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/paper_results/plots/nhits_single_prediction.pdf
--------------------------------------------------------------------------------
/raw_data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/raw_data.zip
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | darts==0.29.0
2 | matplotlib==3.8.4
3 | numpy==1.26.4
4 | optuna==3.6.1
5 | pandas==2.2.2
6 | pillow==10.3.0
7 | pmdarima==2.0.4
8 | pytorch-lightning==2.2.4
9 | PyYAML==6.0.1
10 | requests==2.31.0
11 | scikit-learn==1.4.2
12 | scipy==1.13.0
13 | seaborn==0.13.2
14 | statsforecast==1.7.4
15 | statsmodels==0.14.2
16 | torch==2.3.0
17 | torchaudio==2.3.0
18 | torchmetrics==1.3.2
19 | torchvision==0.18.0
20 | tqdm==4.66.4
21 | xgboost==2.0.3
22 | jupyter==1.0.0
23 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IrinaStatsLab/GlucoBench/661d840a98b316df51faa13a7100430afcbbb5b7/utils/__init__.py
--------------------------------------------------------------------------------
/utils/darts_evaluation.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import yaml
4 | import random
5 | from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6 |
7 | import numpy as np
8 | from scipy import stats
9 | import pandas as pd
10 | import darts
11 |
12 | from darts import models
13 | from darts import metrics
14 | from darts import TimeSeries
15 |
16 | # import data formatter
17 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
18 | from data_formatter.base import *
19 | from utils.darts_processing import *
20 |
21 | def _get_values(
22 | series: TimeSeries, stochastic_quantile: Optional[float] = 0.5
23 | ) -> np.ndarray:
24 | """
25 | Returns the numpy values of a time series.
26 | For stochastic series, return either all sample values with (stochastic_quantile=None) or the quantile sample value
27 | with (stochastic_quantile {>=0,<=1})
28 | """
29 | if series.is_deterministic:
30 | series_values = series.univariate_values()
31 | else: # stochastic
32 | if stochastic_quantile is None:
33 | series_values = series.all_values(copy=False)
34 | else:
35 | series_values = series.quantile_timeseries(
36 | quantile=stochastic_quantile
37 | ).univariate_values()
38 | return series_values
39 |
40 | def _get_values_or_raise(
41 | series_a: TimeSeries,
42 | series_b: TimeSeries,
43 | intersect: bool,
44 | stochastic_quantile: Optional[float] = 0.5,
45 | remove_nan_union: bool = False,
46 | ) -> Tuple[np.ndarray, np.ndarray]:
47 | """Returns the processed numpy values of two time series. Processing can be customized with arguments
48 | `intersect, stochastic_quantile, remove_nan_union`.
49 |
50 | Raises a ValueError if the two time series (or their intersection) do not have the same time index.
51 |
52 | Parameters
53 | ----------
54 | series_a
55 | A univariate deterministic ``TimeSeries`` instance (the actual series).
56 | series_b
57 | A univariate (deterministic or stochastic) ``TimeSeries`` instance (the predicted series).
58 | intersect
59 | A boolean for whether or not to only consider the time intersection between `series_a` and `series_b`
60 | stochastic_quantile
61 | Optionally, for stochastic predicted series, return either all sample values with (`stochastic_quantile=None`)
62 | or any deterministic quantile sample values by setting `stochastic_quantile=quantile` {>=0,<=1}.
63 | remove_nan_union
64 | By setting `remove_non_union` to True, remove all indices from `series_a` and `series_b` which have a NaN value
65 | in either of the two input series.
66 | """
67 | series_a_common = series_a.slice_intersect(series_b) if intersect else series_a
68 | series_b_common = series_b.slice_intersect(series_a) if intersect else series_b
69 |
70 | series_a_det = _get_values(series_a_common, stochastic_quantile=stochastic_quantile)
71 | series_b_det = _get_values(series_b_common, stochastic_quantile=stochastic_quantile)
72 |
73 | if not remove_nan_union:
74 | return series_a_det, series_b_det
75 |
76 | b_is_deterministic = bool(len(series_b_det.shape) == 1)
77 | if b_is_deterministic:
78 | isnan_mask = np.logical_or(np.isnan(series_a_det), np.isnan(series_b_det))
79 | else:
80 | isnan_mask = np.logical_or(
81 | np.isnan(series_a_det), np.isnan(series_b_det).any(axis=2).flatten()
82 | )
83 | return np.delete(series_a_det, isnan_mask), np.delete(
84 | series_b_det, isnan_mask, axis=0
85 | )
86 |
87 | def rescale_and_backtest(series: Union[TimeSeries,
88 | Sequence[TimeSeries]],
89 | forecasts: Union[TimeSeries,
90 | Sequence[TimeSeries],
91 | Sequence[Sequence[TimeSeries]]],
92 | metric: Union[
93 | Callable[[TimeSeries, TimeSeries], float],
94 | List[Callable[[TimeSeries, TimeSeries], float]],
95 | ],
96 | scaler: Callable[[TimeSeries], TimeSeries] = None,
97 | reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
98 | likelihood: str = "GaussianMean",
99 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
100 | ):
101 | """
102 | Backtest the historical forecasts (as provided by Darts) on the series.
103 |
104 | Parameters
105 | ----------
106 | series
107 | The target time series.
108 | forecasts
109 | The forecasts.
110 | scaler
111 | The scaler used to scale the series.
112 | metric
113 | The metric or metrics to use for backtesting.
114 | reduction
115 | The reduction to apply to the metric.
116 | likelihood
117 | The likelihood to use for evaluating the model.
118 | cal_thresholds
119 | The thresholds to use for computing the calibration error.
120 |
121 | Returns
122 | -------
123 | np.ndarray
124 | Error array. If the reduction is none, array is of shape (n, p)
125 | where n is the total number of samples (forecasts) and p is the number of metrics.
126 | If the reduction is not none, array is of shape (k, p), where k is the number of series.
127 | float
128 | The estimated log-likelihood of the model on the data.
129 | np.ndarray
130 | The ECE for each time point in the forecast.
131 | """
132 | series = [series] if isinstance(series, TimeSeries) else series
133 | forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
134 | metric = [metric] if not isinstance(metric, list) else metric
135 |
136 | # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
137 | backtest_list = []
138 | for idx in range(len(series)):
139 | if scaler is not None:
140 | series[idx] = scaler.inverse_transform(series[idx])
141 | forecasts[idx] = [scaler.inverse_transform(f) for f in forecasts[idx]]
142 | errors = [
143 | [metric_f(series[idx], f) for metric_f in metric]
144 | if len(metric) > 1
145 | else metric[0](series[idx], f)
146 | for f in forecasts[idx]
147 | ]
148 | if reduction is None:
149 | backtest_list.append(np.array(errors))
150 | else:
151 | backtest_list.append(reduction(np.array(errors), axis=0))
152 | backtest_list = np.vstack(backtest_list)
153 |
154 | if likelihood == "GaussianMean":
155 | # compute likelihood
156 | est_var = []
157 | for idx, target_ts in enumerate(series):
158 | est_var += [metrics.mse(target_ts, f) for f in forecasts[idx]]
159 | est_var = np.mean(est_var)
160 | forecast_len = forecasts[0][0].n_timesteps
161 | log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
162 |
163 | # compute calibration error: 1) cdf values 2) compute calibration error
164 | # compute the cdf values
165 | cdf_vals = []
166 | for idx in range(len(series)):
167 | for forecast in forecasts[idx]:
168 | y_true, y_pred = _get_values_or_raise(series[idx],
169 | forecast,
170 | intersect=True,
171 | remove_nan_union=True)
172 | y_true, y_pred = y_true.flatten(), y_pred.flatten()
173 | cdf_vals.append(stats.norm.cdf(y_true, loc=y_pred, scale=np.sqrt(est_var)))
174 | cdf_vals = np.vstack(cdf_vals)
175 | # compute the prediction calibration
176 | cal_error = np.zeros(forecasts[0][0].n_timesteps)
177 | for p in cal_thresholds:
178 | est_p = (cdf_vals <= p).astype(float)
179 | est_p = np.mean(est_p, axis=0)
180 | cal_error += (est_p - p) ** 2
181 |
182 | return backtest_list, log_likelihood, cal_error
183 |
184 | def rescale_and_test(series: Union[TimeSeries,
185 | Sequence[TimeSeries]],
186 | forecasts: Union[TimeSeries,
187 | Sequence[TimeSeries]],
188 | metric: Union[
189 | Callable[[TimeSeries, TimeSeries], float],
190 | List[Callable[[TimeSeries, TimeSeries], float]],
191 | ],
192 | scaler: Callable[[TimeSeries], TimeSeries] = None,
193 | likelihood: str = "GaussianMean",
194 | cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
195 | ):
196 | """
197 | Test the forecasts on the series.
198 |
199 | Parameters
200 | ----------
201 | series
202 | The target time series.
203 | forecasts
204 | The forecasts.
205 | scaler
206 | The scaler used to scale the series.
207 | metric
208 | The metric or metrics to use for backtesting.
209 | reduction
210 | The reduction to apply to the metric.
211 | likelihood
212 | The likelihood to use for evaluating the likelihood and calibration of model.
213 | cal_thresholds
214 | The thresholds to use for computing the calibration error.
215 |
216 | Returns
217 | -------
218 | np.ndarray
219 | Error array. If the reduction is none, array is of shape (n, p)
220 | where n is the total number of samples (forecasts) and p is the number of metrics.
221 | If the reduction is not none, array is of shape (k, p), where k is the number of series.
222 | float
223 | The estimated log-likelihood of the model on the data.
224 | np.ndarray
225 | The ECE for each time point in the forecast.
226 | """
227 | series = [series] if isinstance(series, TimeSeries) else series
228 | forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
229 | metric = [metric] if not isinstance(metric, list) else metric
230 |
231 | # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
232 | series = scaler.inverse_transform(series)
233 | forecasts = scaler.inverse_transform(forecasts)
234 | errors = [
235 | [metric_f(t, f) for metric_f in metric]
236 | if len(metric) > 1
237 | else metric[0](t, f)
238 | for (t, f) in zip(series, forecasts)
239 | ]
240 | errors = np.array(errors)
241 |
242 | if likelihood == "GaussianMean":
243 | # compute likelihood
244 | est_var = [metrics.mse(t, f) for (t, f) in zip(series, forecasts)]
245 | est_var = np.mean(est_var)
246 | forecast_len = forecasts[0].n_timesteps
247 | log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
248 |
249 | # compute calibration error: 1) cdf values 2) compute calibration error
250 | # compute the cdf values
251 | cdf_vals = []
252 | for t, f in zip(series, forecasts):
253 | t, f = _get_values_or_raise(t, f, intersect=True, remove_nan_union=True)
254 | t, f = t.flatten(), f.flatten()
255 | cdf_vals.append(stats.norm.cdf(t, loc=f, scale=np.sqrt(est_var)))
256 | cdf_vals = np.vstack(cdf_vals)
257 | # compute the prediction calibration
258 | cal_error = np.zeros(forecasts[0].n_timesteps)
259 | for p in cal_thresholds:
260 | est_p = (cdf_vals <= p).astype(float)
261 | est_p = np.mean(est_p, axis=0)
262 | cal_error += (est_p - p) ** 2
263 |
264 | if likelihood == "Quantile":
265 | # no likelihood since we don't have a parametric model
266 | log_likelihood = 0
267 |
268 | # compute calibration error: 1) get quantiles 2) compute calibration error
269 | cal_error = np.zeros(forecasts[0].n_timesteps)
270 | for p in cal_thresholds:
271 | est_p = 0
272 | for t, f in zip(series, forecasts):
273 | q = f.quantile(p)
274 | t, q = _get_values_or_raise(t, q, intersect=True, remove_nan_union=True)
275 | t, q = t.flatten(), q.flatten()
276 | est_p += (t <= q).astype(float)
277 | est_p = (est_p / len(series)).flatten()
278 | cal_error += (est_p - p) ** 2
279 |
280 | return errors, log_likelihood, cal_error
--------------------------------------------------------------------------------
/utils/darts_training.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import yaml
4 | import random
5 | from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6 |
7 | import numpy as np
8 | from scipy import stats
9 | import pandas as pd
10 | import darts
11 |
12 | from darts import models
13 | from darts import metrics
14 | from darts import TimeSeries
15 | from pytorch_lightning.callbacks import Callback
16 | from darts.logging import get_logger, raise_if_not
17 |
18 | # for optuna callback
19 | import warnings
20 | import optuna
21 | from optuna.storages._cached_storage import _CachedStorage
22 | from optuna.storages._rdb.storage import RDBStorage
23 | # Define key names of `Trial.system_attrs`.
24 | _PRUNED_KEY = "ddp_pl:pruned"
25 | _EPOCH_KEY = "ddp_pl:epoch"
26 | with optuna._imports.try_import() as _imports:
27 | import pytorch_lightning as pl
28 | from pytorch_lightning import LightningModule
29 | from pytorch_lightning import Trainer
30 | from pytorch_lightning.callbacks import Callback
31 | if not _imports.is_successful():
32 | Callback = object # type: ignore # NOQA
33 | LightningModule = object # type: ignore # NOQA
34 | Trainer = object # type: ignore # NOQA
35 |
36 | def print_callback(study, trial, study_file=None):
37 | # write output to a file
38 | with open(study_file, "a") as f:
39 | f.write(f"Current value: {trial.value}, Current params: {trial.params}\n")
40 | f.write(f"Best value: {study.best_value}, Best params: {study.best_trial.params}\n")
41 |
42 | def early_stopping_check(study,
43 | trial,
44 | study_file,
45 | early_stopping_rounds=10):
46 | """
47 | Early stopping callback for Optuna.
48 | This function checks the current trial number and the best trial number.
49 | """
50 | current_trial_number = trial.number
51 | best_trial_number = study.best_trial.number
52 | should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
53 | if should_stop:
54 | with open(study_file, 'a') as f:
55 | f.write('\nEarly stopping at trial {} (best trial: {})'.format(current_trial_number, best_trial_number))
56 | study.stop()
57 |
58 | class LossLogger(Callback):
59 | def __init__(self):
60 | self.train_loss = []
61 | self.val_loss = []
62 |
63 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
64 | self.train_loss.append(float(trainer.callback_metrics["train_loss"]))
65 |
66 | def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
67 | self.val_loss.append(float(trainer.callback_metrics["val_loss"]))
68 |
69 | class PyTorchLightningPruningCallback(Callback):
70 | """PyTorch Lightning callback to prune unpromising trials.
71 | See `the example `__
73 | if you want to add a pruning callback which observes accuracy.
74 | Args:
75 | trial:
76 | A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
77 | objective function.
78 | monitor:
79 | An evaluation metric for pruning, e.g., ``val_loss`` or
80 | ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
81 | ``pytorch_lightning.LightningModule.training_step`` or
82 | ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
83 | how this dictionary is formatted.
84 | """
85 |
86 | def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
87 | super().__init__()
88 |
89 | self._trial = trial
90 | self.monitor = monitor
91 |
92 | def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
93 | # When the trainer calls `on_validation_end` for sanity check,
94 | # do not call `trial.report` to avoid calling `trial.report` multiple times
95 | # at epoch 0. The related page is
96 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
97 | if trainer.sanity_checking:
98 | return
99 |
100 | epoch = pl_module.current_epoch
101 |
102 | current_score = trainer.callback_metrics.get(self.monitor)
103 | if current_score is None:
104 | message = (
105 | "The metric '{}' is not in the evaluation logs for pruning. "
106 | "Please make sure you set the correct metric name.".format(self.monitor)
107 | )
108 | warnings.warn(message)
109 | return
110 |
111 | self._trial.report(current_score, step=epoch)
112 | if self._trial.should_prune():
113 | message = "Trial was pruned at epoch {}.".format(epoch)
114 | raise optuna.TrialPruned(message)
--------------------------------------------------------------------------------