├── .gitignore
├── README.md
├── conda_env.yml
├── config
├── bimpgru
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ └── la_point.yaml
├── brits
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ ├── la_point.yaml
│ └── synthetic.yaml
├── grin
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ ├── la_point.yaml
│ └── synthetic.yaml
├── mpgru
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ └── la_point.yaml
├── rgain
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ └── la_point.yaml
└── var
│ ├── air.yaml
│ ├── air36.yaml
│ ├── bay_block.yaml
│ ├── bay_point.yaml
│ ├── irish_block.yaml
│ ├── irish_point.yaml
│ ├── la_block.yaml
│ └── la_point.yaml
├── grin.png
├── lib
├── __init__.py
├── data
│ ├── __init__.py
│ ├── datamodule
│ │ ├── __init__.py
│ │ └── spatiotemporal.py
│ ├── imputation_dataset.py
│ ├── preprocessing
│ │ ├── __init__.py
│ │ └── scalers.py
│ ├── spatiotemporal_dataset.py
│ └── temporal_dataset.py
├── datasets
│ ├── __init__.py
│ ├── air_quality.py
│ ├── metr_la.py
│ ├── pd_dataset.py
│ ├── pems_bay.py
│ └── synthetic.py
├── fillers
│ ├── __init__.py
│ ├── britsfiller.py
│ ├── filler.py
│ ├── graphfiller.py
│ ├── multi_imputation_filler.py
│ └── rgainfiller.py
├── nn
│ ├── __init__.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── gcrnn.py
│ │ ├── gril.py
│ │ ├── imputation.py
│ │ ├── mpgru.py
│ │ ├── rits.py
│ │ ├── spatial_attention.py
│ │ └── spatial_conv.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── brits.py
│ │ ├── grin.py
│ │ ├── mpgru.py
│ │ ├── rgain.py
│ │ ├── rnn_imputers.py
│ │ └── var.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── metric_base.py
│ │ ├── metrics.py
│ │ └── ops.py
└── utils
│ ├── __init__.py
│ ├── numpy_metrics.py
│ ├── parser_utils.py
│ └── utils.py
├── requirements.txt
└── scripts
├── run_baselines.py
├── run_imputation.py
└── run_synthetic.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_STORE
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks (ICLR 2022 - [open review](https://openreview.net/forum?id=kOu3-S3wJ7) - [pdf](https://openreview.net/pdf?id=kOu3-S3wJ7))
2 |
3 | [](https://openreview.net/forum?id=kOu3-S3wJ7)
4 | [](https://openreview.net/pdf?id=kOu3-S3wJ7)
5 | [](https://arxiv.org/abs/2108.00298)
6 |
7 | This repository contains the code for the reproducibility of the experiments presented in the paper "Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks" (ICLR 2022). In this paper, we propose a graph neural network architecture for multivariate time series imputation and achieve state-of-the-art results on several benchmarks.
8 |
9 | **Authors**: [Andrea Cini](mailto:andrea.cini@usi.ch), [Ivan Marisca](mailto:ivan.marisca@usi.ch), Cesare Alippi
10 |
11 |
12 | **‼️ PyG implementation of GRIN is now available inside [Torch Spatiotemporal](https://github.com/TorchSpatiotemporal/tsl), a library built to accelerate research on neural spatiotemporal data processing methods, with a focus on Graph Neural Networks.**
13 |
14 | ---
15 |
16 |
GRIN in a nutshell
17 |
18 | The [paper](https://arxiv.org/abs/2108.00298) introduces __GRIN__, a method and an architecture to exploit relational inductive biases to reconstruct missing values in multivariate time series coming from sensor networks. GRIN features a bidirectional recurrent GNN which learns __spatio-temporal node-level representations__ tailored to reconstruct observations at neighboring nodes.
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | ---
27 |
28 | ## Directory structure
29 |
30 | The directory is structured as follows:
31 |
32 | ```
33 | .
34 | ├── config
35 | │ ├── bimpgru
36 | │ ├── brits
37 | │ ├── grin
38 | │ ├── mpgru
39 | │ ├── rgain
40 | │ └── var
41 | ├── datasets
42 | │ ├── air_quality
43 | │ ├── metr_la
44 | │ ├── pems_bay
45 | │ └── synthetic
46 | ├── lib
47 | │ ├── __init__.py
48 | │ ├── data
49 | │ ├── datasets
50 | │ ├── fillers
51 | │ ├── nn
52 | │ └── utils
53 | ├── requirements.txt
54 | └── scripts
55 | ├── run_baselines.py
56 | ├── run_imputation.py
57 | └── run_synthetic.py
58 |
59 | ```
60 | Note that, given the size of the files, the datasets are not readily available in the folder. See the next section for the downloading instructions.
61 |
62 | ## Datasets
63 |
64 | All the datasets used in the experiment, except CER-E, are open and can be downloaded from this [link](https://mega.nz/folder/qwwG3Qba#c6qFTeT7apmZKKyEunCzSg). The CER-E dataset can be obtained free of charge for research purposes following the instructions at this [link](https://www.ucd.ie/issda/data/commissionforenergyregulationcer/). We recommend storing the downloaded datasets in a folder named `datasets` inside this directory.
65 |
66 | ## Configuration files
67 |
68 | The `config` directory stores all the configuration files used to run the experiment. They are divided into folders, according to the model.
69 |
70 | ## Library
71 |
72 | The support code, including the models and the datasets readers, are packed in a python library named `lib`. Should you have to change the paths to the datasets location, you have to edit the `__init__.py` file of the library.
73 |
74 | ## Scripts
75 |
76 | The scripts used for the experiment in the paper are in the `scripts` folder.
77 |
78 | * `run_baselines.py` is used to compute the metrics for the `MEAN`, `KNN`, `MF` and `MICE` imputation methods. An example of usage is
79 |
80 | ```
81 | python ./scripts/run_baselines.py --datasets air36 air --imputers mean knn --k 10 --in-sample True --n-runs 5
82 | ```
83 |
84 | * `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is
85 |
86 | ```
87 | python ./scripts/run_imputation.py --config config/grin/air36.yaml --in-sample False
88 | ```
89 |
90 | * `run_synthetic.py` is used for the experiments on the synthetic datasets. An example of usage is
91 |
92 | ```
93 | python ./scripts/run_synthetic.py --config config/grin/synthetic.yaml --static-adj False
94 | ```
95 |
96 | ## Requirements
97 |
98 | We run all the experiments in `python 3.8`, see `requirements.txt` for the list of `pip` dependencies.
99 |
100 | ## Bibtex reference
101 |
102 | If you find this code useful please consider to cite our paper:
103 |
104 | ```
105 | @inproceedings{cini2022filling,
106 | title={Filling the G\_ap\_s: Multivariate Time Series Imputation by Graph Neural Networks},
107 | author={Andrea Cini and Ivan Marisca and Cesare Alippi},
108 | booktitle={International Conference on Learning Representations},
109 | year={2022},
110 | url={https://openreview.net/forum?id=kOu3-S3wJ7}
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: grin
2 | channels:
3 | - defaults
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - pip
8 | - pytables
9 | - python=3.8
10 | - pytorch=1.8
11 | - torchvision
12 | - torchaudio
13 | - wheel
14 | - pip:
15 | - einops
16 | - fancyimpute==0.6
17 | - h5py
18 | - openpyxl
19 | - pandas
20 | - pytorch-lightning==1.4
21 | - pyyaml
22 | - scikit-learn
23 | - scipy
24 | - tensorboard
25 |
--------------------------------------------------------------------------------
/config/bimpgru/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/bimpgru/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'bimpgru'
17 | d_hidden: 64
18 | d_emb: 8
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
24 | merge: 'mlp'
--------------------------------------------------------------------------------
/config/brits/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 128
--------------------------------------------------------------------------------
/config/brits/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 64
--------------------------------------------------------------------------------
/config/brits/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 256
--------------------------------------------------------------------------------
/config/brits/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 256
--------------------------------------------------------------------------------
/config/brits/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 256
--------------------------------------------------------------------------------
/config/brits/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 256
--------------------------------------------------------------------------------
/config/brits/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 128
--------------------------------------------------------------------------------
/config/brits/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | batch_size: 32
12 | aggregate_by: ['mean']
13 |
14 | model_name: 'brits'
15 | d_hidden: 128
--------------------------------------------------------------------------------
/config/brits/synthetic.yaml:
--------------------------------------------------------------------------------
1 | window: 36
2 | p_block: 0.025
3 | p_point: 0.025
4 | min_seq: 4
5 | max_seq: 9
6 | use_exogenous: False
7 |
8 | epochs: 200
9 | batch_size: 32
10 |
11 | model_name: 'brits'
12 | d_hidden: 32
--------------------------------------------------------------------------------
/config/grin/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'grin'
17 | pred_loss_weight: 1
18 |
19 | d_hidden: 64
20 | d_emb: 8
21 | d_ff: 64
22 | ff_dropout: 0
23 | kernel_size: 2
24 | decoder_order: 1
25 | n_layers: 1
26 | layer_norm: false
27 | merge: 'mlp'
28 |
--------------------------------------------------------------------------------
/config/grin/synthetic.yaml:
--------------------------------------------------------------------------------
1 | window: 36
2 | p_block: 0.025
3 | p_point: 0.025
4 | min_seq: 4
5 | max_seq: 9
6 | use_exogenous: False
7 |
8 | epochs: 200
9 | batch_size: 32
10 |
11 | model_name: 'grin'
12 | d_hidden: 16
13 | d_emb: 0
14 | d_ff: 16
15 | ff_dropout: 0
16 | kernel_size: 1
17 | decoder_order: 1
18 | n_layers: 1
19 | layer_norm: false
20 | merge: 'mlp'
21 |
--------------------------------------------------------------------------------
/config/mpgru/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/mpgru/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 |
4 | adj_threshold: 0.1
5 |
6 | detrend: False
7 | scale: True
8 | scaling_axis: 'global' # ['channels', 'global']
9 | scaled_target: True
10 |
11 | epochs: 300
12 | samples_per_epoch: 5120 # 160 batch of 32
13 | batch_size: 32
14 | aggregate_by: ['mean']
15 |
16 | model_name: 'mpgru'
17 | pred_loss_weight: 1
18 | d_hidden: 64
19 | d_ff: 64
20 | dropout: 0
21 | kernel_size: 2
22 | n_layers: 1
23 | layer_norm: false
--------------------------------------------------------------------------------
/config/rgain/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 128
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 64
21 | d_z: 4
22 | dropout: 0.1
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 256
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 256
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 256
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 256
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 128
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/rgain/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 | whiten_prob: 0.2
4 |
5 | detrend: False
6 | scale: True
7 | scaling_axis: 'channels'
8 | scaled_target: True
9 |
10 | epochs: 300
11 | samples_per_epoch: 5120 # 160 batch of 32
12 | batch_size: 32
13 | loss_fn: mse_loss
14 | consistency_loss: False
15 | use_lr_schedule: True
16 | grad_clip_val: -1
17 | aggregate_by: ['mean']
18 |
19 | model_name: 'gain'
20 | d_model: 128
21 | d_z: 4
22 | dropout: 0.2
23 | inject_noise: true
24 | alpha: 20
25 | g_train_freq: 3
26 | d_train_freq: 1
--------------------------------------------------------------------------------
/config/var/air.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/air36.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'air36'
2 | window: 36
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/bay_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/bay_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'bay_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/irish_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/irish_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'irish_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/la_block.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_block'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/config/var/la_point.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: 'la_point'
2 | window: 24
3 |
4 | detrend: False
5 | scale: True
6 | scaling_axis: 'channels'
7 | scaled_target: True
8 |
9 | epochs: 300
10 | samples_per_epoch: 5120 # 160 batch of 32
11 | lr: 0.0005
12 | batch_size: 64
13 | aggregate_by: ['mean']
14 |
15 | model_name: 'var'
16 | order: 5
17 | padding: 'mean'
--------------------------------------------------------------------------------
/grin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/grin/4a28afbb092600b6e6abeeabaaf67e87dbd1ed6e/grin.png
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
4 |
5 | config = {
6 | 'logs': 'logs/'
7 | }
8 | datasets_path = {
9 | 'air': 'datasets/air_quality',
10 | 'la': 'datasets/metr_la',
11 | 'bay': 'datasets/pems_bay',
12 | 'synthetic': 'datasets/synthetic'
13 | }
14 | epsilon = 1e-8
15 |
16 | for k, v in config.items():
17 | config[k] = os.path.join(base_dir, v)
18 | for k, v in datasets_path.items():
19 | datasets_path[k] = os.path.join(base_dir, v)
20 |
--------------------------------------------------------------------------------
/lib/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .temporal_dataset import TemporalDataset
2 | from .spatiotemporal_dataset import SpatioTemporalDataset
3 |
--------------------------------------------------------------------------------
/lib/data/datamodule/__init__.py:
--------------------------------------------------------------------------------
1 | from .spatiotemporal import SpatioTemporalDataModule
2 |
--------------------------------------------------------------------------------
/lib/data/datamodule/spatiotemporal.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from torch.utils.data import DataLoader, Subset, RandomSampler
3 |
4 | from .. import TemporalDataset, SpatioTemporalDataset
5 | from ..preprocessing import StandardScaler, MinMaxScaler
6 | from ...utils import ensure_list
7 | from ...utils.parser_utils import str_to_bool
8 |
9 |
10 | class SpatioTemporalDataModule(pl.LightningDataModule):
11 | """
12 | Pytorch Lightning DataModule for TimeSeriesDatasets
13 | """
14 |
15 | def __init__(self, dataset: TemporalDataset,
16 | scale=True,
17 | scaling_axis='samples',
18 | scaling_type='std',
19 | scale_exogenous=None,
20 | train_idxs=None,
21 | val_idxs=None,
22 | test_idxs=None,
23 | batch_size=32,
24 | workers=1,
25 | samples_per_epoch=None):
26 | super(SpatioTemporalDataModule, self).__init__()
27 | self.torch_dataset = dataset
28 | # splitting
29 | self.trainset = Subset(self.torch_dataset, train_idxs if train_idxs is not None else [])
30 | self.valset = Subset(self.torch_dataset, val_idxs if val_idxs is not None else [])
31 | self.testset = Subset(self.torch_dataset, test_idxs if test_idxs is not None else [])
32 | # preprocessing
33 | self.scale = scale
34 | self.scaling_type = scaling_type
35 | self.scaling_axis = scaling_axis
36 | self.scale_exogenous = ensure_list(scale_exogenous) if scale_exogenous is not None else None
37 | # data loaders
38 | self.batch_size = batch_size
39 | self.workers = workers
40 | self.samples_per_epoch = samples_per_epoch
41 |
42 | @property
43 | def is_spatial(self):
44 | return isinstance(self.torch_dataset, SpatioTemporalDataset)
45 |
46 | @property
47 | def n_nodes(self):
48 | if not self.has_setup_fit:
49 | raise ValueError('You should initialize the datamodule first.')
50 | return self.torch_dataset.n_nodes if self.is_spatial else None
51 |
52 | @property
53 | def d_in(self):
54 | if not self.has_setup_fit:
55 | raise ValueError('You should initialize the datamodule first.')
56 | return self.torch_dataset.n_channels
57 |
58 | @property
59 | def d_out(self):
60 | if not self.has_setup_fit:
61 | raise ValueError('You should initialize the datamodule first.')
62 | return self.torch_dataset.horizon
63 |
64 | @property
65 | def train_slice(self):
66 | return self.torch_dataset.expand_indices(self.trainset.indices, merge=True)
67 |
68 | @property
69 | def val_slice(self):
70 | return self.torch_dataset.expand_indices(self.valset.indices, merge=True)
71 |
72 | @property
73 | def test_slice(self):
74 | return self.torch_dataset.expand_indices(self.testset.indices, merge=True)
75 |
76 | def get_scaling_axes(self, dim='global'):
77 | scaling_axis = tuple()
78 | if dim == 'global':
79 | scaling_axis = (0, 1, 2)
80 | elif dim == 'channels':
81 | scaling_axis = (0, 1)
82 | elif dim == 'nodes':
83 | scaling_axis = (0,)
84 | # Remove last dimension for temporal datasets
85 | if not self.is_spatial:
86 | scaling_axis = scaling_axis[:-1]
87 |
88 | if not len(scaling_axis):
89 | raise ValueError(f'Scaling axis "{dim}" not valid.')
90 |
91 | return scaling_axis
92 |
93 | def get_scaler(self):
94 | if self.scaling_type == 'std':
95 | return StandardScaler
96 | elif self.scaling_type == 'minmax':
97 | return MinMaxScaler
98 | else:
99 | return NotImplementedError
100 |
101 | def setup(self, stage=None):
102 |
103 | if self.scale:
104 | scaling_axis = self.get_scaling_axes(self.scaling_axis)
105 | train = self.torch_dataset.data.numpy()[self.train_slice]
106 | train_mask = self.torch_dataset.mask.numpy()[self.train_slice] if 'mask' in self.torch_dataset else None
107 | scaler = self.get_scaler()(scaling_axis).fit(train, mask=train_mask, keepdims=True).to_torch()
108 | self.torch_dataset.scaler = scaler
109 |
110 | if self.scale_exogenous is not None:
111 | for label in self.scale_exogenous:
112 | exo = getattr(self.torch_dataset, label)
113 | scaler = self.get_scaler()(scaling_axis)
114 | scaler.fit(exo[self.train_slice], keepdims=True).to_torch()
115 | setattr(self.torch_dataset, label, scaler.transform(exo))
116 |
117 | def _data_loader(self, dataset, shuffle=False, batch_size=None, **kwargs):
118 | batch_size = self.batch_size if batch_size is None else batch_size
119 | return DataLoader(dataset,
120 | shuffle=shuffle,
121 | batch_size=batch_size,
122 | num_workers=self.workers,
123 | **kwargs)
124 |
125 | def train_dataloader(self, shuffle=True, batch_size=None):
126 | if self.samples_per_epoch is not None:
127 | sampler = RandomSampler(self.trainset, replacement=True, num_samples=self.samples_per_epoch)
128 | return self._data_loader(self.trainset, False, batch_size, sampler=sampler, drop_last=True)
129 | return self._data_loader(self.trainset, shuffle, batch_size, drop_last=True)
130 |
131 | def val_dataloader(self, shuffle=False, batch_size=None):
132 | return self._data_loader(self.valset, shuffle, batch_size)
133 |
134 | def test_dataloader(self, shuffle=False, batch_size=None):
135 | return self._data_loader(self.testset, shuffle, batch_size)
136 |
137 | @staticmethod
138 | def add_argparse_args(parser, **kwargs):
139 | parser.add_argument('--batch-size', type=int, default=64)
140 | parser.add_argument('--scaling-axis', type=str, default="channels")
141 | parser.add_argument('--scaling-type', type=str, default="std")
142 | parser.add_argument('--scale', type=str_to_bool, nargs='?', const=True, default=True)
143 | parser.add_argument('--workers', type=int, default=0)
144 | parser.add_argument('--samples-per-epoch', type=int, default=None)
145 | return parser
146 |
--------------------------------------------------------------------------------
/lib/data/imputation_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from . import TemporalDataset, SpatioTemporalDataset
5 |
6 |
7 | class ImputationDataset(TemporalDataset):
8 |
9 | def __init__(self, data,
10 | index=None,
11 | mask=None,
12 | eval_mask=None,
13 | freq=None,
14 | trend=None,
15 | scaler=None,
16 | window=24,
17 | stride=1,
18 | exogenous=None):
19 | if mask is None:
20 | mask = np.ones_like(data)
21 | if exogenous is None:
22 | exogenous = dict()
23 | exogenous['mask_window'] = mask
24 | if eval_mask is not None:
25 | exogenous['eval_mask_window'] = eval_mask
26 | super(ImputationDataset, self).__init__(data,
27 | index=index,
28 | exogenous=exogenous,
29 | trend=trend,
30 | scaler=scaler,
31 | freq=freq,
32 | window=window,
33 | horizon=window,
34 | delay=-window,
35 | stride=stride)
36 |
37 | def get(self, item, preprocess=False):
38 | res, transform = super(ImputationDataset, self).get(item, preprocess)
39 | res['x'] = torch.where(res['mask'], res['x'], torch.zeros_like(res['x']))
40 | return res, transform
41 |
42 |
43 | class GraphImputationDataset(ImputationDataset, SpatioTemporalDataset):
44 | pass
45 |
--------------------------------------------------------------------------------
/lib/data/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .scalers import *
2 |
--------------------------------------------------------------------------------
/lib/data/preprocessing/scalers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import numpy as np
3 |
4 |
5 | class AbstractScaler(ABC):
6 |
7 | def __init__(self, **kwargs):
8 | for k, v in kwargs.items():
9 | setattr(self, k, v)
10 |
11 | def __repr__(self):
12 | params = ", ".join([f"{k}={str(v)}" for k, v in self.params().items()])
13 | return "{}({})".format(self.__class__.__name__, params)
14 |
15 | def __call__(self, *args, **kwargs):
16 | return self.transform(*args, **kwargs)
17 |
18 | def params(self):
19 | return {k: v for k, v in self.__dict__.items() if not callable(v) and not k.startswith("__")}
20 |
21 | @abstractmethod
22 | def fit(self, x):
23 | pass
24 |
25 | @abstractmethod
26 | def transform(self, x):
27 | pass
28 |
29 | @abstractmethod
30 | def inverse_transform(self, x):
31 | pass
32 |
33 | def fit_transform(self, x):
34 | self.fit(x)
35 | return self.transform(x)
36 |
37 | def to_torch(self):
38 | import torch
39 | for p in self.params():
40 | param = getattr(self, p)
41 | param = np.atleast_1d(param)
42 | param = torch.tensor(param).float()
43 | setattr(self, p, param)
44 | return self
45 |
46 |
47 | class Scaler(AbstractScaler):
48 | def __init__(self, offset=0., scale=1.):
49 | self.bias = offset
50 | self.scale = scale
51 | super(Scaler, self).__init__()
52 |
53 | def params(self):
54 | return dict(bias=self.bias, scale=self.scale)
55 |
56 | def fit(self, x, mask=None, keepdims=True):
57 | pass
58 |
59 | def transform(self, x):
60 | return (x - self.bias) / self.scale
61 |
62 | def inverse_transform(self, x):
63 | return x * self.scale + self.bias
64 |
65 | def fit_transform(self, x, mask=None, keepdims=True):
66 | self.fit(x, mask, keepdims)
67 | return self.transform(x)
68 |
69 |
70 | class StandardScaler(Scaler):
71 | def __init__(self, axis=0):
72 | self.axis = axis
73 | super(StandardScaler, self).__init__()
74 |
75 | def fit(self, x, mask=None, keepdims=True):
76 | if mask is not None:
77 | x = np.where(mask, x, np.nan)
78 | self.bias = np.nanmean(x, axis=self.axis, keepdims=keepdims)
79 | self.scale = np.nanstd(x, axis=self.axis, keepdims=keepdims)
80 | else:
81 | self.bias = x.mean(axis=self.axis, keepdims=keepdims)
82 | self.scale = x.std(axis=self.axis, keepdims=keepdims)
83 | return self
84 |
85 |
86 | class MinMaxScaler(Scaler):
87 | def __init__(self, axis=0):
88 | self.axis = axis
89 | super(MinMaxScaler, self).__init__()
90 |
91 | def fit(self, x, mask=None, keepdims=True):
92 | if mask is not None:
93 | x = np.where(mask, x, np.nan)
94 | self.bias = np.nanmin(x, axis=self.axis, keepdims=keepdims)
95 | self.scale = (np.nanmax(x, axis=self.axis, keepdims=keepdims) - self.bias)
96 | else:
97 | self.bias = x.min(axis=self.axis, keepdims=keepdims)
98 | self.scale = (x.max(axis=self.axis, keepdims=keepdims) - self.bias)
99 | return self
100 |
--------------------------------------------------------------------------------
/lib/data/spatiotemporal_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from einops import rearrange
4 |
5 | from .temporal_dataset import TemporalDataset
6 |
7 |
8 | class SpatioTemporalDataset(TemporalDataset):
9 | def __init__(self, data,
10 | index=None,
11 | trend=None,
12 | scaler=None,
13 | freq=None,
14 | window=24,
15 | horizon=24,
16 | delay=0,
17 | stride=1,
18 | **exogenous):
19 | """
20 | Pytorch dataset for data that can be represented as a single TimeSeries
21 |
22 | :param data:
23 | raw target time series (ts) (can be multivariate), shape: [steps, (features), nodes]
24 | :param exog:
25 | global exogenous variables, shape: [steps, nodes]
26 | :param trend:
27 | trend time series to be removed from the ts, shape: [steps, (features), (nodes)]
28 | :param bias:
29 | bias to be removed from the ts (after de-trending), shape [steps, (features), (nodes)]
30 | :param scale: r
31 | scaling factor to scale the ts (after de-trending), shape [steps, (features), (nodes)]
32 | :param mask:
33 | mask for valid data, 1 -> valid time step, 0 -> invalid. same shape of ts.
34 | :param target_exog:
35 | exogenous variables of the target, shape: [steps, nodes]
36 | :param window:
37 | length of windows returned by __get_intem__
38 | :param horizon:
39 | length of prediction horizon returned by __get_intem__
40 | :param delay:
41 | delay between input and prediction
42 | """
43 | super(SpatioTemporalDataset, self).__init__(data,
44 | index=index,
45 | trend=trend,
46 | scaler=scaler,
47 | freq=freq,
48 | window=window,
49 | horizon=horizon,
50 | delay=delay,
51 | stride=stride,
52 | **exogenous)
53 |
54 | def __repr__(self):
55 | return "{}(n_samples={}, n_nodes={})".format(self.__class__.__name__, len(self), self.n_nodes)
56 |
57 | @property
58 | def n_nodes(self):
59 | return self.data.shape[1]
60 |
61 | @staticmethod
62 | def check_dim(data):
63 | if data.ndim == 2: # [steps, nodes] -> [steps, nodes, features]
64 | data = rearrange(data, 's (n f) -> s n f', f=1)
65 | elif data.ndim == 1:
66 | data = rearrange(data, '(s n f) -> s n f', n=1, f=1)
67 | elif data.ndim == 3:
68 | pass
69 | else:
70 | raise ValueError(f'Invalid data dimensions {data.shape}')
71 | return data
72 |
73 | def dataframe(self):
74 | if self.n_channels == 1:
75 | return pd.DataFrame(data=np.squeeze(self.data, -1), index=self.index)
76 | raise NotImplementedError()
77 |
--------------------------------------------------------------------------------
/lib/data/temporal_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | from einops import rearrange
5 | from pandas import DatetimeIndex
6 | from torch.utils.data import Dataset
7 |
8 | from .preprocessing import AbstractScaler
9 |
10 |
11 | class TemporalDataset(Dataset):
12 | def __init__(self, data,
13 | index=None,
14 | freq=None,
15 | exogenous=None,
16 | trend=None,
17 | scaler=None,
18 | window=24,
19 | horizon=24,
20 | delay=0,
21 | stride=1):
22 | """Wrapper class for dataset whose entry are dependent from a sequence of temporal indices.
23 |
24 | Parameters
25 | ----------
26 | data : np.ndarray
27 | Data relative to the main signal.
28 | index : DatetimeIndex or None
29 | Temporal indices for the data.
30 | exogenous : dict or None
31 | Exogenous data and label paired with main signal (default is None).
32 | trend : np.ndarray or None
33 | Trend paired with main signal (default is None). Must be of the same length of 'data'.
34 | scaler : AbstractScaler or None
35 | Scaler that must be used for data (default is None).
36 | freq : pd.DateTimeIndex.freq or str
37 | Frequency of the indices (defaults is indices.freq).
38 | window : int
39 | Size of the sliding window in the past.
40 | horizon : int
41 | Size of the prediction horizon.
42 | delay : int
43 | Offset between end of window and start of horizon.
44 |
45 | Raises
46 | ----------
47 | ValueError
48 | If a frequency for the temporal indices is not provided neither in indices nor explicitly.
49 | If preprocess is True and data_scaler is None.
50 | """
51 | super(TemporalDataset, self).__init__()
52 | # Initialize signatures
53 | self.__exogenous_keys = dict()
54 | self.__reserved_signature = {'data', 'trend', 'x', 'y'}
55 | # Store data
56 | self.data = data
57 | if exogenous is not None:
58 | for name, value in exogenous.items():
59 | self.add_exogenous(value, name, for_window=True, for_horizon=True)
60 | # Store time information
61 | self.index = index
62 | try:
63 | freq = freq or index.freq or index.inferred_freq
64 | self.freq = pd.tseries.frequencies.to_offset(freq)
65 | except AttributeError:
66 | self.freq = None
67 | # Store offset information
68 | self.window = window
69 | self.delay = delay
70 | self.horizon = horizon
71 | self.stride = stride
72 | # Identify the indices of the samples
73 | self._indices = np.arange(self.data.shape[0] - self.sample_span + 1)[::self.stride]
74 | # Store preprocessing options
75 | self.trend = trend
76 | self.scaler = scaler
77 |
78 | def __getitem__(self, item):
79 | return self.get(item, self.preprocess)
80 |
81 | def __contains__(self, item):
82 | return item in self.__exogenous_keys
83 |
84 | def __len__(self):
85 | return len(self._indices)
86 |
87 | def __repr__(self):
88 | return "{}(n_samples={})".format(self.__class__.__name__, len(self))
89 |
90 | # Getter and setter for data
91 |
92 | @property
93 | def data(self):
94 | return self.__data
95 |
96 | @data.setter
97 | def data(self, value):
98 | assert value is not None
99 | self.__data = self.check_input(value)
100 |
101 | @property
102 | def trend(self):
103 | return self.__trend
104 |
105 | @trend.setter
106 | def trend(self, value):
107 | self.__trend = self.check_input(value)
108 |
109 | # Setter for exogenous data
110 |
111 | def add_exogenous(self, obj, name, for_window=True, for_horizon=False):
112 | assert isinstance(name, str)
113 | if name.endswith('_window'):
114 | name = name[:-7]
115 | for_window, for_horizon = True, False
116 | if name.endswith('_horizon'):
117 | name = name[:-8]
118 | for_window, for_horizon = False, True
119 | if name in self.__reserved_signature:
120 | raise ValueError("Channel '{0}' cannot be added in this way. Use obj.{0} instead.".format(name))
121 | if not (for_window or for_horizon):
122 | raise ValueError("Either for_window or for_horizon must be True.")
123 | obj = self.check_input(obj)
124 | setattr(self, name, obj)
125 | self.__exogenous_keys[name] = dict(for_window=for_window, for_horizon=for_horizon)
126 | return self
127 |
128 | # Dataset properties
129 |
130 | @property
131 | def horizon_offset(self):
132 | return self.window + self.delay
133 |
134 | @property
135 | def sample_span(self):
136 | return max(self.horizon_offset + self.horizon, self.window)
137 |
138 | @property
139 | def preprocess(self):
140 | return (self.trend is not None) or (self.scaler is not None)
141 |
142 | @property
143 | def n_steps(self):
144 | return self.data.shape[0]
145 |
146 | @property
147 | def n_channels(self):
148 | return self.data.shape[-1]
149 |
150 | @property
151 | def indices(self):
152 | return self._indices
153 |
154 | # Signature information
155 |
156 | @property
157 | def exo_window_keys(self):
158 | return {k for k, v in self.__exogenous_keys.items() if v['for_window']}
159 |
160 | @property
161 | def exo_horizon_keys(self):
162 | return {k for k, v in self.__exogenous_keys.items() if v['for_horizon']}
163 |
164 | @property
165 | def exo_common_keys(self):
166 | return self.exo_window_keys.intersection(self.exo_horizon_keys)
167 |
168 | @property
169 | def signature(self):
170 | attrs = []
171 | if self.window > 0:
172 | attrs.append('x')
173 | for attr in self.exo_window_keys:
174 | attrs.append(attr if attr not in self.exo_common_keys else (attr + '_window'))
175 | for attr in self.exo_horizon_keys:
176 | attrs.append(attr if attr not in self.exo_common_keys else (attr + '_horizon'))
177 | attrs.append('y')
178 | attrs = tuple(attrs)
179 | preprocess = []
180 | if self.trend is not None:
181 | preprocess.append('trend')
182 | if self.scaler is not None:
183 | preprocess.extend(self.scaler.params())
184 | preprocess = tuple(preprocess)
185 | return dict(data=attrs, preprocessing=preprocess)
186 |
187 | # Item getters
188 |
189 | def get(self, item, preprocess=False):
190 | idx = self._indices[item]
191 | res, transform = dict(), dict()
192 | if self.window > 0:
193 | res['x'] = self.data[idx:idx + self.window]
194 | for attr in self.exo_window_keys:
195 | key = attr if attr not in self.exo_common_keys else (attr + '_window')
196 | res[key] = getattr(self, attr)[idx:idx + self.window]
197 | for attr in self.exo_horizon_keys:
198 | key = attr if attr not in self.exo_common_keys else (attr + '_horizon')
199 | res[key] = getattr(self, attr)[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
200 | res['y'] = self.data[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
201 | if preprocess:
202 | if self.trend is not None:
203 | y_trend = self.trend[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
204 | res['y'] = res['y'] - y_trend
205 | transform['trend'] = y_trend
206 | if 'x' in res:
207 | res['x'] = res['x'] - self.trend[idx:idx + self.window]
208 | if self.scaler is not None:
209 | transform.update(self.scaler.params())
210 | if 'x' in res:
211 | res['x'] = self.scaler.transform(res['x'])
212 | return res, transform
213 |
214 | def snapshot(self, indices=None, preprocess=True):
215 | if not self.preprocess:
216 | preprocess = False
217 | data, prep = [{k: [] for k in sign} for sign in self.signature.values()]
218 | indices = np.arange(len(self._indices)) if indices is None else indices
219 | for idx in indices:
220 | data_i, prep_i = self.get(idx, preprocess)
221 | [v.append(data_i[k]) for k, v in data.items()]
222 | if len(prep_i):
223 | [v.append(prep_i[k]) for k, v in prep.items()]
224 | data = {k: np.stack(ds) for k, ds in data.items() if len(ds)}
225 | if len(prep):
226 | prep = {k: np.stack(ds) if k == 'trend' else ds[0] for k, ds in prep.items() if len(ds)}
227 | return data, prep
228 |
229 | # Data utilities
230 |
231 | def expand_indices(self, indices=None, unique=False, merge=False):
232 | ds_indices = dict.fromkeys([time for time in ['window', 'horizon'] if getattr(self, time) > 0])
233 | indices = np.arange(len(self._indices)) if indices is None else indices
234 | if 'window' in ds_indices:
235 | w_idxs = [np.arange(idx, idx + self.window) for idx in self._indices[indices]]
236 | ds_indices['window'] = np.concatenate(w_idxs)
237 | if 'horizon' in ds_indices:
238 | h_idxs = [np.arange(idx + self.horizon_offset, idx + self.horizon_offset + self.horizon)
239 | for idx in self._indices[indices]]
240 | ds_indices['horizon'] = np.concatenate(h_idxs)
241 | if unique:
242 | ds_indices = {k: np.unique(v) for k, v in ds_indices.items()}
243 | if merge:
244 | ds_indices = np.unique(np.concatenate(list(ds_indices.values())))
245 | return ds_indices
246 |
247 | def overlapping_indices(self, idxs1, idxs2, synch_mode='window', as_mask=False):
248 | assert synch_mode in ['window', 'horizon']
249 | ts1 = self.data_timestamps(idxs1, flatten=False)[synch_mode]
250 | ts2 = self.data_timestamps(idxs2, flatten=False)[synch_mode]
251 | common_ts = np.intersect1d(np.unique(ts1), np.unique(ts2))
252 | is_overlapping = lambda sample: np.any(np.in1d(sample, common_ts))
253 | m1 = np.apply_along_axis(is_overlapping, 1, ts1)
254 | m2 = np.apply_along_axis(is_overlapping, 1, ts2)
255 | if as_mask:
256 | return m1, m2
257 | return np.sort(idxs1[m1]), np.sort(idxs2[m2])
258 |
259 | def data_timestamps(self, indices=None, flatten=True):
260 | ds_indices = self.expand_indices(indices, unique=False)
261 | ds_timestamps = {k: self.index[v] for k, v in ds_indices.items()}
262 | if not flatten:
263 | ds_timestamps = {k: np.array(v).reshape(-1, getattr(self, k)) for k, v in ds_timestamps.items()}
264 | return ds_timestamps
265 |
266 | def reduce_dataset(self, indices, inplace=False):
267 | if not inplace:
268 | from copy import deepcopy
269 | dataset = deepcopy(self)
270 | else:
271 | dataset = self
272 | old_index = dataset.index[dataset._indices[indices]]
273 | ds_indices = dataset.expand_indices(indices, merge=True)
274 | dataset.index = dataset.index[ds_indices]
275 | dataset.data = dataset.data[ds_indices]
276 | if dataset.mask is not None:
277 | dataset.mask = dataset.mask[ds_indices]
278 | if dataset.trend is not None:
279 | dataset.trend = dataset.trend[ds_indices]
280 | for attr in dataset.exo_window_keys.union(dataset.exo_horizon_keys):
281 | if getattr(dataset, attr, None) is not None:
282 | setattr(dataset, attr, getattr(dataset, attr)[ds_indices])
283 | dataset._indices = np.flatnonzero(np.in1d(dataset.index, old_index))
284 | return dataset
285 |
286 | def check_input(self, data):
287 | if data is None:
288 | return data
289 | data = self.check_dim(data)
290 | data = data.clone().detach() if isinstance(data, torch.Tensor) else torch.tensor(data)
291 | # cast data
292 | if torch.is_floating_point(data):
293 | return data.float()
294 | elif data.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]:
295 | return data.int()
296 | return data
297 |
298 | # Class-specific methods (override in children)
299 |
300 | @staticmethod
301 | def check_dim(data):
302 | if data.ndim == 1: # [steps] -> [steps, features]
303 | data = rearrange(data, '(s f) -> s f', f=1)
304 | elif data.ndim != 2:
305 | raise ValueError(f'Invalid data dimensions {data.shape}')
306 | return data
307 |
308 | def dataframe(self):
309 | return pd.DataFrame(data=self.data, index=self.index)
310 |
311 | @staticmethod
312 | def add_argparse_args(parser, **kwargs):
313 | parser.add_argument('--window', type=int, default=24)
314 | parser.add_argument('--horizon', type=int, default=24)
315 | parser.add_argument('--delay', type=int, default=0)
316 | parser.add_argument('--stride', type=int, default=1)
317 | return parser
318 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .air_quality import AirQuality
2 | from .metr_la import MissingValuesMetrLA
3 | from .pems_bay import MissingValuesPemsBay
4 | from .synthetic import ChargedParticles
5 |
--------------------------------------------------------------------------------
/lib/datasets/air_quality.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from lib import datasets_path
7 | from .pd_dataset import PandasDataset
8 | from ..utils.utils import disjoint_months, infer_mask, compute_mean, geographical_distance, thresholded_gaussian_kernel
9 |
10 |
11 | class AirQuality(PandasDataset):
12 | SEED = 3210
13 |
14 | def __init__(self, impute_nans=False, small=False, freq='60T', masked_sensors=None):
15 | self.random = np.random.default_rng(self.SEED)
16 | self.test_months = [3, 6, 9, 12]
17 | self.infer_eval_from = 'next'
18 | self.eval_mask = None
19 | df, dist, mask = self.load(impute_nans=impute_nans, small=small, masked_sensors=masked_sensors)
20 | self.dist = dist
21 | if masked_sensors is None:
22 | self.masked_sensors = list()
23 | else:
24 | self.masked_sensors = list(masked_sensors)
25 | super().__init__(dataframe=df, u=None, mask=mask, name='air', freq=freq, aggr='nearest')
26 |
27 | def load_raw(self, small=False):
28 | if small:
29 | path = os.path.join(datasets_path['air'], 'small36.h5')
30 | eval_mask = pd.DataFrame(pd.read_hdf(path, 'eval_mask'))
31 | else:
32 | path = os.path.join(datasets_path['air'], 'full437.h5')
33 | eval_mask = None
34 | df = pd.DataFrame(pd.read_hdf(path, 'pm25'))
35 | stations = pd.DataFrame(pd.read_hdf(path, 'stations'))
36 | return df, stations, eval_mask
37 |
38 | def load(self, impute_nans=True, small=False, masked_sensors=None):
39 | # load readings and stations metadata
40 | df, stations, eval_mask = self.load_raw(small)
41 | # compute the masks
42 | mask = (~np.isnan(df.values)).astype('uint8') # 1 if value is not nan else 0
43 | if eval_mask is None:
44 | eval_mask = infer_mask(df, infer_from=self.infer_eval_from)
45 |
46 | eval_mask = eval_mask.values.astype('uint8')
47 | if masked_sensors is not None:
48 | eval_mask[:, masked_sensors] = np.where(mask[:, masked_sensors], 1, 0)
49 | self.eval_mask = eval_mask # 1 if value is ground-truth for imputation else 0
50 | # eventually replace nans with weekly mean by hour
51 | if impute_nans:
52 | df = df.fillna(compute_mean(df))
53 | # compute distances from latitude and longitude degrees
54 | st_coord = stations.loc[:, ['latitude', 'longitude']]
55 | dist = geographical_distance(st_coord, to_rad=True).values
56 | return df, dist, mask
57 |
58 | def splitter(self, dataset, val_len=1., in_sample=False, window=0):
59 | nontest_idxs, test_idxs = disjoint_months(dataset, months=self.test_months, synch_mode='horizon')
60 | if in_sample:
61 | train_idxs = np.arange(len(dataset))
62 | val_months = [(m - 1) % 12 for m in self.test_months]
63 | _, val_idxs = disjoint_months(dataset, months=val_months, synch_mode='horizon')
64 | else:
65 | # take equal number of samples before each month of testing
66 | val_len = (int(val_len * len(nontest_idxs)) if val_len < 1 else val_len) // len(self.test_months)
67 | # get indices of first day of each testing month
68 | delta_idxs = np.diff(test_idxs)
69 | end_month_idxs = test_idxs[1:][np.flatnonzero(delta_idxs > delta_idxs.min())]
70 | if len(end_month_idxs) < len(self.test_months):
71 | end_month_idxs = np.insert(end_month_idxs, 0, test_idxs[0])
72 | # expand month indices
73 | month_val_idxs = [np.arange(v_idx - val_len, v_idx) - window for v_idx in end_month_idxs]
74 | val_idxs = np.concatenate(month_val_idxs) % len(dataset)
75 | # remove overlapping indices from training set
76 | ovl_idxs, _ = dataset.overlapping_indices(nontest_idxs, val_idxs, synch_mode='horizon', as_mask=True)
77 | train_idxs = nontest_idxs[~ovl_idxs]
78 | return [train_idxs, val_idxs, test_idxs]
79 |
80 | def get_similarity(self, thr=0.1, include_self=False, force_symmetric=False, sparse=False, **kwargs):
81 | theta = np.std(self.dist[:36, :36]) # use same theta for both air and air36
82 | adj = thresholded_gaussian_kernel(self.dist, theta=theta, threshold=thr)
83 | if not include_self:
84 | adj[np.diag_indices_from(adj)] = 0.
85 | if force_symmetric:
86 | adj = np.maximum.reduce([adj, adj.T])
87 | if sparse:
88 | import scipy.sparse as sps
89 | adj = sps.coo_matrix(adj)
90 | return adj
91 |
92 | @property
93 | def mask(self):
94 | return self._mask
95 |
96 | @property
97 | def training_mask(self):
98 | return self._mask if self.eval_mask is None else (self._mask & (1 - self.eval_mask))
99 |
100 | def test_interval_mask(self, dtype=bool, squeeze=True):
101 | m = np.in1d(self.df.index.month, self.test_months).astype(dtype)
102 | if squeeze:
103 | return m
104 | return m[:, None]
105 |
--------------------------------------------------------------------------------
/lib/datasets/metr_la.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from lib import datasets_path
7 | from .pd_dataset import PandasDataset
8 | from ..utils import sample_mask
9 |
10 |
11 | class MetrLA(PandasDataset):
12 | def __init__(self, impute_zeros=False, freq='5T'):
13 |
14 | df, dist, mask = self.load(impute_zeros=impute_zeros)
15 | self.dist = dist
16 | super().__init__(dataframe=df, u=None, mask=mask, name='la', freq=freq, aggr='nearest')
17 |
18 | def load(self, impute_zeros=True):
19 | path = os.path.join(datasets_path['la'], 'metr_la.h5')
20 | df = pd.read_hdf(path)
21 | datetime_idx = sorted(df.index)
22 | date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')
23 | df = df.reindex(index=date_range)
24 | mask = ~np.isnan(df.values)
25 | if impute_zeros:
26 | mask = mask * (df.values != 0.).astype('uint8')
27 | df = df.replace(to_replace=0., method='ffill')
28 | else:
29 | mask = None
30 | dist = self.load_distance_matrix()
31 | return df, dist, mask
32 |
33 | def load_distance_matrix(self):
34 | path = os.path.join(datasets_path['la'], 'metr_la_dist.npy')
35 | try:
36 | dist = np.load(path)
37 | except:
38 | distances = pd.read_csv(os.path.join(datasets_path['la'], 'distances_la.csv'))
39 | with open(os.path.join(datasets_path['la'], 'sensor_ids_la.txt')) as f:
40 | ids = f.read().strip().split(',')
41 | num_sensors = len(ids)
42 | dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf
43 | # Builds sensor id to index map.
44 | sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}
45 |
46 | # Fills cells in the matrix with distances.
47 | for row in distances.values:
48 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
49 | continue
50 | dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]
51 | np.save(path, dist)
52 | return dist
53 |
54 | def get_similarity(self, thr=0.1, force_symmetric=False, sparse=False):
55 | finite_dist = self.dist.reshape(-1)
56 | finite_dist = finite_dist[~np.isinf(finite_dist)]
57 | sigma = finite_dist.std()
58 | adj = np.exp(-np.square(self.dist / sigma))
59 | adj[adj < thr] = 0.
60 | if force_symmetric:
61 | adj = np.maximum.reduce([adj, adj.T])
62 | if sparse:
63 | import scipy.sparse as sps
64 | adj = sps.coo_matrix(adj)
65 | return adj
66 |
67 | @property
68 | def mask(self):
69 | return self._mask
70 |
71 |
72 | class MissingValuesMetrLA(MetrLA):
73 | SEED = 9101112
74 |
75 | def __init__(self, p_fault=0.0015, p_noise=0.05):
76 | super(MissingValuesMetrLA, self).__init__(impute_zeros=True)
77 | self.rng = np.random.default_rng(self.SEED)
78 | self.p_fault = p_fault
79 | self.p_noise = p_noise
80 | eval_mask = sample_mask(self.numpy().shape,
81 | p=p_fault,
82 | p_noise=p_noise,
83 | min_seq=12,
84 | max_seq=12 * 4,
85 | rng=self.rng)
86 | self.eval_mask = (eval_mask & self.mask).astype('uint8')
87 |
88 | @property
89 | def training_mask(self):
90 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))
91 |
92 | def splitter(self, dataset, val_len=0, test_len=0, window=0):
93 | idx = np.arange(len(dataset))
94 | if test_len < 1:
95 | test_len = int(test_len * len(idx))
96 | if val_len < 1:
97 | val_len = int(val_len * (len(idx) - test_len))
98 | test_start = len(idx) - test_len
99 | val_start = test_start - val_len
100 | return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]
--------------------------------------------------------------------------------
/lib/datasets/pd_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 |
5 |
6 | class PandasDataset:
7 | def __init__(self, dataframe: pd.DataFrame, u: pd.DataFrame = None, name='pd-dataset', mask=None, freq=None,
8 | aggr='sum', **kwargs):
9 | """
10 | Initialize a tsl dataset from a pandas dataframe.
11 |
12 |
13 | :param dataframe: dataframe containing the data, shape: n_steps, n_nodes
14 | :param u: dataframe with exog variables
15 | :param name: optional name of the dataset
16 | :param mask: mask for valid data (1:valid, 0:not valid)
17 | :param freq: force a frequency (possibly by resampling)
18 | :param aggr: aggregation method after resampling
19 | """
20 | super().__init__()
21 | self.name = name
22 |
23 | # set dataset dataframe
24 | self.df = dataframe
25 |
26 | # set optional exog_variable dataframe
27 | # make sure to consider only the overlapping part of the two dataframes
28 | # assumption u.index \in df.index
29 | idx = sorted(self.df.index)
30 | self.start = idx[0]
31 | self.end = idx[-1]
32 |
33 | if u is not None:
34 | self.u = u[self.start:self.end]
35 | else:
36 | self.u = None
37 |
38 | if mask is not None:
39 | mask = np.asarray(mask).astype('uint8')
40 | self._mask = mask
41 |
42 | if freq is not None:
43 | self.resample_(freq=freq, aggr=aggr)
44 | else:
45 | self.freq = self.df.index.inferred_freq
46 | # make sure that all the dataframes are aligned
47 | self.resample_(self.freq, aggr=aggr)
48 |
49 | assert 'T' in self.freq
50 | self.samples_per_day = int(60 / int(self.freq[:-1]) * 24)
51 |
52 | def __repr__(self):
53 | return "{}(nodes={}, length={})".format(self.__class__.__name__, self.n_nodes, self.length)
54 |
55 | @property
56 | def has_mask(self):
57 | return self._mask is not None
58 |
59 | @property
60 | def has_u(self):
61 | return self.u is not None
62 |
63 | def resample_(self, freq, aggr):
64 | resampler = self.df.resample(freq)
65 | idx = self.df.index
66 | if aggr == 'sum':
67 | self.df = resampler.sum()
68 | elif aggr == 'mean':
69 | self.df = resampler.mean()
70 | elif aggr == 'nearest':
71 | self.df = resampler.nearest()
72 | else:
73 | raise ValueError(f'{aggr} if not a valid aggregation method.')
74 |
75 | if self.has_mask:
76 | resampler = pd.DataFrame(self._mask, index=idx).resample(freq)
77 | self._mask = resampler.min().to_numpy()
78 |
79 | if self.has_u:
80 | resampler = self.u.resample(freq)
81 | self.u = resampler.nearest()
82 | self.freq = freq
83 |
84 | def dataframe(self) -> pd.DataFrame:
85 | return self.df.copy()
86 |
87 | @property
88 | def length(self):
89 | return self.df.values.shape[0]
90 |
91 | @property
92 | def n_nodes(self):
93 | return self.df.values.shape[1]
94 |
95 | @property
96 | def mask(self):
97 | if self._mask is None:
98 | return np.ones_like(self.df.values).astype('uint8')
99 | return self._mask
100 |
101 | def numpy(self, return_idx=False):
102 | if return_idx:
103 | return self.numpy(), self.df.index
104 | return self.df.values
105 |
106 | def pytorch(self):
107 | data = self.numpy()
108 | return torch.FloatTensor(data)
109 |
110 | def __len__(self):
111 | return self.length
112 |
113 | @staticmethod
114 | def build():
115 | raise NotImplementedError
116 |
117 | def load_raw(self):
118 | raise NotImplementedError
119 |
120 | def load(self):
121 | raise NotImplementedError
122 |
--------------------------------------------------------------------------------
/lib/datasets/pems_bay.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from lib import datasets_path
7 | from .pd_dataset import PandasDataset
8 | from ..utils import sample_mask
9 |
10 |
11 | class PemsBay(PandasDataset):
12 | def __init__(self):
13 | df, dist, mask = self.load()
14 | self.dist = dist
15 | super().__init__(dataframe=df, u=None, mask=mask, name='bay', freq='5T', aggr='nearest')
16 |
17 | def load(self, impute_zeros=True):
18 | path = os.path.join(datasets_path['bay'], 'pems_bay.h5')
19 | df = pd.read_hdf(path)
20 | datetime_idx = sorted(df.index)
21 | date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')
22 | df = df.reindex(index=date_range)
23 | mask = ~np.isnan(df.values)
24 | df.fillna(method='ffill', axis=0, inplace=True)
25 | dist = self.load_distance_matrix(list(df.columns))
26 | return df.astype('float32'), dist, mask.astype('uint8')
27 |
28 | def load_distance_matrix(self, ids):
29 | path = os.path.join(datasets_path['bay'], 'pems_bay_dist.npy')
30 | try:
31 | dist = np.load(path)
32 | except:
33 | distances = pd.read_csv(os.path.join(datasets_path['bay'], 'distances_bay.csv'))
34 | num_sensors = len(ids)
35 | dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf
36 | # Builds sensor id to index map.
37 | sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}
38 |
39 | # Fills cells in the matrix with distances.
40 | for row in distances.values:
41 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
42 | continue
43 | dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]
44 | np.save(path, dist)
45 | return dist
46 |
47 | def get_similarity(self, type='dcrnn', thr=0.1, force_symmetric=False, sparse=False):
48 | """
49 | Return similarity matrix among nodes. Implemented to match DCRNN.
50 |
51 | :param type: type of similarity matrix.
52 | :param thr: threshold to increase saprseness.
53 | :param trainlen: number of steps that can be used for computing the similarity.
54 | :param force_symmetric: force the result to be simmetric.
55 | :return: and NxN array representig similarity among nodes.
56 | """
57 | if type == 'dcrnn':
58 | finite_dist = self.dist.reshape(-1)
59 | finite_dist = finite_dist[~np.isinf(finite_dist)]
60 | sigma = finite_dist.std()
61 | adj = np.exp(-np.square(self.dist / sigma))
62 | elif type == 'stcn':
63 | sigma = 10
64 | adj = np.exp(-np.square(self.dist) / sigma)
65 | else:
66 | raise NotImplementedError
67 | adj[adj < thr] = 0.
68 | if force_symmetric:
69 | adj = np.maximum.reduce([adj, adj.T])
70 | if sparse:
71 | import scipy.sparse as sps
72 | adj = sps.coo_matrix(adj)
73 | return adj
74 |
75 | @property
76 | def mask(self):
77 | if self._mask is None:
78 | return self.df.values != 0.
79 | return self._mask
80 |
81 |
82 | class MissingValuesPemsBay(PemsBay):
83 | SEED = 56789
84 |
85 | def __init__(self, p_fault=0.0015, p_noise=0.05):
86 | super(MissingValuesPemsBay, self).__init__()
87 | self.rng = np.random.default_rng(self.SEED)
88 | self.p_fault = p_fault
89 | self.p_noise = p_noise
90 | eval_mask = sample_mask(self.numpy().shape,
91 | p=p_fault,
92 | p_noise=p_noise,
93 | min_seq=12,
94 | max_seq=12 * 4,
95 | rng=self.rng)
96 | self.eval_mask = (eval_mask & self.mask).astype('uint8')
97 |
98 | @property
99 | def training_mask(self):
100 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))
101 |
102 | def splitter(self, dataset, val_len=0, test_len=0, window=0):
103 | idx = np.arange(len(dataset))
104 | if test_len < 1:
105 | test_len = int(test_len * len(idx))
106 | if val_len < 1:
107 | val_len = int(val_len * (len(idx) - test_len))
108 | test_start = len(idx) - test_len
109 | val_start = test_start - val_len
110 | return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]
111 |
--------------------------------------------------------------------------------
/lib/datasets/synthetic.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy as np
4 | import torch
5 | from einops import rearrange
6 | from torch.utils.data import Dataset, DataLoader, Subset
7 |
8 | from lib import datasets_path
9 |
10 |
11 | def generate_mask(shape, p_block=0.01, p_point=0.01, max_seq=1, min_seq=1, rng=None):
12 | """Generate mask in which 1 denotes valid values, 0 missing ones. Assuming shape=(steps, ...)."""
13 | if rng is None:
14 | rand = np.random.random
15 | randint = np.random.randint
16 | else:
17 | rand = rng.random
18 | randint = rng.integers
19 | # init mask
20 | mask = np.ones(shape, dtype='uint8')
21 | # block missing
22 | if p_block > 0:
23 | assert max_seq >= min_seq
24 | for col in range(shape[1]):
25 | i = 0
26 | while i < shape[0]:
27 | if rand() > p_block:
28 | i += 1
29 | else:
30 | fault_len = int(randint(min_seq, max_seq + 1))
31 | mask[i:i + fault_len, col] = 0
32 | i += fault_len + 1 # at least one valid value between two blocks
33 | # point missing
34 | # let values before and after block missing always valid
35 | diff = np.zeros(mask.shape, dtype='uint8')
36 | diff[:-1] |= np.diff(mask, axis=0) < 0
37 | diff[1:] |= np.diff(mask, axis=0) > 0
38 | mask = np.where(mask - diff, rand(shape) > p_point, mask)
39 | return mask
40 |
41 |
42 | class SyntheticDataset(Dataset):
43 | SEED: int
44 |
45 | def __init__(self, filename,
46 | window=None,
47 | p_block=0.05,
48 | p_point=0.05,
49 | max_seq=6,
50 | min_seq=4,
51 | use_exogenous=True,
52 | mask_exogenous=True,
53 | graph_mode=True):
54 | super(SyntheticDataset, self).__init__()
55 | self.mask_exogenous = mask_exogenous
56 | self.use_exogenous = use_exogenous
57 | self.graph_mode = graph_mode
58 | # fetch data
59 | content = self.load(filename)
60 | self.window = window if window is not None else content['loc'].shape[1]
61 | self.loc = torch.tensor(content['loc'][:, :self.window]).float()
62 | self.vel = torch.tensor(content['vel'][:, :self.window]).float()
63 | self.adj = content['adj']
64 | self.SEED = content['seed'].item()
65 | # compute masks
66 | self.rng = np.random.default_rng(self.SEED)
67 | mask_shape = (len(self), self.window, self.n_nodes, 1)
68 | mask = generate_mask(mask_shape,
69 | p_block=p_block,
70 | p_point=p_point,
71 | max_seq=max_seq,
72 | min_seq=min_seq,
73 | rng=self.rng).repeat(self.n_channels, -1)
74 | eval_mask = 1 - generate_mask(mask_shape,
75 | p_block=p_block,
76 | p_point=p_point,
77 | max_seq=max_seq,
78 | min_seq=min_seq,
79 | rng=self.rng).repeat(self.n_channels, -1)
80 | self.mask = torch.tensor(mask).byte()
81 | self.eval_mask = torch.tensor(eval_mask).byte() & self.mask
82 | # store splitting indices
83 | self.train_idxs = None
84 | self.val_idxs = None
85 | self.test_idxs = None
86 |
87 | def __len__(self):
88 | return self.loc.size(0)
89 |
90 | def __getitem__(self, index):
91 | eval_mask = self.eval_mask[index]
92 | mask = self.training_mask[index]
93 | x = mask * self.loc[index]
94 | res = dict(x=x, mask=mask, eval_mask=eval_mask)
95 | if self.use_exogenous:
96 | u = self.vel[index]
97 | if self.mask_exogenous:
98 | u *= mask.all(-1, keepdims=True)
99 | res.update(u=u)
100 | res.update(y=self.loc[index])
101 | if not self.graph_mode:
102 | res = {k: rearrange(v, 's n f -> s (n f)') for k, v in res.items()}
103 | return res
104 |
105 | @property
106 | def n_channels(self):
107 | return self.loc.size(-1)
108 |
109 | @property
110 | def n_nodes(self):
111 | return self.loc.size(-2)
112 |
113 | @property
114 | def n_exogenous(self):
115 | return self.vel.size(-1) if self.use_exogenous else 0
116 |
117 | @property
118 | def training_mask(self):
119 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))
120 |
121 | @staticmethod
122 | def load(filename):
123 | return np.load(filename)
124 |
125 | def get_similarity(self, sparse=False):
126 | return self.adj
127 |
128 | # Splitting options
129 |
130 | def split(self, val_len=0, test_len=0):
131 | idx = np.arange(len(self))
132 | if test_len < 1:
133 | test_len = int(test_len * len(idx))
134 | if val_len < 1:
135 | val_len = int(val_len * (len(idx) - test_len))
136 | test_start = len(idx) - test_len
137 | val_start = test_start - val_len
138 | # split dataset
139 | self.train_idxs = idx[:val_start]
140 | self.val_idxs = idx[val_start:test_start]
141 | self.test_idxs = idx[test_start:]
142 |
143 | def train_dataloader(self, shuffle=True, batch_size=32):
144 | return DataLoader(Subset(self, self.train_idxs), shuffle=shuffle, batch_size=batch_size, drop_last=True)
145 |
146 | def val_dataloader(self, shuffle=False, batch_size=32):
147 | return DataLoader(Subset(self, self.val_idxs), shuffle=shuffle, batch_size=batch_size)
148 |
149 | def test_dataloader(self, shuffle=False, batch_size=32):
150 | return DataLoader(Subset(self, self.test_idxs), shuffle=shuffle, batch_size=batch_size)
151 |
152 |
153 | class ChargedParticles(SyntheticDataset):
154 |
155 | def __init__(self, static_adj=False,
156 | window=None,
157 | p_block=0.05,
158 | p_point=0.05,
159 | max_seq=6,
160 | min_seq=4,
161 | use_exogenous=True,
162 | mask_exogenous=True,
163 | graph_mode=True):
164 | if static_adj:
165 | filename = os.path.join(datasets_path['synthetic'], 'charged_static.npz')
166 | else:
167 | filename = os.path.join(datasets_path['synthetic'], 'charged_varying.npz')
168 | self.static_adj = static_adj
169 | super(ChargedParticles, self).__init__(filename, window,
170 | p_block=p_block,
171 | p_point=p_point,
172 | max_seq=max_seq,
173 | min_seq=min_seq,
174 | use_exogenous=use_exogenous,
175 | mask_exogenous=mask_exogenous,
176 | graph_mode=graph_mode)
177 | charges = self.load(filename)['charges']
178 | self.charges = torch.tensor(charges).float()
179 |
180 | def __getitem__(self, item):
181 | res = super(ChargedParticles, self).__getitem__(item)
182 | # add charges as exogenous features
183 | if self.use_exogenous:
184 | charges = self.charges[item] if not self.static_adj else self.charges
185 | stacked_charges = charges[None].expand(self.window, -1, -1)
186 | if not self.graph_mode:
187 | stacked_charges = rearrange(stacked_charges, 's n f -> s (n f)')
188 | res.update(u=torch.cat([res['u'], stacked_charges], -1))
189 | return res
190 |
191 | def get_similarity(self, sparse=False):
192 | return np.ones((self.n_nodes, self.n_nodes)) - np.eye(self.n_nodes)
193 |
194 | @property
195 | def n_exogenous(self):
196 | if self.use_exogenous:
197 | return super(ChargedParticles, self).n_exogenous + 1 # add charges to features
198 | return 0
199 |
--------------------------------------------------------------------------------
/lib/fillers/__init__.py:
--------------------------------------------------------------------------------
1 | from .filler import Filler
2 | from .britsfiller import BRITSFiller
3 | from .graphfiller import GraphFiller
4 | from .rgainfiller import RGAINFiller
5 | from .multi_imputation_filler import MultiImputationFiller
6 |
--------------------------------------------------------------------------------
/lib/fillers/britsfiller.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from . import Filler
4 | from ..nn import BRITS
5 |
6 |
7 | class BRITSFiller(Filler):
8 |
9 | def training_step(self, batch, batch_idx):
10 | # Unpack batch
11 | batch_data, batch_preprocessing = self._unpack_batch(batch)
12 |
13 | # Extract mask and target
14 | mask = batch_data['mask'].clone().detach()
15 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
16 | eval_mask = batch_data.pop('eval_mask', None)
17 | y = batch_data.pop('y')
18 |
19 | # Compute predictions and compute loss
20 | out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)
21 |
22 | if self.scaled_target:
23 | target = self._preprocess(y, batch_preprocessing)
24 | else:
25 | target = y
26 | imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]
27 | predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]
28 |
29 | loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])
30 | loss += BRITS.consistency_loss(*imputations)
31 |
32 | # Logging
33 | metrics_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data
34 | out = self._postprocess(out, batch_preprocessing)
35 | self.train_metrics.update(out.detach(), y, metrics_mask)
36 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
37 | self.log('train_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=False)
38 | return loss
39 |
40 | def validation_step(self, batch, batch_idx):
41 | # Unpack batch
42 | batch_data, batch_preprocessing = self._unpack_batch(batch)
43 |
44 | # Extract mask and target
45 | mask = batch_data.get('mask')
46 | eval_mask = batch_data.pop('eval_mask', None)
47 | y = batch_data.pop('y')
48 |
49 | # Compute predictions and compute loss
50 | out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)
51 |
52 | if self.scaled_target:
53 | target = self._preprocess(y, batch_preprocessing)
54 | else:
55 | target = y
56 | predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]
57 |
58 | val_loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])
59 |
60 | # Logging
61 | out = self._postprocess(out, batch_preprocessing)
62 | self.val_metrics.update(out.detach(), y, eval_mask)
63 | self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
64 | self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
65 | return val_loss
66 |
67 | def test_step(self, batch, batch_idx):
68 | # Unpack batch
69 | batch_data, batch_preprocessing = self._unpack_batch(batch)
70 |
71 | # Extract mask and target
72 | eval_mask = batch_data.pop('eval_mask', None)
73 | y = batch_data.pop('y')
74 |
75 | # Compute outputs and rescale
76 | imputation, *_ = self.predict_batch(batch, preprocess=False, postprocess=True)
77 | test_loss = self.loss_fn(imputation, y, eval_mask)
78 |
79 | # Logging
80 | self.test_metrics.update(imputation.detach(), y, eval_mask)
81 | self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
82 | self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
83 | return test_loss
84 |
--------------------------------------------------------------------------------
/lib/fillers/graphfiller.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from . import Filler
4 | from ..nn.models import MPGRUNet, GRINet, BiMPGRUNet
5 |
6 |
7 | class GraphFiller(Filler):
8 |
9 | def __init__(self,
10 | model_class,
11 | model_kwargs,
12 | optim_class,
13 | optim_kwargs,
14 | loss_fn,
15 | scaled_target=False,
16 | whiten_prob=0.05,
17 | pred_loss_weight=1.,
18 | warm_up=0,
19 | metrics=None,
20 | scheduler_class=None,
21 | scheduler_kwargs=None):
22 | super(GraphFiller, self).__init__(model_class=model_class,
23 | model_kwargs=model_kwargs,
24 | optim_class=optim_class,
25 | optim_kwargs=optim_kwargs,
26 | loss_fn=loss_fn,
27 | scaled_target=scaled_target,
28 | whiten_prob=whiten_prob,
29 | metrics=metrics,
30 | scheduler_class=scheduler_class,
31 | scheduler_kwargs=scheduler_kwargs)
32 |
33 | self.tradeoff = pred_loss_weight
34 | if model_class is MPGRUNet:
35 | self.trimming = (warm_up, 0)
36 | elif model_class in [GRINet, BiMPGRUNet]:
37 | self.trimming = (warm_up, warm_up)
38 |
39 | def trim_seq(self, *seq):
40 | seq = [s[:, self.trimming[0]:s.size(1) - self.trimming[1]] for s in seq]
41 | if len(seq) == 1:
42 | return seq[0]
43 | return seq
44 |
45 | def training_step(self, batch, batch_idx):
46 | # Unpack batch
47 | batch_data, batch_preprocessing = self._unpack_batch(batch)
48 |
49 | # Compute masks
50 | mask = batch_data['mask'].clone().detach()
51 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
52 | eval_mask = batch_data.pop('eval_mask', None)
53 | eval_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data
54 |
55 | y = batch_data.pop('y')
56 |
57 | # Compute predictions and compute loss
58 | res = self.predict_batch(batch, preprocess=False, postprocess=False)
59 | imputation, predictions = (res[0], res[1:]) if isinstance(res, (list, tuple)) else (res, [])
60 |
61 | # trim to imputation horizon len
62 | imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)
63 | predictions = self.trim_seq(*predictions)
64 |
65 | if self.scaled_target:
66 | target = self._preprocess(y, batch_preprocessing)
67 | else:
68 | target = y
69 | imputation = self._postprocess(imputation, batch_preprocessing)
70 | for i, _ in enumerate(predictions):
71 | predictions[i] = self._postprocess(predictions[i], batch_preprocessing)
72 |
73 | loss = self.loss_fn(imputation, target, mask)
74 | for pred in predictions:
75 | loss += self.tradeoff * self.loss_fn(pred, target, mask)
76 |
77 | # Logging
78 | if self.scaled_target:
79 | imputation = self._postprocess(imputation, batch_preprocessing)
80 | self.train_metrics.update(imputation.detach(), y, eval_mask) # all unseen data
81 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
82 | self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
83 | return loss
84 |
85 | def validation_step(self, batch, batch_idx):
86 | # Unpack batch
87 | batch_data, batch_preprocessing = self._unpack_batch(batch)
88 |
89 | # Extract mask and target
90 | mask = batch_data.get('mask')
91 | eval_mask = batch_data.pop('eval_mask', None)
92 | y = batch_data.pop('y')
93 |
94 | # Compute predictions and compute loss
95 | imputation = self.predict_batch(batch, preprocess=False, postprocess=False)
96 |
97 | # trim to imputation horizon len
98 | imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)
99 |
100 | if self.scaled_target:
101 | target = self._preprocess(y, batch_preprocessing)
102 | else:
103 | target = y
104 | imputation = self._postprocess(imputation, batch_preprocessing)
105 |
106 | val_loss = self.loss_fn(imputation, target, eval_mask)
107 |
108 | # Logging
109 | if self.scaled_target:
110 | imputation = self._postprocess(imputation, batch_preprocessing)
111 | self.val_metrics.update(imputation.detach(), y, eval_mask)
112 | self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
113 | self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
114 | return val_loss
115 |
116 | def test_step(self, batch, batch_idx):
117 | # Unpack batch
118 | batch_data, batch_preprocessing = self._unpack_batch(batch)
119 |
120 | # Extract mask and target
121 | eval_mask = batch_data.pop('eval_mask', None)
122 | y = batch_data.pop('y')
123 |
124 | # Compute outputs and rescale
125 | imputation = self.predict_batch(batch, preprocess=False, postprocess=True)
126 | test_loss = self.loss_fn(imputation, y, eval_mask)
127 |
128 | # Logging
129 | self.test_metrics.update(imputation.detach(), y, eval_mask)
130 | self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
131 | self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
132 | return test_loss
133 |
--------------------------------------------------------------------------------
/lib/fillers/multi_imputation_filler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning.core.decorators import auto_move_data
3 |
4 | from . import Filler
5 |
6 |
7 | class MultiImputationFiller(Filler):
8 | """
9 | Filler with multiple imputation outputs
10 | """
11 |
12 | def __init__(self,
13 | model_class,
14 | model_kwargs,
15 | optim_class,
16 | optim_kwargs,
17 | loss_fn,
18 | consistency_loss=False,
19 | scaled_target=False,
20 | whiten_prob=0.05,
21 | metrics=None,
22 | scheduler_class=None,
23 | scheduler_kwargs=None):
24 |
25 | super().__init__(model_class,
26 | model_kwargs,
27 | optim_class,
28 | optim_kwargs,
29 | loss_fn,
30 | scaled_target,
31 | whiten_prob,
32 | metrics,
33 | scheduler_class,
34 | scheduler_kwargs)
35 | self.consistency_loss = consistency_loss
36 |
37 | @auto_move_data
38 | def forward(self, *args, **kwargs):
39 | out = self.model(*args, **kwargs)
40 | assert isinstance(out, (list, tuple))
41 | if self.training:
42 | return out
43 | return out[0] # we assume that the final imputation is the first one
44 |
45 | def _consistency_loss(self, imputations, mask):
46 | from itertools import combinations
47 | return sum([self.loss_fn(imp1, imp2, mask) for imp1, imp2 in combinations(imputations, 2)])
48 |
49 | def training_step(self, batch, batch_idx):
50 | # Unpack batch
51 | batch_data, batch_preprocessing = self._unpack_batch(batch)
52 |
53 | # Extract mask and target
54 | mask = batch_data['mask'].clone().detach()
55 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
56 | eval_mask = batch_data.pop('eval_mask', None)
57 | y = batch_data.pop('y')
58 |
59 | # Compute predictions and compute loss
60 | imputations = self.predict_batch(batch, preprocess=False, postprocess=False)
61 |
62 | if self.scaled_target:
63 | target = self._preprocess(y, batch_preprocessing)
64 | else:
65 | target = y
66 | imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]
67 |
68 | loss = sum([self.loss_fn(imp, target, mask) for imp in imputations])
69 | if self.consistency_loss:
70 | loss += self._consistency_loss(imputations, mask)
71 |
72 | # Logging
73 | metrics_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data
74 |
75 | x_hat = imputations[0]
76 | x_hat = self._postprocess(x_hat, batch_preprocessing)
77 | self.train_metrics.update(x_hat.detach(), y, metrics_mask)
78 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
79 | self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
80 | return loss
81 |
--------------------------------------------------------------------------------
/lib/fillers/rgainfiller.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | from .multi_imputation_filler import MultiImputationFiller
5 | from ..nn.utils.metric_base import MaskedMetric
6 |
7 |
8 | class MaskedBCEWithLogits(MaskedMetric):
9 | def __init__(self,
10 | mask_nans=False,
11 | mask_inf=False,
12 | compute_on_step=True,
13 | dist_sync_on_step=False,
14 | process_group=None,
15 | dist_sync_fn=None,
16 | at=None):
17 | super(MaskedBCEWithLogits, self).__init__(metric_fn=F.binary_cross_entropy_with_logits,
18 | mask_nans=mask_nans,
19 | mask_inf=mask_inf,
20 | compute_on_step=compute_on_step,
21 | dist_sync_on_step=dist_sync_on_step,
22 | process_group=process_group,
23 | dist_sync_fn=dist_sync_fn,
24 | metric_kwargs={'reduction': 'none'},
25 | at=at)
26 |
27 |
28 | class RGAINFiller(MultiImputationFiller):
29 | def __init__(self,
30 | model_class,
31 | model_kwargs,
32 | optim_class,
33 | optim_kwargs,
34 | loss_fn,
35 | g_train_freq=1,
36 | d_train_freq=5,
37 | consistency_loss=False,
38 | scaled_target=True,
39 | whiten_prob=0.05,
40 | hint_rate=0.7,
41 | alpha=10.,
42 | metrics=None,
43 | scheduler_class=None,
44 | scheduler_kwargs=None):
45 | super(RGAINFiller, self).__init__(model_class=model_class,
46 | model_kwargs=model_kwargs,
47 | optim_class=optim_class,
48 | optim_kwargs=optim_kwargs,
49 | loss_fn=loss_fn,
50 | scaled_target=scaled_target,
51 | whiten_prob=whiten_prob,
52 | metrics=metrics,
53 | consistency_loss=consistency_loss,
54 | scheduler_class=scheduler_class,
55 | scheduler_kwargs=scheduler_kwargs)
56 | # discriminator training params
57 | self.alpha = alpha
58 | self.g_train_freq = g_train_freq
59 | self.d_train_freq = d_train_freq
60 | self.masked_bce_loss = MaskedBCEWithLogits(compute_on_step=True)
61 | # activate manual optimization
62 | self.automatic_optimization = False
63 | self.hint_rate = hint_rate
64 |
65 | def training_step(self, batch, batch_idx):
66 | # Unpack batch
67 | batch_data, batch_preprocessing = self._unpack_batch(batch)
68 | g_opt, d_opt = self.optimizers()
69 | schedulers = self.lr_schedulers()
70 |
71 | # Extract mask and target
72 | x = batch_data.pop('x')
73 | mask = batch_data['mask'].clone().detach()
74 | training_mask = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
75 | eval_mask = batch_data.pop('eval_mask', None)
76 | y = batch_data.pop('y')
77 |
78 | ##########################
79 | # generate imputations
80 | ##########################
81 |
82 | imputations = self.model.generator(x, training_mask)
83 | imputed_seq = imputations[0]
84 | target = self._preprocess(y, batch_preprocessing)
85 | y_hat = self._postprocess(imputed_seq, batch_preprocessing)
86 |
87 | x_in = training_mask * x + (1 - training_mask) * imputed_seq
88 | hint = torch.rand_like(training_mask, dtype=torch.float) < self.hint_rate
89 | hint = hint.byte()
90 | hint = hint * training_mask + (1 - hint) * 0.5
91 |
92 | #########################
93 | # train generator
94 | #########################
95 | if (batch_idx % self.g_train_freq) == 0:
96 |
97 | g_opt.zero_grad()
98 |
99 | rec_loss = sum([torch.sqrt(self.loss_fn(imp, target, mask)) for imp in imputations])
100 | if self.consistency_loss:
101 | rec_loss += self._consistency_loss(imputations, mask)
102 |
103 | logits = self.model.discriminator(x_in, hint)
104 | # maximize logit
105 | adv_loss = self.masked_bce_loss(logits, torch.ones_like(logits), 1 - training_mask)
106 |
107 | g_loss = self.alpha * rec_loss + adv_loss
108 |
109 | self.manual_backward(g_loss)
110 | g_opt.step()
111 |
112 | # Logging
113 | metrics_mask = (mask | eval_mask) - training_mask
114 | self.train_metrics.update(y_hat.detach(), y, metrics_mask) # all unseen data
115 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
116 | self.log('gen_loss', adv_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
117 | self.log('imp_loss', rec_loss.detach(), on_step=True, on_epoch=True, logger=True, prog_bar=True)
118 |
119 | ###########################
120 | # train discriminator
121 | ###########################
122 |
123 | if (batch_idx % self.d_train_freq) == 0:
124 | d_opt.zero_grad()
125 |
126 | logits = self.model.discriminator(x_in.detach(), hint)
127 | d_loss = self.masked_bce_loss(logits, training_mask.to(logits.dtype))
128 |
129 | self.manual_backward(d_loss)
130 | d_opt.step()
131 | self.log('d_loss', d_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
132 |
133 | if (schedulers is not None) and self.trainer.is_last_batch:
134 | for sch in schedulers:
135 | sch.step()
136 |
137 | def configure_optimizers(self):
138 | opt_g = self.optim_class(self.model.generator.parameters(), **self.optim_kwargs)
139 | opt_d = self.optim_class(self.model.discriminator.parameters(), **self.optim_kwargs)
140 | optimizers = [opt_g, opt_d]
141 | if self.scheduler_class is not None:
142 | metric = self.scheduler_kwargs.pop('monitor', None)
143 | schedulers = [{"scheduler": self.scheduler_class(opt, **self.scheduler_kwargs), "monitor": metric}
144 | for opt in optimizers]
145 | return optimizers, schedulers
146 | return optimizers
147 |
--------------------------------------------------------------------------------
/lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .layers import *
2 |
--------------------------------------------------------------------------------
/lib/nn/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .rits import RITS, BRITS
2 | from .gril import GRIL, BiGRIL
3 | from .spatial_conv import SpatialConvOrderK
4 | from .mpgru import MPGRUImputer
5 |
--------------------------------------------------------------------------------
/lib/nn/layers/gcrnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .spatial_conv import SpatialConvOrderK
5 |
6 |
7 | class GCGRUCell(nn.Module):
8 | """
9 | Graph Convolution Gated Recurrent Unit Cell.
10 | """
11 |
12 | def __init__(self, d_in, num_units, support_len, order, activation='tanh'):
13 | """
14 | :param num_units: the hidden dim of rnn
15 | :param support_len: the (weighted) adjacency matrix of the graph, in numpy ndarray form
16 | :param order: the max diffusion step
17 | :param activation: if None, don't do activation for cell state
18 | """
19 | super(GCGRUCell, self).__init__()
20 | self.activation_fn = getattr(torch, activation)
21 |
22 | self.forget_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,
23 | order=order)
24 | self.update_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,
25 | order=order)
26 | self.c_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, order=order)
27 |
28 | def forward(self, x, h, adj):
29 | """
30 | :param x: (B, input_dim, num_nodes)
31 | :param h: (B, num_units, num_nodes)
32 | :param adj: (num_nodes, num_nodes)
33 | :return:
34 | """
35 | # we start with bias 1.0 to not reset and not update
36 | x_gates = torch.cat([x, h], dim=1)
37 | r = torch.sigmoid(self.forget_gate(x_gates, adj))
38 | u = torch.sigmoid(self.update_gate(x_gates, adj))
39 | x_c = torch.cat([x, r * h], dim=1)
40 | c = self.c_gate(x_c, adj) # batch_size, self._num_nodes * output_size
41 | c = self.activation_fn(c)
42 | return u * h + (1. - u) * c
43 |
44 |
45 | class GCRNN(nn.Module):
46 | def __init__(self,
47 | d_in,
48 | d_model,
49 | d_out,
50 | n_layers,
51 | support_len,
52 | kernel_size=2):
53 | super(GCRNN, self).__init__()
54 | self.d_in = d_in
55 | self.d_model = d_model
56 | self.d_out = d_out
57 | self.n_layers = n_layers
58 | self.ks = kernel_size
59 | self.support_len = support_len
60 | self.rnn_cells = nn.ModuleList()
61 | for i in range(self.n_layers):
62 | self.rnn_cells.append(GCGRUCell(d_in=self.d_in if i == 0 else self.d_model,
63 | num_units=self.d_model, support_len=self.support_len, order=self.ks))
64 | self.output_layer = nn.Conv2d(self.d_model, self.d_out, kernel_size=1)
65 |
66 | def init_hidden_states(self, x):
67 | return [torch.zeros(size=(x.shape[0], self.d_model, x.shape[2])).to(x.device) for _ in range(self.n_layers)]
68 |
69 | def single_pass(self, x, h, adj):
70 | out = x
71 | for l, layer in enumerate(self.rnn_cells):
72 | out = h[l] = layer(out, h[l], adj)
73 | return out, h
74 |
75 | def forward(self, x, adj, h=None):
76 | # x:[batch, features, nodes, steps]
77 | *_, steps = x.size()
78 | if h is None:
79 | h = self.init_hidden_states(x)
80 | # temporal conv
81 | for step in range(steps):
82 | out, h = self.single_pass(x[..., step], h, adj)
83 |
84 | return self.output_layer(out[..., None])
85 |
--------------------------------------------------------------------------------
/lib/nn/layers/gril.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 | from .spatial_conv import SpatialConvOrderK
6 | from .gcrnn import GCGRUCell
7 | from .spatial_attention import SpatialAttention
8 | from ..utils.ops import reverse_tensor
9 |
10 |
11 | class SpatialDecoder(nn.Module):
12 | def __init__(self, d_in, d_model, d_out, support_len, order=1, attention_block=False, nheads=2, dropout=0.):
13 | super(SpatialDecoder, self).__init__()
14 | self.order = order
15 | self.lin_in = nn.Conv1d(d_in, d_model, kernel_size=1)
16 | self.graph_conv = SpatialConvOrderK(c_in=d_model, c_out=d_model,
17 | support_len=support_len * order, order=1, include_self=False)
18 | if attention_block:
19 | self.spatial_att = SpatialAttention(d_in=d_model,
20 | d_model=d_model,
21 | nheads=nheads,
22 | dropout=dropout)
23 | self.lin_out = nn.Conv1d(3 * d_model, d_model, kernel_size=1)
24 | else:
25 | self.register_parameter('spatial_att', None)
26 | self.lin_out = nn.Conv1d(2 * d_model, d_model, kernel_size=1)
27 | self.read_out = nn.Conv1d(2 * d_model, d_out, kernel_size=1)
28 | self.activation = nn.PReLU()
29 | self.adj = None
30 |
31 | def forward(self, x, m, h, u, adj, cached_support=False):
32 | # [batch, channels, nodes]
33 | x_in = [x, m, h] if u is None else [x, m, u, h]
34 | x_in = torch.cat(x_in, 1)
35 | if self.order > 1:
36 | if cached_support and (self.adj is not None):
37 | adj = self.adj
38 | else:
39 | adj = SpatialConvOrderK.compute_support_orderK(adj, self.order, include_self=False, device=x_in.device)
40 | self.adj = adj if cached_support else None
41 |
42 | x_in = self.lin_in(x_in)
43 | out = self.graph_conv(x_in, adj)
44 | if self.spatial_att is not None:
45 | # [batch, channels, nodes] -> [batch, steps, nodes, features]
46 | x_in = rearrange(x_in, 'b f n -> b 1 n f')
47 | out_att = self.spatial_att(x_in, torch.eye(x_in.size(2), dtype=torch.bool, device=x_in.device))
48 | out_att = rearrange(out_att, 'b s n f -> b f (n s)')
49 | out = torch.cat([out, out_att], 1)
50 | out = torch.cat([out, h], 1)
51 | out = self.activation(self.lin_out(out))
52 | # out = self.lin_out(out)
53 | out = torch.cat([out, h], 1)
54 | return self.read_out(out), out
55 |
56 |
57 | class GRIL(nn.Module):
58 | def __init__(self,
59 | input_size,
60 | hidden_size,
61 | u_size=None,
62 | n_layers=1,
63 | dropout=0.,
64 | kernel_size=2,
65 | decoder_order=1,
66 | global_att=False,
67 | support_len=2,
68 | n_nodes=None,
69 | layer_norm=False):
70 | super(GRIL, self).__init__()
71 | self.input_size = int(input_size)
72 | self.hidden_size = int(hidden_size)
73 | self.u_size = int(u_size) if u_size is not None else 0
74 | self.n_layers = int(n_layers)
75 | rnn_input_size = 2 * self.input_size + self.u_size # input + mask + (eventually) exogenous
76 |
77 | # Spatio-temporal encoder (rnn_input_size -> hidden_size)
78 | self.cells = nn.ModuleList()
79 | self.norms = nn.ModuleList()
80 | for i in range(self.n_layers):
81 | self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,
82 | num_units=self.hidden_size, support_len=support_len, order=kernel_size))
83 | if layer_norm:
84 | self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))
85 | else:
86 | self.norms.append(nn.Identity())
87 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None
88 |
89 | # Fist stage readout
90 | self.first_stage = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)
91 |
92 | # Spatial decoder (rnn_input_size + hidden_size -> hidden_size)
93 | self.spatial_decoder = SpatialDecoder(d_in=rnn_input_size + self.hidden_size,
94 | d_model=self.hidden_size,
95 | d_out=self.input_size,
96 | support_len=2,
97 | order=decoder_order,
98 | attention_block=global_att)
99 |
100 | # Hidden state initialization embedding
101 | if n_nodes is not None:
102 | self.h0 = self.init_hidden_states(n_nodes)
103 | else:
104 | self.register_parameter('h0', None)
105 |
106 | def init_hidden_states(self, n_nodes):
107 | h0 = []
108 | for l in range(self.n_layers):
109 | std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))
110 | vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))
111 | h0.append(nn.Parameter(vals))
112 | return nn.ParameterList(h0)
113 |
114 | def get_h0(self, x):
115 | if self.h0 is not None:
116 | return [h.expand(x.shape[0], -1, -1) for h in self.h0]
117 | return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers
118 |
119 | def update_state(self, x, h, adj):
120 | rnn_in = x
121 | for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):
122 | rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))
123 | if self.dropout is not None and layer < (self.n_layers - 1):
124 | rnn_in = self.dropout(rnn_in)
125 | return h
126 |
127 | def forward(self, x, adj, mask=None, u=None, h=None, cached_support=False):
128 | # x:[batch, features, nodes, steps]
129 | *_, steps = x.size()
130 |
131 | # infer all valid if mask is None
132 | if mask is None:
133 | mask = torch.ones_like(x, dtype=torch.uint8)
134 |
135 | # init hidden state using node embedding or the empty state
136 | if h is None:
137 | h = self.get_h0(x)
138 | elif not isinstance(h, list):
139 | h = [*h]
140 |
141 | # Temporal conv
142 | predictions, imputations, states = [], [], []
143 | representations = []
144 | for step in range(steps):
145 | x_s = x[..., step]
146 | m_s = mask[..., step]
147 | h_s = h[-1]
148 | u_s = u[..., step] if u is not None else None
149 | # firstly impute missing values with predictions from state
150 | xs_hat_1 = self.first_stage(h_s)
151 | # fill missing values in input with prediction
152 | x_s = torch.where(m_s, x_s, xs_hat_1)
153 | # prepare inputs
154 | # retrieve maximum information from neighbors
155 | xs_hat_2, repr_s = self.spatial_decoder(x=x_s, m=m_s, h=h_s, u=u_s, adj=adj,
156 | cached_support=cached_support) # receive messages from neighbors (no self-loop!)
157 | # readout of imputation state + mask to retrieve imputations
158 | # prepare inputs
159 | x_s = torch.where(m_s, x_s, xs_hat_2)
160 | inputs = [x_s, m_s]
161 | if u_s is not None:
162 | inputs.append(u_s)
163 | inputs = torch.cat(inputs, dim=1) # x_hat_2 + mask + exogenous
164 | # update state with original sequence filled using imputations
165 | h = self.update_state(inputs, h, adj)
166 | # store imputations and states
167 | imputations.append(xs_hat_2)
168 | predictions.append(xs_hat_1)
169 | states.append(torch.stack(h, dim=0))
170 | representations.append(repr_s)
171 |
172 | # Aggregate outputs -> [batch, features, nodes, steps]
173 | imputations = torch.stack(imputations, dim=-1)
174 | predictions = torch.stack(predictions, dim=-1)
175 | states = torch.stack(states, dim=-1)
176 | representations = torch.stack(representations, dim=-1)
177 |
178 | return imputations, predictions, representations, states
179 |
180 |
181 | class BiGRIL(nn.Module):
182 | def __init__(self,
183 | input_size,
184 | hidden_size,
185 | ff_size,
186 | ff_dropout,
187 | n_layers=1,
188 | dropout=0.,
189 | n_nodes=None,
190 | support_len=2,
191 | kernel_size=2,
192 | decoder_order=1,
193 | global_att=False,
194 | u_size=0,
195 | embedding_size=0,
196 | layer_norm=False,
197 | merge='mlp'):
198 | super(BiGRIL, self).__init__()
199 | self.fwd_rnn = GRIL(input_size=input_size,
200 | hidden_size=hidden_size,
201 | n_layers=n_layers,
202 | dropout=dropout,
203 | n_nodes=n_nodes,
204 | support_len=support_len,
205 | kernel_size=kernel_size,
206 | decoder_order=decoder_order,
207 | global_att=global_att,
208 | u_size=u_size,
209 | layer_norm=layer_norm)
210 | self.bwd_rnn = GRIL(input_size=input_size,
211 | hidden_size=hidden_size,
212 | n_layers=n_layers,
213 | dropout=dropout,
214 | n_nodes=n_nodes,
215 | support_len=support_len,
216 | kernel_size=kernel_size,
217 | decoder_order=decoder_order,
218 | global_att=global_att,
219 | u_size=u_size,
220 | layer_norm=layer_norm)
221 |
222 | if n_nodes is None:
223 | embedding_size = 0
224 | if embedding_size > 0:
225 | self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))
226 | nn.init.kaiming_normal_(self.emb, nonlinearity='relu')
227 | else:
228 | self.register_parameter('emb', None)
229 |
230 | if merge == 'mlp':
231 | self._impute_from_states = True
232 | self.out = nn.Sequential(
233 | nn.Conv2d(in_channels=4 * hidden_size + input_size + embedding_size,
234 | out_channels=ff_size, kernel_size=1),
235 | nn.ReLU(),
236 | nn.Dropout(ff_dropout),
237 | nn.Conv2d(in_channels=ff_size, out_channels=input_size, kernel_size=1)
238 | )
239 | elif merge in ['mean', 'sum', 'min', 'max']:
240 | self._impute_from_states = False
241 | self.out = getattr(torch, merge)
242 | else:
243 | raise ValueError("Merge option %s not allowed." % merge)
244 | self.supp = None
245 |
246 | def forward(self, x, adj, mask=None, u=None, cached_support=False):
247 | if cached_support and (self.supp is not None):
248 | supp = self.supp
249 | else:
250 | supp = SpatialConvOrderK.compute_support(adj, x.device)
251 | self.supp = supp if cached_support else None
252 | # Forward
253 | fwd_out, fwd_pred, fwd_repr, _ = self.fwd_rnn(x, supp, mask=mask, u=u, cached_support=cached_support)
254 | # Backward
255 | rev_x, rev_mask, rev_u = [reverse_tensor(tens) for tens in (x, mask, u)]
256 | *bwd_res, _ = self.bwd_rnn(rev_x, supp, mask=rev_mask, u=rev_u, cached_support=cached_support)
257 | bwd_out, bwd_pred, bwd_repr = [reverse_tensor(res) for res in bwd_res]
258 |
259 | if self._impute_from_states:
260 | inputs = [fwd_repr, bwd_repr, mask]
261 | if self.emb is not None:
262 | b, *_, s = fwd_repr.shape # fwd_h: [batches, channels, nodes, steps]
263 | inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)] # stack emb for batches and steps
264 | imputation = torch.cat(inputs, dim=1)
265 | imputation = self.out(imputation)
266 | else:
267 | imputation = torch.stack([fwd_out, bwd_out], dim=1)
268 | imputation = self.out(imputation, dim=1)
269 |
270 | predictions = torch.stack([fwd_out, bwd_out, fwd_pred, bwd_pred], dim=0)
271 |
272 | return imputation, predictions
273 |
--------------------------------------------------------------------------------
/lib/nn/layers/imputation.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class ImputationLayer(nn.Module):
9 | def __init__(self, d_in, bias=True):
10 | super(ImputationLayer, self).__init__()
11 | self.W = nn.Parameter(torch.Tensor(d_in, d_in))
12 | if bias:
13 | self.b = nn.Parameter(torch.Tensor(d_in))
14 | else:
15 | self.register_buffer('b', None)
16 | mask = 1. - torch.eye(d_in)
17 | self.register_buffer('mask', mask)
18 | self.reset_parameters()
19 |
20 | def reset_parameters(self):
21 | nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
22 | if self.b is not None:
23 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
24 | bound = 1 / math.sqrt(fan_in)
25 | nn.init.uniform_(self.b, -bound, bound)
26 |
27 | def forward(self, x):
28 | # batch, features
29 | return F.linear(x, self.mask * self.W, self.b)
--------------------------------------------------------------------------------
/lib/nn/layers/mpgru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .gcrnn import GCGRUCell
5 |
6 |
7 | class MPGRUImputer(nn.Module):
8 | def __init__(self,
9 | input_size,
10 | hidden_size,
11 | ff_size=None,
12 | u_size=None,
13 | n_layers=1,
14 | dropout=0.,
15 | kernel_size=2,
16 | support_len=2,
17 | n_nodes=None,
18 | layer_norm=False,
19 | autoencoder_mode=False):
20 | super(MPGRUImputer, self).__init__()
21 | self.input_size = int(input_size)
22 | self.hidden_size = int(hidden_size)
23 | self.ff_size = int(ff_size) if ff_size is not None else 0
24 | self.u_size = int(u_size) if u_size is not None else 0
25 | self.n_layers = int(n_layers)
26 | rnn_input_size = 2 * self.input_size + self.u_size # input + mask + (eventually) exogenous
27 |
28 | # Spatio-temporal encoder (rnn_input_size -> hidden_size)
29 | self.cells = nn.ModuleList()
30 | self.norms = nn.ModuleList()
31 | for i in range(self.n_layers):
32 | self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,
33 | num_units=self.hidden_size, support_len=support_len, order=kernel_size))
34 | if layer_norm:
35 | self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))
36 | else:
37 | self.norms.append(nn.Identity())
38 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None
39 |
40 | # Readout
41 | if self.ff_size:
42 | self.pred_readout = nn.Sequential(
43 | nn.Conv1d(in_channels=self.hidden_size, out_channels=self.ff_size, kernel_size=1),
44 | nn.PReLU(),
45 | nn.Conv1d(in_channels=self.ff_size, out_channels=self.input_size, kernel_size=1)
46 | )
47 | else:
48 | self.pred_readout = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)
49 |
50 | # Hidden state initialization embedding
51 | if n_nodes is not None:
52 | self.h0 = self.init_hidden_states(n_nodes)
53 | else:
54 | self.register_parameter('h0', None)
55 |
56 | self.autoencoder_mode = autoencoder_mode
57 |
58 | def init_hidden_states(self, n_nodes):
59 | h0 = []
60 | for l in range(self.n_layers):
61 | std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))
62 | vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))
63 | h0.append(nn.Parameter(vals))
64 | return nn.ParameterList(h0)
65 |
66 | def get_h0(self, x):
67 | if self.h0 is not None:
68 | return [h.expand(x.shape[0], -1, -1) for h in self.h0]
69 | return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers
70 |
71 | def update_state(self, x, h, adj):
72 | rnn_in = x
73 | for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):
74 | rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))
75 | if self.dropout is not None and layer < (self.n_layers - 1):
76 | rnn_in = self.dropout(rnn_in)
77 | return h
78 |
79 | def forward(self, x, adj, mask=None, u=None, h=None):
80 | # x:[batch, features, nodes, steps]
81 | *_, steps = x.size()
82 |
83 | # infer all valid if mask is None
84 | if mask is None:
85 | mask = torch.ones_like(x, dtype=torch.uint8)
86 |
87 | # init hidden state using node embedding or the empty state
88 | if h is None:
89 | h = self.get_h0(x)
90 | elif not isinstance(h, list):
91 | h = [*h]
92 |
93 | # Temporal conv
94 | predictions, states = [], []
95 | for step in range(steps):
96 | x_s = x[..., step]
97 | m_s = mask[..., step]
98 | h_s = h[-1]
99 | u_s = u[..., step] if u is not None else None
100 | # impute missing values with predictions from state
101 | x_s_hat = self.pred_readout(h_s)
102 | # store imputations and state
103 | predictions.append(x_s_hat)
104 | states.append(torch.stack(h, dim=0))
105 | # fill missing values in input with prediction
106 | x_s = torch.where(m_s, x_s, x_s_hat)
107 | inputs = [x_s, m_s]
108 | if u_s is not None:
109 | inputs.append(u_s)
110 | inputs = torch.cat(inputs, dim=1) # x_hat complemented + mask + exogenous
111 | # update state with original sequence filled using imputations
112 | h = self.update_state(inputs, h, adj)
113 |
114 | # In autoencoder mode use states after input processing
115 | if self.autoencoder_mode:
116 | states = states[1:] + [torch.stack(h, dim=0)]
117 |
118 | # Aggregate outputs -> [batch, features, nodes, steps]
119 | predictions = torch.stack(predictions, dim=-1)
120 | states = torch.stack(states, dim=-1)
121 |
122 | return predictions, states
123 |
--------------------------------------------------------------------------------
/lib/nn/layers/rits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | from torch.nn.parameter import Parameter
8 |
9 | from ..utils.ops import reverse_tensor
10 |
11 |
12 | class FeatureRegression(nn.Module):
13 | def __init__(self, input_size):
14 | super(FeatureRegression, self).__init__()
15 | self.W = Parameter(torch.Tensor(input_size, input_size))
16 | self.b = Parameter(torch.Tensor(input_size))
17 |
18 | m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
19 | self.register_buffer('m', m)
20 |
21 | self.reset_parameters()
22 |
23 | def reset_parameters(self):
24 | stdv = 1. / math.sqrt(self.W.shape[0])
25 | self.W.data.uniform_(-stdv, stdv)
26 | if self.b is not None:
27 | self.b.data.uniform_(-stdv, stdv)
28 |
29 | def forward(self, x):
30 | z_h = F.linear(x, self.W * Variable(self.m), self.b)
31 | return z_h
32 |
33 |
34 | class TemporalDecay(nn.Module):
35 | def __init__(self, d_in, d_out, diag=False):
36 | super(TemporalDecay, self).__init__()
37 | self.diag = diag
38 | self.W = Parameter(torch.Tensor(d_out, d_in))
39 | self.b = Parameter(torch.Tensor(d_out))
40 |
41 | if self.diag:
42 | assert (d_in == d_out)
43 | m = torch.eye(d_in, d_in)
44 | self.register_buffer('m', m)
45 |
46 | self.reset_parameters()
47 |
48 | def reset_parameters(self):
49 | stdv = 1. / math.sqrt(self.W.shape[0])
50 | self.W.data.uniform_(-stdv, stdv)
51 | if self.b is not None:
52 | self.b.data.uniform_(-stdv, stdv)
53 |
54 | @staticmethod
55 | def compute_delta(mask, freq=1):
56 | delta = torch.zeros_like(mask).float()
57 | one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device)
58 | for i in range(1, delta.shape[-2]):
59 | m = mask[..., i - 1, :]
60 | delta[..., i, :] = m * one_step + (1 - m) * torch.add(delta[..., i - 1, :], freq)
61 | return delta
62 |
63 | def forward(self, d):
64 | if self.diag:
65 | gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b))
66 | else:
67 | gamma = F.relu(F.linear(d, self.W, self.b))
68 | gamma = torch.exp(-gamma)
69 | return gamma
70 |
71 |
72 | class RITS(nn.Module):
73 | def __init__(self,
74 | input_size,
75 | hidden_size=64):
76 | super(RITS, self).__init__()
77 | self.input_size = int(input_size)
78 | self.hidden_size = int(hidden_size)
79 |
80 | self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size)
81 |
82 | self.temp_decay_h = TemporalDecay(d_in=self.input_size, d_out=self.hidden_size, diag=False)
83 | self.temp_decay_x = TemporalDecay(d_in=self.input_size, d_out=self.input_size, diag=True)
84 |
85 | self.hist_reg = nn.Linear(self.hidden_size, self.input_size)
86 | self.feat_reg = FeatureRegression(self.input_size)
87 |
88 | self.weight_combine = nn.Linear(2 * self.input_size, self.input_size)
89 |
90 | def init_hidden_states(self, x):
91 | return Variable(torch.zeros((x.shape[0], self.hidden_size))).to(x.device)
92 |
93 | def forward(self, x, mask=None, delta=None):
94 | # x : [batch, steps, features]
95 | steps = x.shape[-2]
96 |
97 | if mask is None:
98 | mask = torch.ones_like(x, dtype=torch.uint8)
99 | if delta is None:
100 | delta = TemporalDecay.compute_delta(mask)
101 |
102 | # init rnn states
103 | h = self.init_hidden_states(x)
104 | c = self.init_hidden_states(x)
105 |
106 | imputation = []
107 | predictions = []
108 | for step in range(steps):
109 | d = delta[:, step, :]
110 | m = mask[:, step, :]
111 | x_s = x[:, step, :]
112 |
113 | gamma_h = self.temp_decay_h(d)
114 |
115 | # history prediction
116 | x_h = self.hist_reg(h)
117 | x_c = m * x_s + (1 - m) * x_h
118 | h = h * gamma_h
119 |
120 | # feature prediction
121 | z_h = self.feat_reg(x_c)
122 |
123 | # predictions combination
124 | gamma_x = self.temp_decay_x(d)
125 | alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1))
126 | alpha = torch.sigmoid(alpha)
127 | c_h = alpha * z_h + (1 - alpha) * x_h
128 |
129 | c_c = m * x_s + (1 - m) * c_h
130 | inputs = torch.cat([c_c, m], dim=1)
131 | h, c = self.rnn_cell(inputs, (h, c))
132 |
133 | imputation.append(c_c)
134 | predictions.append(torch.stack((c_h, z_h, x_h), dim=0))
135 |
136 | # imputation -> [batch, steps, features]
137 | imputation = torch.stack(imputation, dim=-2)
138 | # predictions -> [predictions, batch, steps, features]
139 | predictions = torch.stack(predictions, dim=-2)
140 | c_h, z_h, x_h = predictions
141 |
142 | return imputation, (c_h, z_h, x_h)
143 |
144 |
145 | class BRITS(nn.Module):
146 |
147 | def __init__(self, input_size, hidden_size):
148 | super().__init__()
149 | self.rits_fwd = RITS(input_size, hidden_size)
150 | self.rits_bwd = RITS(input_size, hidden_size)
151 |
152 | def forward(self, x, mask=None):
153 | # x: [batches, steps, features]
154 | # forward
155 | imp_fwd, pred_fwd = self.rits_fwd(x, mask)
156 | # backward
157 | x_bwd = reverse_tensor(x, axis=1)
158 | mask_bwd = reverse_tensor(mask, axis=1) if mask is not None else None
159 | imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd)
160 | imp_bwd, pred_bwd = reverse_tensor(imp_bwd, axis=1), [reverse_tensor(pb, axis=1) for pb in pred_bwd]
161 | # stack into shape = [batch, directions, steps, features]
162 | imputation = torch.stack([imp_fwd, imp_bwd], dim=1)
163 | predictions = [torch.stack([pf, pb], dim=1) for pf, pb in zip(pred_fwd, pred_bwd)]
164 | c_h, z_h, x_h = predictions
165 |
166 | return imputation, (c_h, z_h, x_h)
167 |
168 | @staticmethod
169 | def consistency_loss(imp_fwd, imp_bwd):
170 | loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean()
171 | return loss
172 |
--------------------------------------------------------------------------------
/lib/nn/layers/spatial_attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from einops import rearrange
4 |
5 |
6 | class SpatialAttention(nn.Module):
7 | def __init__(self, d_in, d_model, nheads, dropout=0.):
8 | super(SpatialAttention, self).__init__()
9 | self.lin_in = nn.Linear(d_in, d_model)
10 | self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout)
11 |
12 | def forward(self, x, att_mask=None, **kwargs):
13 | r"""Pass the input through the encoder layer.
14 |
15 | Args:
16 | src: the sequence to the encoder layer (required).
17 | src_mask: the mask for the src sequence (optional).
18 | src_key_padding_mask: the mask for the src keys per batch (optional).
19 |
20 | Shape:
21 | see the docs in Transformer class.
22 | """
23 | b, s, n, f = x.size()
24 | x = rearrange(x, 'b s n f -> n (b s) f')
25 | x = self.lin_in(x)
26 | x = self.self_attn(x, x, x, attn_mask=att_mask)[0]
27 | x = rearrange(x, 'n (b s) f -> b s n f', b=b, s=s)
28 | return x
29 |
--------------------------------------------------------------------------------
/lib/nn/layers/spatial_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ... import epsilon
5 |
6 |
7 | class SpatialConvOrderK(nn.Module):
8 | """
9 | Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs)
10 |
11 | Efficient implementation inspired from graph-wavenet codebase
12 | """
13 |
14 | def __init__(self, c_in, c_out, support_len=3, order=2, include_self=True):
15 | super(SpatialConvOrderK, self).__init__()
16 | self.include_self = include_self
17 | c_in = (order * support_len + (1 if include_self else 0)) * c_in
18 | self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1)
19 | self.order = order
20 |
21 | @staticmethod
22 | def compute_support(adj, device=None):
23 | if device is not None:
24 | adj = adj.to(device)
25 | adj_bwd = adj.T
26 | adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
27 | adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
28 | support = [adj_fwd, adj_bwd]
29 | return support
30 |
31 | @staticmethod
32 | def compute_support_orderK(adj, k, include_self=False, device=None):
33 | if isinstance(adj, (list, tuple)):
34 | support = adj
35 | else:
36 | support = SpatialConvOrderK.compute_support(adj, device)
37 | supp_k = []
38 | for a in support:
39 | ak = a
40 | for i in range(k - 1):
41 | ak = torch.matmul(ak, a.T)
42 | if not include_self:
43 | ak.fill_diagonal_(0.)
44 | supp_k.append(ak)
45 | return support + supp_k
46 |
47 | def forward(self, x, support):
48 | # [batch, features, nodes, steps]
49 | if x.dim() < 4:
50 | squeeze = True
51 | x = torch.unsqueeze(x, -1)
52 | else:
53 | squeeze = False
54 | out = [x] if self.include_self else []
55 | if (type(support) is not list):
56 | support = [support]
57 | for a in support:
58 | x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
59 | out.append(x1)
60 | for k in range(2, self.order + 1):
61 | x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
62 | out.append(x2)
63 | x1 = x2
64 |
65 | out = torch.cat(out, dim=1)
66 | out = self.mlp(out)
67 | if squeeze:
68 | out = out.squeeze(-1)
69 | return out
70 |
--------------------------------------------------------------------------------
/lib/nn/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .grin import GRINet
2 | from .brits import BRITSNet
3 | from .mpgru import MPGRUNet, BiMPGRUNet
4 | from .var import VARImputer
5 | from .rgain import RGAINNet
6 | from .rnn_imputers import BiRNNImputer, RNNImputer
7 |
--------------------------------------------------------------------------------
/lib/nn/models/brits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ..layers import BRITS
5 |
6 |
7 | class BRITSNet(nn.Module):
8 | def __init__(self,
9 | d_in,
10 | d_hidden=64):
11 | super(BRITSNet, self).__init__()
12 | self.birits = BRITS(input_size=d_in,
13 | hidden_size=d_hidden)
14 |
15 | def forward(self, x, mask=None, **kwargs):
16 | # x: [batches, steps, features]
17 | imputations, predictions = self.birits(x, mask=mask)
18 | # predictions: [batch, directions, steps, features] x 3
19 | out = torch.mean(imputations, dim=1) # -> [batch, steps, features]
20 | predictions = torch.cat(predictions, dim=1) # -> [batch, directions * n_predictions, steps, features]
21 | # reshape
22 | imputations = torch.transpose(imputations, 0, 1) # rearrange(imputations, 'b d s f -> d b s f')
23 | predictions = torch.transpose(predictions, 0, 1) # rearrange(predictions, 'b d s f -> d b s f')
24 | return out, imputations, predictions
25 |
26 | @staticmethod
27 | def add_model_specific_args(parser):
28 | parser.add_argument('--d-in', type=int)
29 | parser.add_argument('--d-hidden', type=int, default=64)
30 | return parser
31 |
--------------------------------------------------------------------------------
/lib/nn/models/grin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import nn
4 |
5 | from ..layers import BiGRIL
6 | from ...utils.parser_utils import str_to_bool
7 |
8 |
9 | class GRINet(nn.Module):
10 | def __init__(self,
11 | adj,
12 | d_in,
13 | d_hidden,
14 | d_ff,
15 | ff_dropout,
16 | n_layers=1,
17 | kernel_size=2,
18 | decoder_order=1,
19 | global_att=False,
20 | d_u=0,
21 | d_emb=0,
22 | layer_norm=False,
23 | merge='mlp',
24 | impute_only_holes=True):
25 | super(GRINet, self).__init__()
26 | self.d_in = d_in
27 | self.d_hidden = d_hidden
28 | self.d_u = int(d_u) if d_u is not None else 0
29 | self.d_emb = int(d_emb) if d_emb is not None else 0
30 | self.register_buffer('adj', torch.tensor(adj).float())
31 | self.impute_only_holes = impute_only_holes
32 |
33 | self.bigrill = BiGRIL(input_size=self.d_in,
34 | ff_size=d_ff,
35 | ff_dropout=ff_dropout,
36 | hidden_size=self.d_hidden,
37 | embedding_size=self.d_emb,
38 | n_nodes=self.adj.shape[0],
39 | n_layers=n_layers,
40 | kernel_size=kernel_size,
41 | decoder_order=decoder_order,
42 | global_att=global_att,
43 | u_size=self.d_u,
44 | layer_norm=layer_norm,
45 | merge=merge)
46 |
47 | def forward(self, x, mask=None, u=None, **kwargs):
48 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
49 | x = rearrange(x, 'b s n c -> b c n s')
50 | if mask is not None:
51 | mask = rearrange(mask, 'b s n c -> b c n s')
52 |
53 | if u is not None:
54 | u = rearrange(u, 'b s n c -> b c n s')
55 |
56 | # imputation: [batches, channels, nodes, steps] prediction: [4, batches, channels, nodes, steps]
57 | imputation, prediction = self.bigrill(x, self.adj, mask=mask, u=u, cached_support=self.training)
58 | # In evaluation stage impute only missing values
59 | if self.impute_only_holes and not self.training:
60 | imputation = torch.where(mask, x, imputation)
61 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
62 | imputation = torch.transpose(imputation, -3, -1)
63 | prediction = torch.transpose(prediction, -3, -1)
64 | if self.training:
65 | return imputation, prediction
66 | return imputation
67 |
68 | @staticmethod
69 | def add_model_specific_args(parser):
70 | parser.add_argument('--d-hidden', type=int, default=64)
71 | parser.add_argument('--d-ff', type=int, default=64)
72 | parser.add_argument('--ff-dropout', type=int, default=0.)
73 | parser.add_argument('--n-layers', type=int, default=1)
74 | parser.add_argument('--kernel-size', type=int, default=2)
75 | parser.add_argument('--decoder-order', type=int, default=1)
76 | parser.add_argument('--d-u', type=int, default=0)
77 | parser.add_argument('--d-emb', type=int, default=8)
78 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
79 | parser.add_argument('--global-att', type=str_to_bool, nargs='?', const=True, default=False)
80 | parser.add_argument('--merge', type=str, default='mlp')
81 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
82 | return parser
83 |
--------------------------------------------------------------------------------
/lib/nn/models/mpgru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import nn
4 |
5 | from ..layers import MPGRUImputer, SpatialConvOrderK
6 | from ..utils.ops import reverse_tensor
7 | from ...utils.parser_utils import str_to_bool
8 |
9 |
10 | class MPGRUNet(nn.Module):
11 | def __init__(self,
12 | adj,
13 | d_in,
14 | d_hidden,
15 | d_ff=0,
16 | d_u=0,
17 | n_layers=1,
18 | dropout=0.,
19 | kernel_size=2,
20 | support_len=2,
21 | layer_norm=False,
22 | impute_only_holes=True):
23 | super(MPGRUNet, self).__init__()
24 | self.register_buffer('adj', torch.tensor(adj).float())
25 | n_nodes = adj.shape[0]
26 | self.gcgru = MPGRUImputer(input_size=d_in,
27 | hidden_size=d_hidden,
28 | ff_size=d_ff,
29 | u_size=d_u,
30 | n_layers=n_layers,
31 | dropout=dropout,
32 | kernel_size=kernel_size,
33 | support_len=support_len,
34 | layer_norm=layer_norm,
35 | n_nodes=n_nodes)
36 | self.impute_only_holes = impute_only_holes
37 |
38 | def forward(self, x, mask=None, u=None, h=None):
39 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
40 | x = rearrange(x, 'b s n c -> b c n s')
41 | if mask is not None:
42 | mask = rearrange(mask, 'b s n c -> b c n s')
43 | if u is not None:
44 | u = rearrange(u, 'b s n c -> b c n s')
45 |
46 | adj = SpatialConvOrderK.compute_support(self.adj, x.device)
47 | imputation, _ = self.gcgru(x, adj, mask=mask, u=u, h=h)
48 |
49 | # In evaluation stage impute only missing values
50 | if self.impute_only_holes and not self.training:
51 | imputation = torch.where(mask, x, imputation)
52 |
53 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
54 | imputation = rearrange(imputation, 'b c n s -> b s n c')
55 |
56 | return imputation
57 |
58 | @staticmethod
59 | def add_model_specific_args(parser):
60 | parser.add_argument('--d-hidden', type=int, default=64)
61 | parser.add_argument('--d-ff', type=int, default=64)
62 | parser.add_argument('--n-layers', type=int, default=1)
63 | parser.add_argument('--kernel-size', type=int, default=2)
64 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
65 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
66 | parser.add_argument('--dropout', type=float, default=0.)
67 | return parser
68 |
69 |
70 | class BiMPGRUNet(nn.Module):
71 | def __init__(self,
72 | adj,
73 | d_in,
74 | d_hidden,
75 | d_ff=0,
76 | d_u=0,
77 | n_layers=1,
78 | dropout=0.,
79 | kernel_size=2,
80 | support_len=2,
81 | layer_norm=False,
82 | embedding_size=0,
83 | merge='mlp',
84 | impute_only_holes=True,
85 | autoencoder_mode=False):
86 | super(BiMPGRUNet, self).__init__()
87 | self.register_buffer('adj', torch.tensor(adj).float())
88 | n_nodes = adj.shape[0]
89 | self.gcgru_fwd = MPGRUImputer(input_size=d_in,
90 | hidden_size=d_hidden,
91 | u_size=d_u,
92 | n_layers=n_layers,
93 | dropout=dropout,
94 | kernel_size=kernel_size,
95 | support_len=support_len,
96 | layer_norm=layer_norm,
97 | n_nodes=n_nodes,
98 | autoencoder_mode=autoencoder_mode)
99 | self.gcgru_bwd = MPGRUImputer(input_size=d_in,
100 | hidden_size=d_hidden,
101 | u_size=d_u,
102 | n_layers=n_layers,
103 | dropout=dropout,
104 | kernel_size=kernel_size,
105 | support_len=support_len,
106 | layer_norm=layer_norm,
107 | n_nodes=n_nodes,
108 | autoencoder_mode=autoencoder_mode)
109 | self.impute_only_holes = impute_only_holes
110 |
111 | if n_nodes is None:
112 | embedding_size = 0
113 | if embedding_size > 0:
114 | self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))
115 | nn.init.kaiming_normal_(self.emb, nonlinearity='relu')
116 | else:
117 | self.register_parameter('emb', None)
118 |
119 | if merge == 'mlp':
120 | self._impute_from_states = True
121 | self.out = nn.Sequential(
122 | nn.Conv2d(in_channels=2 * d_hidden + d_in + embedding_size,
123 | out_channels=d_ff, kernel_size=1),
124 | nn.ReLU(),
125 | nn.Conv2d(in_channels=d_ff, out_channels=d_in, kernel_size=1)
126 | )
127 | elif merge in ['mean', 'sum', 'min', 'max']:
128 | self._impute_from_states = False
129 | self.out = getattr(torch, merge)
130 | else:
131 | raise ValueError("Merge option %s not allowed." % merge)
132 |
133 | def forward(self, x, mask=None, u=None, h=None):
134 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
135 | x = rearrange(x, 'b s n c -> b c n s')
136 | if mask is not None:
137 | mask = rearrange(mask, 'b s n c -> b c n s')
138 | if u is not None:
139 | u = rearrange(u, 'b s n c -> b c n s')
140 |
141 | adj = SpatialConvOrderK.compute_support(self.adj, x.device)
142 |
143 | # Forward
144 | fwd_pred, fwd_states = self.gcgru_fwd(x, adj, mask=mask, u=u)
145 | # Backward
146 | rev_x, rev_mask, rev_u = [reverse_tensor(tens, axis=-1) for tens in (x, mask, u)]
147 | bwd_res = self.gcgru_bwd(rev_x, adj, mask=rev_mask, u=rev_u)
148 | bwd_pred, bwd_states = [reverse_tensor(res, axis=-1) for res in bwd_res]
149 |
150 | if self._impute_from_states:
151 | inputs = [fwd_states[-1], bwd_states[-1], mask] # take only state of last gcgru layer
152 | if self.emb is not None:
153 | b, *_, s = x.shape # fwd_h: [batches, channels, nodes, steps]
154 | inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)] # stack emb for batches and steps
155 | imputation = torch.cat(inputs, dim=1)
156 | imputation = self.out(imputation)
157 | else:
158 | imputation = torch.stack([fwd_pred, bwd_pred], dim=1)
159 | imputation = self.out(imputation, dim=1)
160 |
161 | # In evaluation stage impute only missing values
162 | if self.impute_only_holes and not self.training:
163 | imputation = torch.where(mask, x, imputation)
164 |
165 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
166 | imputation = rearrange(imputation, 'b c n s -> b s n c')
167 |
168 | return imputation
169 |
170 | @staticmethod
171 | def add_model_specific_args(parser):
172 | parser.add_argument('--d-hidden', type=int, default=64)
173 | parser.add_argument('--d-ff', type=int, default=64)
174 | parser.add_argument('--n-layers', type=int, default=1)
175 | parser.add_argument('--kernel-size', type=int, default=2)
176 | parser.add_argument('--d-emb', type=int, default=8)
177 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
178 | parser.add_argument('--merge', type=str, default='mlp')
179 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
180 | parser.add_argument('--dropout', type=float, default=0.)
181 | parser.add_argument('--autoencoder-mode', type=str_to_bool, nargs='?', const=True, default=False)
182 | return parser
183 |
--------------------------------------------------------------------------------
/lib/nn/models/rgain.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .rnn_imputers import BiRNNImputer
5 | from ...utils.parser_utils import str_to_bool
6 |
7 |
8 | class Generator(nn.Module):
9 | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=True):
10 | super(Generator, self).__init__()
11 | self.inject_noise = inject_noise
12 | self.d_z = d_z if inject_noise else 0
13 | self.birnn = BiRNNImputer(d_in,
14 | d_model,
15 | d_u=d_z,
16 | concat_mask=True,
17 | detach_inputs=False,
18 | dropout=dropout,
19 | state_init='zero')
20 |
21 | def forward(self, x, mask):
22 | if self.inject_noise:
23 | z = torch.rand(x.size(0), x.size(1), self.d_z, device=x.device) * 0.1
24 | else:
25 | z = None
26 | return self.birnn(x, mask, u=z)
27 |
28 |
29 | class Discriminator(torch.nn.Module):
30 | def __init__(self, d_in, d_model, dropout=0.):
31 | super(Discriminator, self).__init__()
32 | self.birnn = nn.GRU(2 * d_in, d_model, bidirectional=True, batch_first=True)
33 | self.dropout = nn.Dropout(dropout)
34 | self.read_out = nn.Linear(2 * d_model, d_in)
35 |
36 | def forward(self, x, h):
37 | x_in = torch.cat([x, h], dim=-1)
38 | out, _ = self.birnn(x_in)
39 | logits = self.read_out(self.dropout(out))
40 | return logits
41 |
42 |
43 | class RGAINNet(torch.nn.Module):
44 | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=False, k=5):
45 | super(RGAINNet, self).__init__()
46 | self.inject_noise = inject_noise
47 | self.k = k
48 | self.generator = Generator(d_in, d_model, d_z=d_z, dropout=dropout, inject_noise=inject_noise)
49 | self.discriminator = Discriminator(d_in, d_model, dropout)
50 |
51 | def forward(self, x, mask, **kwargs):
52 | if not self.training and self.inject_noise:
53 | res = []
54 | for _ in range(self.k):
55 | res.append(self.generator(x, mask)[0])
56 | return torch.stack(res, 0).mean(0),
57 |
58 | return self.generator(x, mask)
59 |
60 | @staticmethod
61 | def add_model_specific_args(parser):
62 | parser.add_argument('--d-in', type=int)
63 | parser.add_argument('--d-model', type=int, default=None)
64 | parser.add_argument('--d-z', type=int, default=8)
65 | parser.add_argument('--k', type=int, default=5)
66 | parser.add_argument('--inject-noise', type=str_to_bool, nargs='?', const=True, default=False)
67 | parser.add_argument('--dropout', type=float, default=0.)
68 | return parser
69 |
--------------------------------------------------------------------------------
/lib/nn/models/rnn_imputers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ..utils.ops import reverse_tensor
5 |
6 |
7 | class RNNImputer(nn.Module):
8 | """Fill the blanks with a 1-step-ahead GRU predictor."""
9 |
10 | def __init__(self, d_in, d_model, concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):
11 | super(RNNImputer, self).__init__()
12 | self.concat_mask = concat_mask
13 | self.detach_inputs = detach_inputs
14 | self.state_init = state_init
15 | self.d_model = d_model
16 | self.input_dim = d_in + d_u if not concat_mask else 2 * d_in + d_u
17 | self.rnn_cell = nn.GRUCell(self.input_dim, d_model)
18 | self.read_out = nn.Linear(d_model, d_in)
19 |
20 | def init_hidden_state(self, x):
21 | if self.state_init == 'zero':
22 | return torch.zeros((x.size(0), self.d_model), device=x.device, dtype=x.dtype)
23 | if self.state_init == 'noise':
24 | return torch.randn(x.size(0), self.d_model, device=x.device, dtype=x.dtype)
25 |
26 | def _preprocess_input(self, x, x_hat, m, u):
27 | if self.detach_inputs:
28 | x_p = torch.where(m, x, x_hat.detach())
29 | else:
30 | x_p = torch.where(m, x, x_hat)
31 |
32 | if u is not None:
33 | x_p = torch.cat([x_p, u], -1)
34 | if self.concat_mask:
35 | x_p = torch.cat([x_p, m], -1)
36 | return x_p
37 |
38 | def forward(self, x, mask, u=None, return_hidden=False):
39 | # x: [batches, steps, features]
40 | steps = x.size(1)
41 | # ensure masked values are not visible
42 | x = torch.where(mask, x, torch.zeros_like(x))
43 |
44 | h = self.init_hidden_state(x)
45 | x_hat = self.read_out(h)
46 | hs = [h]
47 | preds = [x_hat]
48 | for s in range(steps - 1):
49 | u_t = None if u is None else u[:, s]
50 | x_t = self._preprocess_input(x[:, s], x_hat, mask[:, s], u_t)
51 | h = self.rnn_cell(x_t, h)
52 | x_hat = self.read_out(h)
53 | hs.append(h)
54 | preds.append(x_hat)
55 |
56 | x_hat = torch.stack(preds, 1)
57 | h = torch.stack(hs, 1)
58 | if return_hidden:
59 | return x_hat, h
60 | return x_hat
61 |
62 | @staticmethod
63 | def add_model_specific_args(parser):
64 | parser.add_argument('--d-in', type=int)
65 | parser.add_argument('--d-model', type=int, default=None)
66 | return parser
67 |
68 |
69 | class BiRNNImputer(nn.Module):
70 | """Fill the blanks with a 1-step-ahead GRU predictor."""
71 |
72 | def __init__(self, d_in, d_model, dropout=0., concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):
73 | super(BiRNNImputer, self).__init__()
74 | self.d_model = d_model
75 | self.fwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,
76 | d_u=d_u)
77 | self.bwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,
78 | d_u=d_u)
79 | self.dropout = nn.Dropout(dropout)
80 | self.read_out = nn.Linear(2 * d_model, d_in)
81 |
82 | def forward(self, x, mask, u=None, return_hidden=False):
83 | # x: [batches, steps, features]
84 | x_hat_fwd, h_fwd = self.fwd_rnn(x, mask, u=u, return_hidden=True)
85 | x_hat_bwd, h_bwd = self.bwd_rnn(reverse_tensor(x, 1),
86 | reverse_tensor(mask, 1),
87 | u=reverse_tensor(u, 1) if u is not None else None,
88 | return_hidden=True)
89 | x_hat_bwd = reverse_tensor(x_hat_bwd, 1)
90 | h_bwd = reverse_tensor(h_bwd, 1)
91 | h = self.dropout(torch.cat([h_fwd, h_bwd], -1))
92 | x_hat = self.read_out(h)
93 | if return_hidden:
94 | return (x_hat, x_hat_fwd, x_hat_bwd), h
95 | return x_hat, x_hat_fwd, x_hat_bwd
96 |
97 | @staticmethod
98 | def add_model_specific_args(parser):
99 | parser.add_argument('--d-in', type=int)
100 | parser.add_argument('--d-model', type=int, default=None)
101 | parser.add_argument('--dropout', type=float, default=0.)
102 | return parser
103 |
--------------------------------------------------------------------------------
/lib/nn/models/var.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import nn
4 |
5 | from lib import epsilon
6 |
7 |
8 | class VAR(nn.Module):
9 | def __init__(self, order, d_in, d_out=None, steps_ahead=1, bias=True):
10 | super(VAR, self).__init__()
11 | self.order = order
12 | self.d_in = d_in
13 | self.d_out = d_out if d_out is not None else d_in
14 | self.steps_ahead = steps_ahead
15 | self.lin = nn.Linear(order * d_in, steps_ahead * self.d_out, bias=bias)
16 |
17 | def forward(self, x):
18 | # x: [batches, steps, features]
19 | x = rearrange(x, 'b s f -> b (s f)')
20 | out = self.lin(x)
21 | out = rearrange(out, 'b (s f) -> b s f', s=self.steps_ahead, f=self.d_out)
22 | return out
23 |
24 | @staticmethod
25 | def add_model_specific_args(parser):
26 | parser.add_argument('--order', type=int)
27 | parser.add_argument('--d-in', type=int)
28 | parser.add_argument('--d-out', type=int, default=None)
29 | parser.add_argument('--steps-ahead', type=int, default=1)
30 | return parser
31 |
32 |
33 | class VARImputer(nn.Module):
34 | """Fill the blanks with a 1-step-ahead VAR predictor."""
35 |
36 | def __init__(self, order, d_in, padding='mean'):
37 | super(VARImputer, self).__init__()
38 | assert padding in ['mean', 'zero']
39 | self.order = order
40 | self.padding = padding
41 | self.predictor = VAR(order, d_in, d_out=d_in, steps_ahead=1)
42 |
43 | def forward(self, x, mask=None):
44 | # x: [batches, steps, features]
45 | batch_size, steps, n_feats = x.shape
46 | if mask is None:
47 | mask = torch.ones_like(x, dtype=torch.uint8)
48 | x = x * mask
49 | # pad input sequence to start filling from first step
50 | if self.padding == 'mean':
51 | mean = torch.sum(x, 1) / (torch.sum(mask, 1) + epsilon)
52 | pad = torch.repeat_interleave(mean.unsqueeze(1), self.order, 1)
53 | elif self.padding == 'zero':
54 | pad = torch.zeros((batch_size, self.order, n_feats)).to(x.device)
55 | x = torch.cat([pad, x], 1)
56 | # x: [batch, order + steps, features]
57 | x = [x[:, i] for i in range(x.shape[1])]
58 | for s in range(steps):
59 | x_hat = self.predictor(torch.stack(x[s:s + self.order], 1))
60 | x_hat = x_hat[:, 0]
61 | x[s + self.order] = torch.where(mask[:, s], x[s + self.order], x_hat)
62 | x = torch.stack(x[self.order:], 1) # remove padding
63 | return x
64 |
65 | @staticmethod
66 | def add_model_specific_args(parser):
67 | parser.add_argument('--order', type=int)
68 | parser.add_argument('--d-in', type=int)
69 | parser.add_argument("--padding", type=str, default='mean')
70 | return parser
71 |
--------------------------------------------------------------------------------
/lib/nn/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/grin/4a28afbb092600b6e6abeeabaaf67e87dbd1ed6e/lib/nn/utils/__init__.py
--------------------------------------------------------------------------------
/lib/nn/utils/metric_base.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from pytorch_lightning.metrics import Metric
5 | from torchmetrics.utilities.checks import _check_same_shape
6 |
7 |
8 | class MaskedMetric(Metric):
9 | def __init__(self,
10 | metric_fn,
11 | mask_nans=False,
12 | mask_inf=False,
13 | compute_on_step=True,
14 | dist_sync_on_step=False,
15 | process_group=None,
16 | dist_sync_fn=None,
17 | metric_kwargs=None,
18 | at=None):
19 | super(MaskedMetric, self).__init__(compute_on_step=compute_on_step,
20 | dist_sync_on_step=dist_sync_on_step,
21 | process_group=process_group,
22 | dist_sync_fn=dist_sync_fn)
23 |
24 | if metric_kwargs is None:
25 | metric_kwargs = dict()
26 | self.metric_fn = partial(metric_fn, **metric_kwargs)
27 | self.mask_nans = mask_nans
28 | self.mask_inf = mask_inf
29 | if at is None:
30 | self.at = slice(None)
31 | else:
32 | self.at = slice(at, at + 1)
33 | self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0.).float())
34 | self.add_state('numel', dist_reduce_fx='sum', default=torch.tensor(0))
35 |
36 | def _check_mask(self, mask, val):
37 | if mask is None:
38 | mask = torch.ones_like(val).byte()
39 | else:
40 | _check_same_shape(mask, val)
41 | if self.mask_nans:
42 | mask = mask * ~torch.isnan(val)
43 | if self.mask_inf:
44 | mask = mask * ~torch.isinf(val)
45 | return mask
46 |
47 | def _compute_masked(self, y_hat, y, mask):
48 | _check_same_shape(y_hat, y)
49 | val = self.metric_fn(y_hat, y)
50 | mask = self._check_mask(mask, val)
51 | val = torch.where(mask, val, torch.tensor(0., device=val.device).float())
52 | return val.sum(), mask.sum()
53 |
54 | def _compute_std(self, y_hat, y):
55 | _check_same_shape(y_hat, y)
56 | val = self.metric_fn(y_hat, y)
57 | return val.sum(), val.numel()
58 |
59 | def is_masked(self, mask):
60 | return self.mask_inf or self.mask_nans or (mask is not None)
61 |
62 | def update(self, y_hat, y, mask=None):
63 | y_hat = y_hat[:, self.at]
64 | y = y[:, self.at]
65 | if mask is not None:
66 | mask = mask[:, self.at]
67 | if self.is_masked(mask):
68 | val, numel = self._compute_masked(y_hat, y, mask)
69 | else:
70 | val, numel = self._compute_std(y_hat, y)
71 | self.value += val
72 | self.numel += numel
73 |
74 | def compute(self):
75 | if self.numel > 0:
76 | return self.value / self.numel
77 | return self.value
78 |
--------------------------------------------------------------------------------
/lib/nn/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from .metric_base import MaskedMetric
2 | from .ops import mape
3 | from torch.nn import functional as F
4 | import torch
5 |
6 | from torchmetrics.utilities.checks import _check_same_shape
7 |
8 | from ... import epsilon
9 |
10 |
11 | class MaskedMAE(MaskedMetric):
12 | def __init__(self,
13 | mask_nans=False,
14 | mask_inf=False,
15 | compute_on_step=True,
16 | dist_sync_on_step=False,
17 | process_group=None,
18 | dist_sync_fn=None,
19 | at=None):
20 | super(MaskedMAE, self).__init__(metric_fn=F.l1_loss,
21 | mask_nans=mask_nans,
22 | mask_inf=mask_inf,
23 | compute_on_step=compute_on_step,
24 | dist_sync_on_step=dist_sync_on_step,
25 | process_group=process_group,
26 | dist_sync_fn=dist_sync_fn,
27 | metric_kwargs={'reduction': 'none'},
28 | at=at)
29 |
30 |
31 | class MaskedMAPE(MaskedMetric):
32 | def __init__(self,
33 | mask_nans=False,
34 | compute_on_step=True,
35 | dist_sync_on_step=False,
36 | process_group=None,
37 | dist_sync_fn=None,
38 | at=None):
39 | super(MaskedMAPE, self).__init__(metric_fn=mape,
40 | mask_nans=mask_nans,
41 | mask_inf=True,
42 | compute_on_step=compute_on_step,
43 | dist_sync_on_step=dist_sync_on_step,
44 | process_group=process_group,
45 | dist_sync_fn=dist_sync_fn,
46 | at=at)
47 |
48 |
49 | class MaskedMSE(MaskedMetric):
50 | def __init__(self,
51 | mask_nans=False,
52 | compute_on_step=True,
53 | dist_sync_on_step=False,
54 | process_group=None,
55 | dist_sync_fn=None,
56 | at=None):
57 | super(MaskedMSE, self).__init__(metric_fn=F.mse_loss,
58 | mask_nans=mask_nans,
59 | mask_inf=True,
60 | compute_on_step=compute_on_step,
61 | dist_sync_on_step=dist_sync_on_step,
62 | process_group=process_group,
63 | dist_sync_fn=dist_sync_fn,
64 | metric_kwargs={'reduction': 'none'},
65 | at=at)
66 |
67 |
68 | class MaskedMRE(MaskedMetric):
69 | def __init__(self,
70 | mask_nans=False,
71 | mask_inf=False,
72 | compute_on_step=True,
73 | dist_sync_on_step=False,
74 | process_group=None,
75 | dist_sync_fn=None,
76 | at=None):
77 | super(MaskedMRE, self).__init__(metric_fn=F.l1_loss,
78 | mask_nans=mask_nans,
79 | mask_inf=mask_inf,
80 | compute_on_step=compute_on_step,
81 | dist_sync_on_step=dist_sync_on_step,
82 | process_group=process_group,
83 | dist_sync_fn=dist_sync_fn,
84 | metric_kwargs={'reduction': 'none'},
85 | at=at)
86 | self.add_state('tot', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float))
87 |
88 | def _compute_masked(self, y_hat, y, mask):
89 | _check_same_shape(y_hat, y)
90 | val = self.metric_fn(y_hat, y)
91 | mask = self._check_mask(mask, val)
92 | val = torch.where(mask, val, torch.tensor(0., device=y.device, dtype=torch.float))
93 | y_masked = torch.where(mask, y, torch.tensor(0., device=y.device, dtype=torch.float))
94 | return val.sum(), mask.sum(), y_masked.sum()
95 |
96 | def _compute_std(self, y_hat, y):
97 | _check_same_shape(y_hat, y)
98 | val = self.metric_fn(y_hat, y)
99 | return val.sum(), val.numel(), y.sum()
100 |
101 | def compute(self):
102 | if self.tot > epsilon:
103 | return self.value / self.tot
104 | return self.value
105 |
106 | def update(self, y_hat, y, mask=None):
107 | y_hat = y_hat[:, self.at]
108 | y = y[:, self.at]
109 | if mask is not None:
110 | mask = mask[:, self.at]
111 | if self.is_masked(mask):
112 | val, numel, tot = self._compute_masked(y_hat, y, mask)
113 | else:
114 | val, numel, tot = self._compute_std(y_hat, y)
115 | self.value += val
116 | self.numel += numel
117 | self.tot += tot
118 |
119 |
120 |
--------------------------------------------------------------------------------
/lib/nn/utils/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import reduce
4 | from torch.autograd import Variable
5 |
6 | from ... import epsilon
7 |
8 |
9 | def mae(y_hat, y, reduction='none'):
10 | return F.l1_loss(y_hat, y, reduction=reduction)
11 |
12 |
13 | def mape(y_hat, y):
14 | return torch.abs((y_hat - y) / y)
15 |
16 |
17 | def wape_loss(y_hat, y):
18 | l = torch.abs(y_hat - y)
19 | return l.sum() / (y.sum() + epsilon)
20 |
21 |
22 | def smape_loss(y_hat, y):
23 | c = torch.abs(y) > epsilon
24 | l_minus = torch.abs(y_hat - y)
25 | l_plus = torch.abs(y_hat + y) + epsilon
26 | l = 2 * l_minus / l_plus * c.float()
27 | return l.sum() / c.sum()
28 |
29 |
30 | def peak_prediction_loss(y_hat, y, reduction='none'):
31 | y_max = reduce(y, 'b s n 1 -> b 1 n 1', 'max')
32 | y_min = reduce(y, 'b s n 1 -> b 1 n 1', 'min')
33 | target = torch.cat([y_max, y_min], dim=1)
34 | return F.mse_loss(y_hat, target, reduction=reduction)
35 |
36 |
37 | def wrap_loss_fn(base_loss):
38 | def loss_fn(y_hat, y_true, mask=None):
39 | scaling = 1.
40 | if mask is not None:
41 | try:
42 | loss = base_loss(y_hat, y_true, reduction='none')
43 | except TypeError:
44 | loss = base_loss(y_hat, y_true)
45 | loss = loss * mask
46 | loss = loss.sum() / (mask.sum() + epsilon)
47 | # scaling = mask.sum() / torch.numel(mask)
48 | else:
49 | loss = base_loss(y_hat, y_true).mean()
50 | return scaling * loss
51 |
52 | return loss_fn
53 |
54 |
55 | def rbf_sim(x, gamma, device='cpu'):
56 | n = x.size()[0]
57 | a = torch.exp(-gamma * F.pdist(x, 2) ** 2)
58 | row_idx, col_idx = torch.triu_indices(n, n, 1)
59 | A = 0.5 * torch.eye(n, n).to(device)
60 | A[row_idx, col_idx] = a
61 | return A + A.T
62 |
63 |
64 | def reverse_tensor(tensor=None, axis=-1):
65 | if tensor is None:
66 | return None
67 | if tensor.dim() <= 1:
68 | return tensor
69 | indices = range(tensor.size()[axis])[::-1]
70 | indices = Variable(torch.LongTensor(indices), requires_grad=False).to(tensor.device)
71 | return tensor.index_select(axis, indices)
72 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
--------------------------------------------------------------------------------
/lib/utils/numpy_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def mae(y_hat, y):
5 | return np.abs(y_hat - y).mean()
6 |
7 |
8 | def nmae(y_hat, y):
9 | delta = np.max(y) - np.min(y) + 1e-8
10 | return mae(y_hat, y) * 100 / delta
11 |
12 |
13 | def mape(y_hat, y):
14 | return 100 * np.abs((y_hat - y) / (y + 1e-8)).mean()
15 |
16 |
17 | def mse(y_hat, y):
18 | return np.square(y_hat - y).mean()
19 |
20 |
21 | def rmse(y_hat, y):
22 | return np.sqrt(mse(y_hat, y))
23 |
24 |
25 | def nrmse(y_hat, y):
26 | delta = np.max(y) - np.min(y) + 1e-8
27 | return rmse(y_hat, y) * 100 / delta
28 |
29 |
30 | def nrmse_2(y_hat, y):
31 | nrmse_ = np.sqrt(np.square(y_hat - y).sum() / np.square(y).sum())
32 | return nrmse_ * 100
33 |
34 |
35 | def r2(y_hat, y):
36 | return 1. - np.square(y_hat - y).sum() / (np.square(y.mean(0) - y).sum())
37 |
38 |
39 | def masked_mae(y_hat, y, mask):
40 | err = np.abs(y_hat - y) * mask
41 | return err.sum() / mask.sum()
42 |
43 |
44 | def masked_mape(y_hat, y, mask):
45 | err = np.abs((y_hat - y) / (y + 1e-8)) * mask
46 | return err.sum() / mask.sum()
47 |
48 |
49 | def masked_mse(y_hat, y, mask):
50 | err = np.square(y_hat - y) * mask
51 | return err.sum() / mask.sum()
52 |
53 |
54 | def masked_rmse(y_hat, y, mask):
55 | err = np.square(y_hat - y) * mask
56 | return np.sqrt(err.sum() / mask.sum())
57 |
58 |
59 | def masked_mre(y_hat, y, mask):
60 | err = np.abs(y_hat - y) * mask
61 | return err.sum() / ((y * mask).sum() + 1e-8)
62 |
--------------------------------------------------------------------------------
/lib/utils/parser_utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from argparse import Namespace, ArgumentParser
3 | from typing import Union
4 |
5 |
6 | def str_to_bool(value):
7 | if isinstance(value, bool):
8 | return value
9 | if value.lower() in {'false', 'f', '0', 'no', 'n', 'off'}:
10 | return False
11 | elif value.lower() in {'true', 't', '1', 'yes', 'y', 'on'}:
12 | return True
13 | raise ValueError(f'{value} is not a valid boolean value')
14 |
15 |
16 | def config_dict_from_args(args):
17 | """
18 | Extract a dictionary with the experiment configuration from arguments (necessary to filter TestTube arguments)
19 |
20 | :param args: TTNamespace
21 | :return: hyparams dict
22 | """
23 | keys_to_remove = {'hpc_exp_number', 'trials', 'optimize_parallel', 'optimize_parallel_gpu',
24 | 'optimize_parallel_cpu', 'generate_trials', 'optimize_trials_parallel_gpu'}
25 | hparams = {key: v for key, v in args.__dict__.items() if key not in keys_to_remove}
26 | return hparams
27 |
28 |
29 | def update_from_config(args: Namespace, config: dict):
30 | assert set(config.keys()) <= set(vars(args)), f'{set(config.keys()).difference(vars(args))} not in args.'
31 | args.__dict__.update(config)
32 | return args
33 |
34 |
35 | def parse_by_group(parser):
36 | """
37 | Create a nested namespace using the groups defined in the argument parser.
38 | Adapted from https://stackoverflow.com/a/56631542/6524027
39 |
40 | :param args: arguments
41 | :param parser: the parser
42 | :return:
43 | """
44 | assert isinstance(parser, ArgumentParser)
45 | args = parser.parse_args()
46 |
47 | # the first two argument groups are 'positional_arguments' and 'optional_arguments'
48 | pos_group, optional_group = parser._action_groups[0], parser._action_groups[1]
49 | args_dict = args._get_kwargs()
50 | pos_optional_arg_names = [arg.dest for arg in pos_group._group_actions] + [arg.dest for arg in
51 | optional_group._group_actions]
52 | pos_optional_args = {name: value for name, value in args_dict if name in pos_optional_arg_names}
53 | other_group_args = dict()
54 |
55 | # If there are additional argument groups, add them as nested namespaces
56 | if len(parser._action_groups) > 2:
57 | for group in parser._action_groups[2:]:
58 | group_arg_names = [arg.dest for arg in group._group_actions]
59 | other_group_args[group.title] = Namespace(
60 | **{name: value for name, value in args_dict if name in group_arg_names})
61 |
62 | # combine the positiona/optional args and the group args
63 | combined_args = pos_optional_args
64 | combined_args.update(other_group_args)
65 | return Namespace(flat=args, **combined_args)
66 |
67 |
68 | def filter_args(args: Union[Namespace, dict], target_cls, return_dict=False):
69 | argspec = inspect.getfullargspec(target_cls.__init__)
70 | target_args = argspec.args
71 | if isinstance(args, Namespace):
72 | args = vars(args)
73 | filtered_args = {k: args[k] for k in target_args if k in args}
74 | if return_dict:
75 | return filtered_args
76 | return Namespace(**filtered_args)
77 |
78 |
79 | def filter_function_args(args: Union[Namespace, dict], function, return_dict=False):
80 | argspec = inspect.getfullargspec(function)
81 | target_args = argspec.args
82 | if isinstance(args, Namespace):
83 | args = vars(args)
84 | filtered_args = {k: args[k] for k in target_args if k in args}
85 | if return_dict:
86 | return filtered_args
87 | return Namespace(**filtered_args)
88 |
--------------------------------------------------------------------------------
/lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from sklearn.metrics.pairwise import haversine_distances
5 |
6 |
7 | def sample_mask(shape, p=0.002, p_noise=0., max_seq=1, min_seq=1, rng=None):
8 | if rng is None:
9 | rand = np.random.random
10 | randint = np.random.randint
11 | else:
12 | rand = rng.random
13 | randint = rng.integers
14 | mask = rand(shape) < p
15 | for col in range(mask.shape[1]):
16 | idxs = np.flatnonzero(mask[:, col])
17 | if not len(idxs):
18 | continue
19 | fault_len = min_seq
20 | if max_seq > min_seq:
21 | fault_len = fault_len + int(randint(max_seq - min_seq))
22 | idxs_ext = np.concatenate([np.arange(i, i + fault_len) for i in idxs])
23 | idxs = np.unique(idxs_ext)
24 | idxs = np.clip(idxs, 0, shape[0] - 1)
25 | mask[idxs, col] = True
26 | mask = mask | (rand(mask.shape) < p_noise)
27 | return mask.astype('uint8')
28 |
29 |
30 | def compute_mean(x, index=None):
31 | """Compute the mean values for each datetime. The mean is first computed hourly over the week of the year.
32 | Further NaN values are computed using hourly mean over the same month through the years. If other NaN are present,
33 | they are removed using the mean of the sole hours. Hoping reasonably that there is at least a non-NaN entry of the
34 | same hour of the NaN datetime in all the dataset."""
35 | if isinstance(x, np.ndarray) and index is not None:
36 | shape = x.shape
37 | x = x.reshape((shape[0], -1))
38 | df_mean = pd.DataFrame(x, index=index)
39 | else:
40 | df_mean = x.copy()
41 | cond0 = [df_mean.index.year, df_mean.index.isocalendar().week, df_mean.index.hour]
42 | cond1 = [df_mean.index.year, df_mean.index.month, df_mean.index.hour]
43 | conditions = [cond0, cond1, cond1[1:], cond1[2:]]
44 | while df_mean.isna().values.sum() and len(conditions):
45 | nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
46 | df_mean = df_mean.fillna(nan_mean)
47 | conditions = conditions[1:]
48 | if df_mean.isna().values.sum():
49 | df_mean = df_mean.fillna(method='ffill')
50 | df_mean = df_mean.fillna(method='bfill')
51 | if isinstance(x, np.ndarray):
52 | df_mean = df_mean.values.reshape(shape)
53 | return df_mean
54 |
55 |
56 | def geographical_distance(x=None, to_rad=True):
57 | """
58 | Compute the as-the-crow-flies distance between every pair of samples in `x`. The first dimension of each point is
59 | assumed to be the latitude, the second is the longitude. The inputs is assumed to be in degrees. If it is not the
60 | case, `to_rad` must be set to False. The dimension of the data must be 2.
61 |
62 | Parameters
63 | ----------
64 | x : pd.DataFrame or np.ndarray
65 | array_like structure of shape (n_samples_2, 2).
66 | to_rad : bool
67 | whether to convert inputs to radians (provided that they are in degrees).
68 |
69 | Returns
70 | -------
71 | distances :
72 | The distance between the points in kilometers.
73 | """
74 | _AVG_EARTH_RADIUS_KM = 6371.0088
75 |
76 | # Extract values of X if it is a DataFrame, else assume it is 2-dim array of lat-lon pairs
77 | latlon_pairs = x.values if isinstance(x, pd.DataFrame) else x
78 |
79 | # If the input values are in degrees, convert them in radians
80 | if to_rad:
81 | latlon_pairs = np.vectorize(np.radians)(latlon_pairs)
82 |
83 | distances = haversine_distances(latlon_pairs) * _AVG_EARTH_RADIUS_KM
84 |
85 | # Cast response
86 | if isinstance(x, pd.DataFrame):
87 | res = pd.DataFrame(distances, x.index, x.index)
88 | else:
89 | res = distances
90 |
91 | return res
92 |
93 |
94 | def infer_mask(df, infer_from='next'):
95 | """Infer evaluation mask from DataFrame. In the evaluation mask a value is 1 if it is present in the DataFrame and
96 | absent in the `infer_from` month.
97 |
98 | @param pd.DataFrame df: the DataFrame.
99 | @param str infer_from: denotes from which month the evaluation value must be inferred.
100 | Can be either `previous` or `next`.
101 | @return: pd.DataFrame eval_mask: the evaluation mask for the DataFrame
102 | """
103 | mask = (~df.isna()).astype('uint8')
104 | eval_mask = pd.DataFrame(index=mask.index, columns=mask.columns, data=0).astype('uint8')
105 | if infer_from == 'previous':
106 | offset = -1
107 | elif infer_from == 'next':
108 | offset = 1
109 | else:
110 | raise ValueError('infer_from can only be one of %s' % ['previous', 'next'])
111 | months = sorted(set(zip(mask.index.year, mask.index.month)))
112 | length = len(months)
113 | for i in range(length):
114 | j = (i + offset) % length
115 | year_i, month_i = months[i]
116 | year_j, month_j = months[j]
117 | mask_j = mask[(mask.index.year == year_j) & (mask.index.month == month_j)]
118 | mask_i = mask_j.shift(1, pd.DateOffset(months=12 * (year_i - year_j) + (month_i - month_j)))
119 | mask_i = mask_i[~mask_i.index.duplicated(keep='first')]
120 | mask_i = mask_i[np.in1d(mask_i.index, mask.index)]
121 | eval_mask.loc[mask_i.index] = ~mask_i.loc[mask_i.index] & mask.loc[mask_i.index]
122 | return eval_mask
123 |
124 |
125 | def prediction_dataframe(y, index, columns=None, aggregate_by='mean'):
126 | """Aggregate batched predictions in a single DataFrame.
127 |
128 | @param (list or np.ndarray) y: the list of predictions.
129 | @param (list or np.ndarray) index: the list of time indexes coupled with the predictions.
130 | @param (list or pd.Index) columns: the columns of the returned DataFrame.
131 | @param (str or list) aggregate_by: how to aggregate the predictions in case there are more than one for a step.
132 | - `mean`: take the mean of the predictions
133 | - `central`: take the prediction at the central position, assuming that the predictions are ordered chronologically
134 | - `smooth_central`: average the predictions weighted by a gaussian signal with std=1
135 | - `last`: take the last prediction
136 | @return: pd.DataFrame df: the evaluation mask for the DataFrame
137 | """
138 | dfs = [pd.DataFrame(data=data.reshape(data.shape[:2]), index=idx, columns=columns) for data, idx in zip(y, index)]
139 | df = pd.concat(dfs)
140 | preds_by_step = df.groupby(df.index)
141 | # aggregate according passed methods
142 | aggr_methods = ensure_list(aggregate_by)
143 | dfs = []
144 | for aggr_by in aggr_methods:
145 | if aggr_by == 'mean':
146 | dfs.append(preds_by_step.mean())
147 | elif aggr_by == 'central':
148 | dfs.append(preds_by_step.aggregate(lambda x: x[int(len(x) // 2)]))
149 | elif aggr_by == 'smooth_central':
150 | from scipy.signal import gaussian
151 | dfs.append(preds_by_step.aggregate(lambda x: np.average(x, weights=gaussian(len(x), 1))))
152 | elif aggr_by == 'last':
153 | dfs.append(preds_by_step.aggregate(lambda x: x[0])) # first imputation has missing value in last position
154 | else:
155 | raise ValueError('aggregate_by can only be one of %s' % ['mean', 'central' 'smooth_central', 'last'])
156 | if isinstance(aggregate_by, str):
157 | return dfs[0]
158 | return dfs
159 |
160 |
161 | def ensure_list(obj):
162 | if isinstance(obj, (list, tuple)):
163 | return list(obj)
164 | else:
165 | return [obj]
166 |
167 |
168 | def missing_val_lens(mask):
169 | m = np.concatenate([np.zeros((1, mask.shape[1])),
170 | (~mask.astype('bool')).astype('int'),
171 | np.zeros((1, mask.shape[1]))])
172 | mdiff = np.diff(m, axis=0)
173 | lens = []
174 | for c in range(m.shape[1]):
175 | mj, = mdiff[:, c].nonzero()
176 | diff = np.diff(mj)[::2]
177 | lens.extend(list(diff))
178 | return lens
179 |
180 |
181 | def disjoint_months(dataset, months=None, synch_mode='window'):
182 | idxs = np.arange(len(dataset))
183 | months = ensure_list(months)
184 | # divide indices according to window or horizon
185 | if synch_mode == 'window':
186 | start, end = 0, dataset.window - 1
187 | elif synch_mode == 'horizon':
188 | start, end = dataset.horizon_offset, dataset.horizon_offset + dataset.horizon - 1
189 | else:
190 | raise ValueError('synch_mode can only be one of %s' % ['window', 'horizon'])
191 | # after idxs
192 | start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
193 | end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
194 | idxs_in_months = start_in_months & end_in_months
195 | after_idxs = idxs[idxs_in_months]
196 | # previous idxs
197 | months = np.setdiff1d(np.arange(1, 13), months)
198 | start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
199 | end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
200 | idxs_in_months = start_in_months & end_in_months
201 | prev_idxs = idxs[idxs_in_months]
202 | return prev_idxs, after_idxs
203 |
204 |
205 | def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False):
206 | if theta is None:
207 | theta = np.std(x)
208 | weights = np.exp(-np.square(x / theta))
209 | if threshold is not None:
210 | mask = x > threshold if threshold_on_input else weights < threshold
211 | weights[mask] = 0.
212 | return weights
213 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | fancyimpute==0.6
3 | h5py
4 | openpyxl
5 | numpy
6 | pandas
7 | pytorch-lightning==1.4
8 | pyyaml
9 | scikit-learn
10 | scipy
11 | tables
12 | tensorboard
13 | tensorflow==2.5.0
14 | tensorflow-gpu==2.4.0
15 | torch==1.8
16 | torchvision
17 | torchaudio
18 | torchmetrics==0.5
19 |
--------------------------------------------------------------------------------
/scripts/run_baselines.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import numpy as np
4 | from fancyimpute import MatrixFactorization, IterativeImputer
5 | from sklearn.neighbors import kneighbors_graph
6 |
7 | from lib import datasets
8 | from lib.utils import numpy_metrics
9 | from lib.utils.parser_utils import str_to_bool
10 |
11 | metrics = {
12 | 'mae': numpy_metrics.masked_mae,
13 | 'mse': numpy_metrics.masked_mse,
14 | 'mre': numpy_metrics.masked_mre,
15 | 'mape': numpy_metrics.masked_mape
16 | }
17 |
18 |
19 | def parse_args():
20 | parser = ArgumentParser()
21 | # experiment setting
22 | parser.add_argument('--datasets', nargs='+', type=str, default=['all'])
23 | parser.add_argument('--imputers', nargs='+', type=str, default=['all'])
24 | parser.add_argument('--n-runs', type=int, default=5)
25 | parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=True)
26 | # SpatialKNNImputer params
27 | parser.add_argument('--k', type=int, default=10)
28 | # MFImputer params
29 | parser.add_argument('--rank', type=int, default=10)
30 | # MICEImputer params
31 | parser.add_argument('--mice-iterations', type=int, default=100)
32 | parser.add_argument('--mice-n-features', type=int, default=None)
33 | args = parser.parse_args()
34 | # parse dataset
35 | if args.datasets[0] == 'all':
36 | args.datasets = ['air36', 'air', 'bay', 'irish', 'la', 'bay_noise', 'irish_noise', 'la_noise']
37 | # parse imputers
38 | if args.imputers[0] == 'all':
39 | args.imputers = ['mean', 'knn', 'mf', 'mice']
40 | if not args.in_sample:
41 | args.imputers = [name for name in args.imputers if name in ['mean', 'mice']]
42 | return args
43 |
44 |
45 | class Imputer:
46 | short_name: str
47 |
48 | def __init__(self, method=None, is_deterministic=True, in_sample=True):
49 | self.name = self.__class__.__name__
50 | self.method = method
51 | self.is_deterministic = is_deterministic
52 | self.in_sample = in_sample
53 |
54 | def fit(self, x, mask):
55 | if not self.in_sample:
56 | x_hat = np.where(mask, x, np.nan)
57 | return self.method.fit(x_hat)
58 |
59 | def predict(self, x, mask):
60 | x_hat = np.where(mask, x, np.nan)
61 | if self.in_sample:
62 | return self.method.fit_transform(x_hat)
63 | else:
64 | return self.method.transform(x_hat)
65 |
66 | def params(self):
67 | return dict()
68 |
69 |
70 | class SpatialKNNImputer(Imputer):
71 | short_name = 'knn'
72 |
73 | def __init__(self, adj, k=20):
74 | super(SpatialKNNImputer, self).__init__()
75 | self.k = k
76 | # normalize sim between [0, 1]
77 | sim = (adj + adj.min()) / (adj.max() + adj.min())
78 | knns = kneighbors_graph(1 - sim,
79 | n_neighbors=self.k,
80 | include_self=False,
81 | metric='precomputed').toarray()
82 | self.knns = knns
83 |
84 | def fit(self, x, mask):
85 | pass
86 |
87 | def predict(self, x, mask):
88 | x = np.where(mask, x, 0)
89 | with np.errstate(divide='ignore', invalid='ignore'):
90 | y_hat = (x @ self.knns.T) / (mask @ self.knns.T)
91 | y_hat[~np.isfinite(y_hat)] = x.mean()
92 | return np.where(mask, x, y_hat)
93 |
94 | def params(self):
95 | return dict(k=self.k)
96 |
97 |
98 | class MeanImputer(Imputer):
99 | short_name = 'mean'
100 |
101 | def fit(self, x, mask):
102 | d = np.where(mask, x, np.nan)
103 | self.means = np.nanmean(d, axis=0, keepdims=True)
104 |
105 | def predict(self, x, mask):
106 | if self.in_sample:
107 | d = np.where(mask, x, np.nan)
108 | means = np.nanmean(d, axis=0, keepdims=True)
109 | else:
110 | means = self.means
111 | return np.where(mask, x, means)
112 |
113 |
114 | class MatrixFactorizationImputer(Imputer):
115 | short_name = 'mf'
116 |
117 | def __init__(self, rank=10, loss='mae', verbose=0):
118 | method = MatrixFactorization(rank=rank, loss=loss, verbose=verbose)
119 | super(MatrixFactorizationImputer, self).__init__(method, is_deterministic=False, in_sample=True)
120 |
121 | def params(self):
122 | return dict(rank=self.method.rank)
123 |
124 |
125 | class MICEImputer(Imputer):
126 | short_name = 'mice'
127 |
128 | def __init__(self, max_iter=100, n_nearest_features=None, in_sample=True, verbose=False):
129 | method = IterativeImputer(max_iter=max_iter, n_nearest_features=n_nearest_features, verbose=verbose)
130 | is_deterministic = n_nearest_features is None
131 | super(MICEImputer, self).__init__(method, is_deterministic=is_deterministic, in_sample=in_sample)
132 |
133 | def params(self):
134 | return dict(max_iter=self.method.max_iter, k=self.method.n_nearest_features or -1)
135 |
136 |
137 | def get_dataset(dataset_name):
138 | if dataset_name[:3] == 'air':
139 | dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36')
140 | elif dataset_name == 'bay':
141 | dataset = datasets.MissingValuesPemsBay()
142 | elif dataset_name == 'la':
143 | dataset = datasets.MissingValuesMetrLA()
144 | elif dataset_name == 'la_noise':
145 | dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25)
146 | elif dataset_name == 'bay_noise':
147 | dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25)
148 | else:
149 | raise ValueError(f"Dataset {dataset_name} not available in this setting.")
150 |
151 | # split in train/test
152 | if isinstance(dataset, datasets.AirQuality):
153 | test_slice = np.in1d(dataset.df.index.month, dataset.test_months)
154 | train_slice = ~test_slice
155 | else:
156 | train_slice = np.zeros(len(dataset)).astype(bool)
157 | train_slice[:-int(0.2 * len(dataset))] = True
158 |
159 | # integrate back eval values in dataset
160 | dataset.eval_mask[train_slice] = 0
161 |
162 | return dataset, train_slice
163 |
164 |
165 | def get_imputer(imputer_name, args):
166 | if imputer_name == 'mean':
167 | imputer = MeanImputer(in_sample=args.in_sample)
168 | elif imputer_name == 'knn':
169 | imputer = SpatialKNNImputer(adj=args.adj, k=args.k)
170 | elif imputer_name == 'mf':
171 | imputer = MatrixFactorizationImputer(rank=args.rank)
172 | elif imputer_name == 'mice':
173 | imputer = MICEImputer(max_iter=args.mice_iterations,
174 | n_nearest_features=args.mice_n_features,
175 | in_sample=args.in_sample)
176 | else:
177 | raise ValueError(f"Imputer {imputer_name} not available in this setting.")
178 | return imputer
179 |
180 |
181 | def run(imputer, dataset, train_slice):
182 | test_slice = ~train_slice
183 | if args.in_sample:
184 | x_train, mask_train = dataset.numpy(), dataset.training_mask
185 | y_hat = imputer.predict(x_train, mask_train)[test_slice]
186 | else:
187 | x_train, mask_train = dataset.numpy()[train_slice], dataset.training_mask[train_slice]
188 | imputer.fit(x_train, mask_train)
189 | x_test, mask_test = dataset.numpy()[test_slice], dataset.training_mask[test_slice]
190 | y_hat = imputer.predict(x_test, mask_test)
191 |
192 | # Evaluate model
193 | y_true = dataset.numpy()[test_slice]
194 | eval_mask = dataset.eval_mask[test_slice]
195 |
196 | for metric, metric_fn in metrics.items():
197 | error = metric_fn(y_hat, y_true, eval_mask)
198 | print(f'{imputer.name} on {ds_name} {metric}: {error:.4f}')
199 |
200 |
201 | if __name__ == '__main__':
202 |
203 | args = parse_args()
204 | print(args.__dict__)
205 |
206 | for ds_name in args.datasets:
207 |
208 | dataset, train_slice = get_dataset(ds_name)
209 | args.adj = dataset.get_similarity(thr=0.1)
210 |
211 | # Instantiate imputers
212 | imputers = [get_imputer(name, args) for name in args.imputers]
213 |
214 | for imputer in imputers:
215 | n_runs = 1 if imputer.is_deterministic else args.n_runs
216 | for _ in range(n_runs):
217 | run(imputer, dataset, train_slice)
218 |
--------------------------------------------------------------------------------
/scripts/run_imputation.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import os
4 | import pathlib
5 | from argparse import ArgumentParser
6 |
7 | import numpy as np
8 | import pytorch_lightning as pl
9 | import torch
10 | import torch.nn.functional as F
11 | import yaml
12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
13 | from pytorch_lightning.loggers import TensorBoardLogger
14 | from torch.optim.lr_scheduler import CosineAnnealingLR
15 |
16 | from lib import fillers, datasets, config
17 | from lib.data.datamodule import SpatioTemporalDataModule
18 | from lib.data.imputation_dataset import ImputationDataset, GraphImputationDataset
19 | from lib.nn import models
20 | from lib.nn.utils.metric_base import MaskedMetric
21 | from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE
22 | from lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe
23 | from lib.utils.parser_utils import str_to_bool
24 |
25 |
26 | def has_graph_support(model_cls):
27 | return model_cls in [models.GRINet, models.MPGRUNet, models.BiMPGRUNet]
28 |
29 |
30 | def get_model_classes(model_str):
31 | if model_str == 'brits':
32 | model, filler = models.BRITSNet, fillers.BRITSFiller
33 | elif model_str == 'grin':
34 | model, filler = models.GRINet, fillers.GraphFiller
35 | elif model_str == 'mpgru':
36 | model, filler = models.MPGRUNet, fillers.GraphFiller
37 | elif model_str == 'bimpgru':
38 | model, filler = models.BiMPGRUNet, fillers.GraphFiller
39 | elif model_str == 'var':
40 | model, filler = models.VARImputer, fillers.Filler
41 | elif model_str == 'gain':
42 | model, filler = models.RGAINNet, fillers.RGAINFiller
43 | elif model_str == 'birnn':
44 | model, filler = models.BiRNNImputer, fillers.MultiImputationFiller
45 | elif model_str == 'rnn':
46 | model, filler = models.RNNImputer, fillers.Filler
47 | else:
48 | raise ValueError(f'Model {model_str} not available.')
49 | return model, filler
50 |
51 |
52 | def get_dataset(dataset_name):
53 | if dataset_name[:3] == 'air':
54 | dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36')
55 | elif dataset_name == 'bay_block':
56 | dataset = datasets.MissingValuesPemsBay()
57 | elif dataset_name == 'la_block':
58 | dataset = datasets.MissingValuesMetrLA()
59 | elif dataset_name == 'la_point':
60 | dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25)
61 | elif dataset_name == 'bay_point':
62 | dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25)
63 | else:
64 | raise ValueError(f"Dataset {dataset_name} not available in this setting.")
65 | return dataset
66 |
67 |
68 | def parse_args():
69 | # Argument parser
70 | parser = ArgumentParser()
71 | parser.add_argument('--seed', type=int, default=-1)
72 | parser.add_argument("--model-name", type=str, default='brits')
73 | parser.add_argument("--dataset-name", type=str, default='air36')
74 | parser.add_argument("--config", type=str, default=None)
75 | # Splitting/aggregation params
76 | parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=False)
77 | parser.add_argument('--val-len', type=float, default=0.1)
78 | parser.add_argument('--test-len', type=float, default=0.2)
79 | parser.add_argument('--aggregate-by', type=str, default='mean')
80 | # Training params
81 | parser.add_argument('--lr', type=float, default=0.001)
82 | parser.add_argument('--epochs', type=int, default=300)
83 | parser.add_argument('--patience', type=int, default=40)
84 | parser.add_argument('--l2-reg', type=float, default=0.)
85 | parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=True)
86 | parser.add_argument('--grad-clip-val', type=float, default=5.)
87 | parser.add_argument('--grad-clip-algorithm', type=str, default='norm')
88 | parser.add_argument('--loss-fn', type=str, default='l1_loss')
89 | parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True)
90 | parser.add_argument('--consistency-loss', type=str_to_bool, nargs='?', const=True, default=False)
91 | parser.add_argument('--whiten-prob', type=float, default=0.05)
92 | parser.add_argument('--pred-loss-weight', type=float, default=1.0)
93 | parser.add_argument('--warm-up', type=int, default=0)
94 | # graph params
95 | parser.add_argument("--adj-threshold", type=float, default=0.1)
96 | # gain hparams
97 | parser.add_argument('--alpha', type=float, default=10.)
98 | parser.add_argument('--hint-rate', type=float, default=0.7)
99 | parser.add_argument('--g-train-freq', type=int, default=1)
100 | parser.add_argument('--d-train-freq', type=int, default=5)
101 |
102 | known_args, _ = parser.parse_known_args()
103 | model_cls, _ = get_model_classes(known_args.model_name)
104 | parser = model_cls.add_model_specific_args(parser)
105 | parser = SpatioTemporalDataModule.add_argparse_args(parser)
106 | parser = ImputationDataset.add_argparse_args(parser)
107 |
108 | args = parser.parse_args()
109 | if args.config is not None:
110 | with open(args.config, 'r') as fp:
111 | config_args = yaml.load(fp, Loader=yaml.FullLoader)
112 | for arg in config_args:
113 | setattr(args, arg, config_args[arg])
114 |
115 | return args
116 |
117 |
118 | def run_experiment(args):
119 | # Set configuration and seed
120 | args = copy.deepcopy(args)
121 | if args.seed < 0:
122 | args.seed = np.random.randint(1e9)
123 | torch.set_num_threads(1)
124 | pl.seed_everything(args.seed)
125 |
126 | model_cls, filler_cls = get_model_classes(args.model_name)
127 | dataset = get_dataset(args.dataset_name)
128 |
129 | ########################################
130 | # create logdir and save configuration #
131 | ########################################
132 |
133 | exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}"
134 | logdir = os.path.join(config['logs'], args.dataset_name, args.model_name, exp_name)
135 | # save config for logging
136 | pathlib.Path(logdir).mkdir(parents=True)
137 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:
138 | yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)
139 |
140 | ########################################
141 | # data module #
142 | ########################################
143 |
144 | # instantiate dataset
145 | dataset_cls = GraphImputationDataset if has_graph_support(model_cls) else ImputationDataset
146 | torch_dataset = dataset_cls(*dataset.numpy(return_idx=True),
147 | mask=dataset.training_mask,
148 | eval_mask=dataset.eval_mask,
149 | window=args.window,
150 | stride=args.stride)
151 |
152 | # get train/val/test indices
153 | split_conf = parser_utils.filter_function_args(args, dataset.splitter, return_dict=True)
154 | train_idxs, val_idxs, test_idxs = dataset.splitter(torch_dataset, **split_conf)
155 |
156 | # configure datamodule
157 | data_conf = parser_utils.filter_args(args, SpatioTemporalDataModule, return_dict=True)
158 | dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs,
159 | **data_conf)
160 | dm.setup()
161 |
162 | # if out of sample in air, add values removed for evaluation in train set
163 | if not args.in_sample and args.dataset_name[:3] == 'air':
164 | dm.torch_dataset.mask[dm.train_slice] |= dm.torch_dataset.eval_mask[dm.train_slice]
165 |
166 | # get adjacency matrix
167 | adj = dataset.get_similarity(thr=args.adj_threshold)
168 | # force adj with no self loop
169 | np.fill_diagonal(adj, 0.)
170 |
171 | ########################################
172 | # predictor #
173 | ########################################
174 |
175 | # model's inputs
176 | additional_model_hparams = dict(adj=adj, d_in=dm.d_in, n_nodes=dm.n_nodes)
177 | model_kwargs = parser_utils.filter_args(args={**vars(args), **additional_model_hparams},
178 | target_cls=model_cls,
179 | return_dict=True)
180 |
181 | # loss and metrics
182 | loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),
183 | compute_on_step=True,
184 | metric_kwargs={'reduction': 'none'})
185 |
186 | metrics = {'mae': MaskedMAE(compute_on_step=False),
187 | 'mape': MaskedMAPE(compute_on_step=False),
188 | 'mse': MaskedMSE(compute_on_step=False),
189 | 'mre': MaskedMRE(compute_on_step=False)}
190 |
191 | # filler's inputs
192 | scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None
193 | additional_filler_hparams = dict(model_class=model_cls,
194 | model_kwargs=model_kwargs,
195 | optim_class=torch.optim.Adam,
196 | optim_kwargs={'lr': args.lr,
197 | 'weight_decay': args.l2_reg},
198 | loss_fn=loss_fn,
199 | metrics=metrics,
200 | scheduler_class=scheduler_class,
201 | scheduler_kwargs={
202 | 'eta_min': 0.0001,
203 | 'T_max': args.epochs
204 | },
205 | alpha=args.alpha,
206 | hint_rate=args.hint_rate,
207 | g_train_freq=args.g_train_freq,
208 | d_train_freq=args.d_train_freq)
209 | filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},
210 | target_cls=filler_cls,
211 | return_dict=True)
212 | filler = filler_cls(**filler_kwargs)
213 |
214 | ########################################
215 | # training #
216 | ########################################
217 |
218 | # callbacks
219 | early_stop_callback = EarlyStopping(monitor='val_mae', patience=args.patience, mode='min')
220 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mae', mode='min')
221 |
222 | logger = TensorBoardLogger(logdir, name="model")
223 |
224 | trainer = pl.Trainer(max_epochs=args.epochs,
225 | logger=logger,
226 | default_root_dir=logdir,
227 | gpus=1 if torch.cuda.is_available() else None,
228 | gradient_clip_val=args.grad_clip_val,
229 | gradient_clip_algorithm=args.grad_clip_algorithm,
230 | callbacks=[early_stop_callback, checkpoint_callback])
231 |
232 | trainer.fit(filler, datamodule=dm)
233 |
234 | ########################################
235 | # testing #
236 | ########################################
237 |
238 | filler.load_state_dict(torch.load(checkpoint_callback.best_model_path,
239 | lambda storage, loc: storage)['state_dict'])
240 | filler.freeze()
241 | trainer.test()
242 | filler.eval()
243 |
244 | if torch.cuda.is_available():
245 | filler.cuda()
246 |
247 | with torch.no_grad():
248 | y_true, y_hat, mask = filler.predict_loader(dm.test_dataloader(), return_mask=True)
249 | y_hat = y_hat.detach().cpu().numpy().reshape(y_hat.shape[:3]) # reshape to (eventually) squeeze node channels
250 |
251 | # Test imputations in whole series
252 | eval_mask = dataset.eval_mask[dm.test_slice]
253 | df_true = dataset.df.iloc[dm.test_slice]
254 | metrics = {
255 | 'mae': numpy_metrics.masked_mae,
256 | 'mse': numpy_metrics.masked_mse,
257 | 'mre': numpy_metrics.masked_mre,
258 | 'mape': numpy_metrics.masked_mape
259 | }
260 | # Aggregate predictions in dataframes
261 | index = dm.torch_dataset.data_timestamps(dm.testset.indices, flatten=False)['horizon']
262 | aggr_methods = ensure_list(args.aggregate_by)
263 | df_hats = prediction_dataframe(y_hat, index, dataset.df.columns, aggregate_by=aggr_methods)
264 | df_hats = dict(zip(aggr_methods, df_hats))
265 | for aggr_by, df_hat in df_hats.items():
266 | # Compute error
267 | print(f'- AGGREGATE BY {aggr_by.upper()}')
268 | for metric_name, metric_fn in metrics.items():
269 | error = metric_fn(df_hat.values, df_true.values, eval_mask).item()
270 | print(f' {metric_name}: {error:.4f}')
271 |
272 | return y_true, y_hat, mask
273 |
274 |
275 | if __name__ == '__main__':
276 | args = parse_args()
277 | run_experiment(args)
278 |
--------------------------------------------------------------------------------
/scripts/run_synthetic.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import os
4 | import pathlib
5 | from argparse import ArgumentParser
6 |
7 | import numpy as np
8 | import pytorch_lightning as pl
9 | import torch
10 | import torch.nn.functional as F
11 | import yaml
12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
13 | from pytorch_lightning.loggers import TensorBoardLogger
14 | from torch.optim.lr_scheduler import CosineAnnealingLR
15 |
16 | from lib import fillers, config
17 | from lib.datasets import ChargedParticles
18 | from lib.nn import models
19 | from lib.nn.utils.metric_base import MaskedMetric
20 | from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE
21 | from lib.utils import parser_utils
22 | from lib.utils.parser_utils import str_to_bool
23 |
24 |
25 | def has_graph_support(model_cls):
26 | return model_cls is models.GRINet
27 |
28 |
29 | def get_model_classes(model_str):
30 | if model_str == 'brits':
31 | model, filler = models.BRITSNet, fillers.BRITSFiller
32 | elif model_str == 'grin':
33 | model, filler = models.GRINet, fillers.GraphFiller
34 | else:
35 | raise ValueError(f'Model {model_str} not available.')
36 | return model, filler
37 |
38 |
39 | def parse_args():
40 | # Argument parser
41 | parser = ArgumentParser()
42 | parser.add_argument('--seed', type=int, default=-1)
43 | parser.add_argument("--model-name", type=str, default='bigrill')
44 | parser.add_argument("--config", type=str, default=None)
45 | # Dataset params
46 | parser.add_argument('--static-adj', type=str_to_bool, nargs='?', const=True, default=False)
47 | parser.add_argument('--window', type=int, default=50)
48 | parser.add_argument('--p-block', type=float, default=0.025)
49 | parser.add_argument('--p-point', type=float, default=0.025)
50 | parser.add_argument('--min-seq', type=int, default=5)
51 | parser.add_argument('--max-seq', type=int, default=10)
52 | parser.add_argument('--use-exogenous', type=str_to_bool, nargs='?', const=True, default=True)
53 | # Splitting/aggregation params
54 | parser.add_argument('--val-len', type=float, default=0.1)
55 | parser.add_argument('--test-len', type=float, default=0.2)
56 | # Training params
57 | parser.add_argument('--lr', type=float, default=0.001)
58 | parser.add_argument('--epochs', type=int, default=300)
59 | parser.add_argument('--patience', type=int, default=40)
60 | parser.add_argument('--l2-reg', type=float, default=0.)
61 | parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=False)
62 | parser.add_argument('--grad-clip-val', type=float, default=5.)
63 | parser.add_argument('--grad-clip-algorithm', type=str, default='norm')
64 | parser.add_argument('--loss-fn', type=str, default='mse_loss')
65 | parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True)
66 | parser.add_argument('--whiten-prob', type=float, default=0.05)
67 | parser.add_argument('--pred-loss-weight', type=float, default=1.0)
68 | parser.add_argument('--warm-up', type=int, default=0)
69 | # graph params
70 | parser.add_argument("--adj-threshold", type=float, default=0.1)
71 |
72 | known_args, _ = parser.parse_known_args()
73 | model_cls, _ = get_model_classes(known_args.model_name)
74 | parser = model_cls.add_model_specific_args(parser)
75 |
76 | args = parser.parse_args()
77 | if args.config is not None:
78 | with open(args.config, 'r') as fp:
79 | config_args = yaml.load(fp, Loader=yaml.FullLoader)
80 | for arg in config_args:
81 | setattr(args, arg, config_args[arg])
82 |
83 | return args
84 |
85 |
86 | def run_experiment(args):
87 | # Set configuration and seed
88 | args = copy.deepcopy(args)
89 | if args.seed < 0:
90 | args.seed = np.random.randint(1e9)
91 | torch.set_num_threads(1)
92 | pl.seed_everything(args.seed)
93 |
94 | ########################################
95 | # load dataset and model #
96 | ########################################
97 |
98 | model_cls, filler_cls = get_model_classes(args.model_name)
99 |
100 | dataset = ChargedParticles(static_adj=args.static_adj,
101 | window=args.window,
102 | p_block=args.p_block,
103 | p_point=args.p_point,
104 | max_seq=args.max_seq,
105 | min_seq=args.min_seq,
106 | use_exogenous=args.use_exogenous,
107 | graph_mode=has_graph_support(model_cls))
108 |
109 | dataset.split(args.val_len, args.test_len)
110 |
111 | # get adjacency matrix
112 | adj = dataset.get_similarity()
113 | np.fill_diagonal(adj, 0.) # force adj with no self loop
114 |
115 | ########################################
116 | # create logdir and save configuration #
117 | ########################################
118 |
119 | exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}"
120 | logdir = os.path.join(config['logs'], 'synthetic', args.model_name, exp_name)
121 | # save config for logging
122 | pathlib.Path(logdir).mkdir(parents=True)
123 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:
124 | yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)
125 |
126 | ########################################
127 | # predictor #
128 | ########################################
129 |
130 | # model's inputs
131 | if has_graph_support(model_cls):
132 | model_params = dict(adj=adj, d_in=dataset.n_channels, d_u=dataset.n_exogenous, n_nodes=dataset.n_nodes)
133 | else:
134 | model_params = dict(d_in=(dataset.n_channels * dataset.n_nodes), d_u=(dataset.n_channels * dataset.n_exogenous))
135 | model_kwargs = parser_utils.filter_args(args={**vars(args), **model_params},
136 | target_cls=model_cls,
137 | return_dict=True)
138 |
139 | # loss and metrics
140 | loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),
141 | compute_on_step=True,
142 | metric_kwargs={'reduction': 'none'})
143 |
144 | metrics = {'mae': MaskedMAE(compute_on_step=False),
145 | 'mape': MaskedMAPE(compute_on_step=False),
146 | 'mse': MaskedMSE(compute_on_step=False),
147 | 'mre': MaskedMRE(compute_on_step=False)}
148 |
149 | # filler's inputs
150 | scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None
151 | additional_filler_hparams = dict(model_class=model_cls,
152 | model_kwargs=model_kwargs,
153 | optim_class=torch.optim.Adam,
154 | optim_kwargs={'lr': args.lr,
155 | 'weight_decay': args.l2_reg},
156 | loss_fn=loss_fn,
157 | metrics=metrics,
158 | scheduler_class=scheduler_class,
159 | scheduler_kwargs={
160 | 'eta_min': 0.0001,
161 | 'T_max': args.epochs
162 | },
163 | alpha=args.alpha,
164 | hint_rate=args.hint_rate,
165 | g_train_freq=args.g_train_freq,
166 | d_train_freq=args.d_train_freq)
167 | filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},
168 | target_cls=filler_cls,
169 | return_dict=True)
170 | filler = filler_cls(**filler_kwargs)
171 |
172 | ########################################
173 | # logging options #
174 | ########################################
175 |
176 | # log number of parameters
177 | args.trainable_parameters = filler.trainable_parameters
178 |
179 | # log statistics on masks
180 | for mask_type in ['mask', 'eval_mask', 'training_mask']:
181 | mask_type_mean = getattr(dataset, mask_type).float().mean().item()
182 | setattr(args, mask_type, mask_type_mean)
183 |
184 | print(args)
185 |
186 | ########################################
187 | # training #
188 | ########################################
189 |
190 | # callbacks
191 | early_stop_callback = EarlyStopping(monitor='val_mse', patience=args.patience, mode='min')
192 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mse', mode='min')
193 |
194 | logger = TensorBoardLogger(logdir, name="model")
195 |
196 | trainer = pl.Trainer(max_epochs=args.epochs,
197 | default_root_dir=logdir,
198 | logger=logger,
199 | gpus=1 if torch.cuda.is_available() else None,
200 | gradient_clip_val=args.grad_clip_val,
201 | gradient_clip_algorithm=args.grad_clip_algorithm,
202 | callbacks=[early_stop_callback, checkpoint_callback])
203 |
204 | trainer.fit(filler,
205 | train_dataloader=dataset.train_dataloader(batch_size=args.batch_size),
206 | val_dataloaders=dataset.val_dataloader(batch_size=args.batch_size))
207 |
208 | ########################################
209 | # testing #
210 | ########################################
211 |
212 | filler.load_state_dict(torch.load(checkpoint_callback.best_model_path,
213 | lambda storage, loc: storage)['state_dict'])
214 | filler.freeze()
215 | trainer.test(filler, test_dataloaders=dataset.test_dataloader(batch_size=args.batch_size))
216 | filler.eval()
217 |
218 |
219 | if __name__ == '__main__':
220 | args = parse_args()
221 | run_experiment(args)
222 |
--------------------------------------------------------------------------------