├── .gitignore ├── README.md ├── conda_env.yml ├── config ├── bimpgru │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ └── la_point.yaml ├── brits │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ ├── la_point.yaml │ └── synthetic.yaml ├── grin │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ ├── la_point.yaml │ └── synthetic.yaml ├── mpgru │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ └── la_point.yaml ├── rgain │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ └── la_point.yaml └── var │ ├── air.yaml │ ├── air36.yaml │ ├── bay_block.yaml │ ├── bay_point.yaml │ ├── irish_block.yaml │ ├── irish_point.yaml │ ├── la_block.yaml │ └── la_point.yaml ├── grin.png ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── datamodule │ │ ├── __init__.py │ │ └── spatiotemporal.py │ ├── imputation_dataset.py │ ├── preprocessing │ │ ├── __init__.py │ │ └── scalers.py │ ├── spatiotemporal_dataset.py │ └── temporal_dataset.py ├── datasets │ ├── __init__.py │ ├── air_quality.py │ ├── metr_la.py │ ├── pd_dataset.py │ ├── pems_bay.py │ └── synthetic.py ├── fillers │ ├── __init__.py │ ├── britsfiller.py │ ├── filler.py │ ├── graphfiller.py │ ├── multi_imputation_filler.py │ └── rgainfiller.py ├── nn │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── gcrnn.py │ │ ├── gril.py │ │ ├── imputation.py │ │ ├── mpgru.py │ │ ├── rits.py │ │ ├── spatial_attention.py │ │ └── spatial_conv.py │ ├── models │ │ ├── __init__.py │ │ ├── brits.py │ │ ├── grin.py │ │ ├── mpgru.py │ │ ├── rgain.py │ │ ├── rnn_imputers.py │ │ └── var.py │ └── utils │ │ ├── __init__.py │ │ ├── metric_base.py │ │ ├── metrics.py │ │ └── ops.py └── utils │ ├── __init__.py │ ├── numpy_metrics.py │ ├── parser_utils.py │ └── utils.py ├── requirements.txt └── scripts ├── run_baselines.py ├── run_imputation.py └── run_synthetic.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_STORE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks (ICLR 2022 - [open review](https://openreview.net/forum?id=kOu3-S3wJ7) - [pdf](https://openreview.net/pdf?id=kOu3-S3wJ7)) 2 | 3 | [![ICLR](https://img.shields.io/badge/ICLR-2022-blue.svg?style=flat-square)](https://openreview.net/forum?id=kOu3-S3wJ7) 4 | [![PDF](https://img.shields.io/badge/%E2%87%A9-PDF-orange.svg?style=flat-square)](https://openreview.net/pdf?id=kOu3-S3wJ7) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2108.00298-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2108.00298) 6 | 7 | This repository contains the code for the reproducibility of the experiments presented in the paper "Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks" (ICLR 2022). In this paper, we propose a graph neural network architecture for multivariate time series imputation and achieve state-of-the-art results on several benchmarks. 8 | 9 | **Authors**: [Andrea Cini](mailto:andrea.cini@usi.ch), [Ivan Marisca](mailto:ivan.marisca@usi.ch), Cesare Alippi 10 | 11 | 12 | **‼️ PyG implementation of GRIN is now available inside [Torch Spatiotemporal](https://github.com/TorchSpatiotemporal/tsl), a library built to accelerate research on neural spatiotemporal data processing methods, with a focus on Graph Neural Networks.** 13 | 14 | --- 15 | 16 |

GRIN in a nutshell

17 | 18 | The [paper](https://arxiv.org/abs/2108.00298) introduces __GRIN__, a method and an architecture to exploit relational inductive biases to reconstruct missing values in multivariate time series coming from sensor networks. GRIN features a bidirectional recurrent GNN which learns __spatio-temporal node-level representations__ tailored to reconstruct observations at neighboring nodes. 19 | 20 |

21 | 22 | Logo 23 | 24 |

25 | 26 | --- 27 | 28 | ## Directory structure 29 | 30 | The directory is structured as follows: 31 | 32 | ``` 33 | . 34 | ├── config 35 | │   ├── bimpgru 36 | │   ├── brits 37 | │   ├── grin 38 | │   ├── mpgru 39 | │   ├── rgain 40 | │   └── var 41 | ├── datasets 42 | │   ├── air_quality 43 | │   ├── metr_la 44 | │   ├── pems_bay 45 | │   └── synthetic 46 | ├── lib 47 | │   ├── __init__.py 48 | │   ├── data 49 | │   ├── datasets 50 | │   ├── fillers 51 | │   ├── nn 52 | │   └── utils 53 | ├── requirements.txt 54 | └── scripts 55 | ├── run_baselines.py 56 | ├── run_imputation.py 57 | └── run_synthetic.py 58 | 59 | ``` 60 | Note that, given the size of the files, the datasets are not readily available in the folder. See the next section for the downloading instructions. 61 | 62 | ## Datasets 63 | 64 | All the datasets used in the experiment, except CER-E, are open and can be downloaded from this [link](https://mega.nz/folder/qwwG3Qba#c6qFTeT7apmZKKyEunCzSg). The CER-E dataset can be obtained free of charge for research purposes following the instructions at this [link](https://www.ucd.ie/issda/data/commissionforenergyregulationcer/). We recommend storing the downloaded datasets in a folder named `datasets` inside this directory. 65 | 66 | ## Configuration files 67 | 68 | The `config` directory stores all the configuration files used to run the experiment. They are divided into folders, according to the model. 69 | 70 | ## Library 71 | 72 | The support code, including the models and the datasets readers, are packed in a python library named `lib`. Should you have to change the paths to the datasets location, you have to edit the `__init__.py` file of the library. 73 | 74 | ## Scripts 75 | 76 | The scripts used for the experiment in the paper are in the `scripts` folder. 77 | 78 | * `run_baselines.py` is used to compute the metrics for the `MEAN`, `KNN`, `MF` and `MICE` imputation methods. An example of usage is 79 | 80 | ``` 81 | python ./scripts/run_baselines.py --datasets air36 air --imputers mean knn --k 10 --in-sample True --n-runs 5 82 | ``` 83 | 84 | * `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is 85 | 86 | ``` 87 | python ./scripts/run_imputation.py --config config/grin/air36.yaml --in-sample False 88 | ``` 89 | 90 | * `run_synthetic.py` is used for the experiments on the synthetic datasets. An example of usage is 91 | 92 | ``` 93 | python ./scripts/run_synthetic.py --config config/grin/synthetic.yaml --static-adj False 94 | ``` 95 | 96 | ## Requirements 97 | 98 | We run all the experiments in `python 3.8`, see `requirements.txt` for the list of `pip` dependencies. 99 | 100 | ## Bibtex reference 101 | 102 | If you find this code useful please consider to cite our paper: 103 | 104 | ``` 105 | @inproceedings{cini2022filling, 106 | title={Filling the G\_ap\_s: Multivariate Time Series Imputation by Graph Neural Networks}, 107 | author={Andrea Cini and Ivan Marisca and Cesare Alippi}, 108 | booktitle={International Conference on Learning Representations}, 109 | year={2022}, 110 | url={https://openreview.net/forum?id=kOu3-S3wJ7} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: grin 2 | channels: 3 | - defaults 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - pip 8 | - pytables 9 | - python=3.8 10 | - pytorch=1.8 11 | - torchvision 12 | - torchaudio 13 | - wheel 14 | - pip: 15 | - einops 16 | - fancyimpute==0.6 17 | - h5py 18 | - openpyxl 19 | - pandas 20 | - pytorch-lightning==1.4 21 | - pyyaml 22 | - scikit-learn 23 | - scipy 24 | - tensorboard 25 | -------------------------------------------------------------------------------- /config/bimpgru/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/bimpgru/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'bimpgru' 17 | d_hidden: 64 18 | d_emb: 8 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false 24 | merge: 'mlp' -------------------------------------------------------------------------------- /config/brits/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 128 -------------------------------------------------------------------------------- /config/brits/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 64 -------------------------------------------------------------------------------- /config/brits/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 256 -------------------------------------------------------------------------------- /config/brits/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 256 -------------------------------------------------------------------------------- /config/brits/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 256 -------------------------------------------------------------------------------- /config/brits/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 256 -------------------------------------------------------------------------------- /config/brits/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 128 -------------------------------------------------------------------------------- /config/brits/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | batch_size: 32 12 | aggregate_by: ['mean'] 13 | 14 | model_name: 'brits' 15 | d_hidden: 128 -------------------------------------------------------------------------------- /config/brits/synthetic.yaml: -------------------------------------------------------------------------------- 1 | window: 36 2 | p_block: 0.025 3 | p_point: 0.025 4 | min_seq: 4 5 | max_seq: 9 6 | use_exogenous: False 7 | 8 | epochs: 200 9 | batch_size: 32 10 | 11 | model_name: 'brits' 12 | d_hidden: 32 -------------------------------------------------------------------------------- /config/grin/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'grin' 17 | pred_loss_weight: 1 18 | 19 | d_hidden: 64 20 | d_emb: 8 21 | d_ff: 64 22 | ff_dropout: 0 23 | kernel_size: 2 24 | decoder_order: 1 25 | n_layers: 1 26 | layer_norm: false 27 | merge: 'mlp' 28 | -------------------------------------------------------------------------------- /config/grin/synthetic.yaml: -------------------------------------------------------------------------------- 1 | window: 36 2 | p_block: 0.025 3 | p_point: 0.025 4 | min_seq: 4 5 | max_seq: 9 6 | use_exogenous: False 7 | 8 | epochs: 200 9 | batch_size: 32 10 | 11 | model_name: 'grin' 12 | d_hidden: 16 13 | d_emb: 0 14 | d_ff: 16 15 | ff_dropout: 0 16 | kernel_size: 1 17 | decoder_order: 1 18 | n_layers: 1 19 | layer_norm: false 20 | merge: 'mlp' 21 | -------------------------------------------------------------------------------- /config/mpgru/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/mpgru/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | 4 | adj_threshold: 0.1 5 | 6 | detrend: False 7 | scale: True 8 | scaling_axis: 'global' # ['channels', 'global'] 9 | scaled_target: True 10 | 11 | epochs: 300 12 | samples_per_epoch: 5120 # 160 batch of 32 13 | batch_size: 32 14 | aggregate_by: ['mean'] 15 | 16 | model_name: 'mpgru' 17 | pred_loss_weight: 1 18 | d_hidden: 64 19 | d_ff: 64 20 | dropout: 0 21 | kernel_size: 2 22 | n_layers: 1 23 | layer_norm: false -------------------------------------------------------------------------------- /config/rgain/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 128 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 64 21 | d_z: 4 22 | dropout: 0.1 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 256 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 256 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 256 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 256 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 128 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/rgain/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | whiten_prob: 0.2 4 | 5 | detrend: False 6 | scale: True 7 | scaling_axis: 'channels' 8 | scaled_target: True 9 | 10 | epochs: 300 11 | samples_per_epoch: 5120 # 160 batch of 32 12 | batch_size: 32 13 | loss_fn: mse_loss 14 | consistency_loss: False 15 | use_lr_schedule: True 16 | grad_clip_val: -1 17 | aggregate_by: ['mean'] 18 | 19 | model_name: 'gain' 20 | d_model: 128 21 | d_z: 4 22 | dropout: 0.2 23 | inject_noise: true 24 | alpha: 20 25 | g_train_freq: 3 26 | d_train_freq: 1 -------------------------------------------------------------------------------- /config/var/air.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/air36.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'air36' 2 | window: 36 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/bay_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/bay_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'bay_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/irish_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/irish_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'irish_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/la_block.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_block' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /config/var/la_point.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'la_point' 2 | window: 24 3 | 4 | detrend: False 5 | scale: True 6 | scaling_axis: 'channels' 7 | scaled_target: True 8 | 9 | epochs: 300 10 | samples_per_epoch: 5120 # 160 batch of 32 11 | lr: 0.0005 12 | batch_size: 64 13 | aggregate_by: ['mean'] 14 | 15 | model_name: 'var' 16 | order: 5 17 | padding: 'mean' -------------------------------------------------------------------------------- /grin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/grin/4a28afbb092600b6e6abeeabaaf67e87dbd1ed6e/grin.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | 5 | config = { 6 | 'logs': 'logs/' 7 | } 8 | datasets_path = { 9 | 'air': 'datasets/air_quality', 10 | 'la': 'datasets/metr_la', 11 | 'bay': 'datasets/pems_bay', 12 | 'synthetic': 'datasets/synthetic' 13 | } 14 | epsilon = 1e-8 15 | 16 | for k, v in config.items(): 17 | config[k] = os.path.join(base_dir, v) 18 | for k, v in datasets_path.items(): 19 | datasets_path[k] = os.path.join(base_dir, v) 20 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .temporal_dataset import TemporalDataset 2 | from .spatiotemporal_dataset import SpatioTemporalDataset 3 | -------------------------------------------------------------------------------- /lib/data/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | from .spatiotemporal import SpatioTemporalDataModule 2 | -------------------------------------------------------------------------------- /lib/data/datamodule/spatiotemporal.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader, Subset, RandomSampler 3 | 4 | from .. import TemporalDataset, SpatioTemporalDataset 5 | from ..preprocessing import StandardScaler, MinMaxScaler 6 | from ...utils import ensure_list 7 | from ...utils.parser_utils import str_to_bool 8 | 9 | 10 | class SpatioTemporalDataModule(pl.LightningDataModule): 11 | """ 12 | Pytorch Lightning DataModule for TimeSeriesDatasets 13 | """ 14 | 15 | def __init__(self, dataset: TemporalDataset, 16 | scale=True, 17 | scaling_axis='samples', 18 | scaling_type='std', 19 | scale_exogenous=None, 20 | train_idxs=None, 21 | val_idxs=None, 22 | test_idxs=None, 23 | batch_size=32, 24 | workers=1, 25 | samples_per_epoch=None): 26 | super(SpatioTemporalDataModule, self).__init__() 27 | self.torch_dataset = dataset 28 | # splitting 29 | self.trainset = Subset(self.torch_dataset, train_idxs if train_idxs is not None else []) 30 | self.valset = Subset(self.torch_dataset, val_idxs if val_idxs is not None else []) 31 | self.testset = Subset(self.torch_dataset, test_idxs if test_idxs is not None else []) 32 | # preprocessing 33 | self.scale = scale 34 | self.scaling_type = scaling_type 35 | self.scaling_axis = scaling_axis 36 | self.scale_exogenous = ensure_list(scale_exogenous) if scale_exogenous is not None else None 37 | # data loaders 38 | self.batch_size = batch_size 39 | self.workers = workers 40 | self.samples_per_epoch = samples_per_epoch 41 | 42 | @property 43 | def is_spatial(self): 44 | return isinstance(self.torch_dataset, SpatioTemporalDataset) 45 | 46 | @property 47 | def n_nodes(self): 48 | if not self.has_setup_fit: 49 | raise ValueError('You should initialize the datamodule first.') 50 | return self.torch_dataset.n_nodes if self.is_spatial else None 51 | 52 | @property 53 | def d_in(self): 54 | if not self.has_setup_fit: 55 | raise ValueError('You should initialize the datamodule first.') 56 | return self.torch_dataset.n_channels 57 | 58 | @property 59 | def d_out(self): 60 | if not self.has_setup_fit: 61 | raise ValueError('You should initialize the datamodule first.') 62 | return self.torch_dataset.horizon 63 | 64 | @property 65 | def train_slice(self): 66 | return self.torch_dataset.expand_indices(self.trainset.indices, merge=True) 67 | 68 | @property 69 | def val_slice(self): 70 | return self.torch_dataset.expand_indices(self.valset.indices, merge=True) 71 | 72 | @property 73 | def test_slice(self): 74 | return self.torch_dataset.expand_indices(self.testset.indices, merge=True) 75 | 76 | def get_scaling_axes(self, dim='global'): 77 | scaling_axis = tuple() 78 | if dim == 'global': 79 | scaling_axis = (0, 1, 2) 80 | elif dim == 'channels': 81 | scaling_axis = (0, 1) 82 | elif dim == 'nodes': 83 | scaling_axis = (0,) 84 | # Remove last dimension for temporal datasets 85 | if not self.is_spatial: 86 | scaling_axis = scaling_axis[:-1] 87 | 88 | if not len(scaling_axis): 89 | raise ValueError(f'Scaling axis "{dim}" not valid.') 90 | 91 | return scaling_axis 92 | 93 | def get_scaler(self): 94 | if self.scaling_type == 'std': 95 | return StandardScaler 96 | elif self.scaling_type == 'minmax': 97 | return MinMaxScaler 98 | else: 99 | return NotImplementedError 100 | 101 | def setup(self, stage=None): 102 | 103 | if self.scale: 104 | scaling_axis = self.get_scaling_axes(self.scaling_axis) 105 | train = self.torch_dataset.data.numpy()[self.train_slice] 106 | train_mask = self.torch_dataset.mask.numpy()[self.train_slice] if 'mask' in self.torch_dataset else None 107 | scaler = self.get_scaler()(scaling_axis).fit(train, mask=train_mask, keepdims=True).to_torch() 108 | self.torch_dataset.scaler = scaler 109 | 110 | if self.scale_exogenous is not None: 111 | for label in self.scale_exogenous: 112 | exo = getattr(self.torch_dataset, label) 113 | scaler = self.get_scaler()(scaling_axis) 114 | scaler.fit(exo[self.train_slice], keepdims=True).to_torch() 115 | setattr(self.torch_dataset, label, scaler.transform(exo)) 116 | 117 | def _data_loader(self, dataset, shuffle=False, batch_size=None, **kwargs): 118 | batch_size = self.batch_size if batch_size is None else batch_size 119 | return DataLoader(dataset, 120 | shuffle=shuffle, 121 | batch_size=batch_size, 122 | num_workers=self.workers, 123 | **kwargs) 124 | 125 | def train_dataloader(self, shuffle=True, batch_size=None): 126 | if self.samples_per_epoch is not None: 127 | sampler = RandomSampler(self.trainset, replacement=True, num_samples=self.samples_per_epoch) 128 | return self._data_loader(self.trainset, False, batch_size, sampler=sampler, drop_last=True) 129 | return self._data_loader(self.trainset, shuffle, batch_size, drop_last=True) 130 | 131 | def val_dataloader(self, shuffle=False, batch_size=None): 132 | return self._data_loader(self.valset, shuffle, batch_size) 133 | 134 | def test_dataloader(self, shuffle=False, batch_size=None): 135 | return self._data_loader(self.testset, shuffle, batch_size) 136 | 137 | @staticmethod 138 | def add_argparse_args(parser, **kwargs): 139 | parser.add_argument('--batch-size', type=int, default=64) 140 | parser.add_argument('--scaling-axis', type=str, default="channels") 141 | parser.add_argument('--scaling-type', type=str, default="std") 142 | parser.add_argument('--scale', type=str_to_bool, nargs='?', const=True, default=True) 143 | parser.add_argument('--workers', type=int, default=0) 144 | parser.add_argument('--samples-per-epoch', type=int, default=None) 145 | return parser 146 | -------------------------------------------------------------------------------- /lib/data/imputation_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from . import TemporalDataset, SpatioTemporalDataset 5 | 6 | 7 | class ImputationDataset(TemporalDataset): 8 | 9 | def __init__(self, data, 10 | index=None, 11 | mask=None, 12 | eval_mask=None, 13 | freq=None, 14 | trend=None, 15 | scaler=None, 16 | window=24, 17 | stride=1, 18 | exogenous=None): 19 | if mask is None: 20 | mask = np.ones_like(data) 21 | if exogenous is None: 22 | exogenous = dict() 23 | exogenous['mask_window'] = mask 24 | if eval_mask is not None: 25 | exogenous['eval_mask_window'] = eval_mask 26 | super(ImputationDataset, self).__init__(data, 27 | index=index, 28 | exogenous=exogenous, 29 | trend=trend, 30 | scaler=scaler, 31 | freq=freq, 32 | window=window, 33 | horizon=window, 34 | delay=-window, 35 | stride=stride) 36 | 37 | def get(self, item, preprocess=False): 38 | res, transform = super(ImputationDataset, self).get(item, preprocess) 39 | res['x'] = torch.where(res['mask'], res['x'], torch.zeros_like(res['x'])) 40 | return res, transform 41 | 42 | 43 | class GraphImputationDataset(ImputationDataset, SpatioTemporalDataset): 44 | pass 45 | -------------------------------------------------------------------------------- /lib/data/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .scalers import * 2 | -------------------------------------------------------------------------------- /lib/data/preprocessing/scalers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | 4 | 5 | class AbstractScaler(ABC): 6 | 7 | def __init__(self, **kwargs): 8 | for k, v in kwargs.items(): 9 | setattr(self, k, v) 10 | 11 | def __repr__(self): 12 | params = ", ".join([f"{k}={str(v)}" for k, v in self.params().items()]) 13 | return "{}({})".format(self.__class__.__name__, params) 14 | 15 | def __call__(self, *args, **kwargs): 16 | return self.transform(*args, **kwargs) 17 | 18 | def params(self): 19 | return {k: v for k, v in self.__dict__.items() if not callable(v) and not k.startswith("__")} 20 | 21 | @abstractmethod 22 | def fit(self, x): 23 | pass 24 | 25 | @abstractmethod 26 | def transform(self, x): 27 | pass 28 | 29 | @abstractmethod 30 | def inverse_transform(self, x): 31 | pass 32 | 33 | def fit_transform(self, x): 34 | self.fit(x) 35 | return self.transform(x) 36 | 37 | def to_torch(self): 38 | import torch 39 | for p in self.params(): 40 | param = getattr(self, p) 41 | param = np.atleast_1d(param) 42 | param = torch.tensor(param).float() 43 | setattr(self, p, param) 44 | return self 45 | 46 | 47 | class Scaler(AbstractScaler): 48 | def __init__(self, offset=0., scale=1.): 49 | self.bias = offset 50 | self.scale = scale 51 | super(Scaler, self).__init__() 52 | 53 | def params(self): 54 | return dict(bias=self.bias, scale=self.scale) 55 | 56 | def fit(self, x, mask=None, keepdims=True): 57 | pass 58 | 59 | def transform(self, x): 60 | return (x - self.bias) / self.scale 61 | 62 | def inverse_transform(self, x): 63 | return x * self.scale + self.bias 64 | 65 | def fit_transform(self, x, mask=None, keepdims=True): 66 | self.fit(x, mask, keepdims) 67 | return self.transform(x) 68 | 69 | 70 | class StandardScaler(Scaler): 71 | def __init__(self, axis=0): 72 | self.axis = axis 73 | super(StandardScaler, self).__init__() 74 | 75 | def fit(self, x, mask=None, keepdims=True): 76 | if mask is not None: 77 | x = np.where(mask, x, np.nan) 78 | self.bias = np.nanmean(x, axis=self.axis, keepdims=keepdims) 79 | self.scale = np.nanstd(x, axis=self.axis, keepdims=keepdims) 80 | else: 81 | self.bias = x.mean(axis=self.axis, keepdims=keepdims) 82 | self.scale = x.std(axis=self.axis, keepdims=keepdims) 83 | return self 84 | 85 | 86 | class MinMaxScaler(Scaler): 87 | def __init__(self, axis=0): 88 | self.axis = axis 89 | super(MinMaxScaler, self).__init__() 90 | 91 | def fit(self, x, mask=None, keepdims=True): 92 | if mask is not None: 93 | x = np.where(mask, x, np.nan) 94 | self.bias = np.nanmin(x, axis=self.axis, keepdims=keepdims) 95 | self.scale = (np.nanmax(x, axis=self.axis, keepdims=keepdims) - self.bias) 96 | else: 97 | self.bias = x.min(axis=self.axis, keepdims=keepdims) 98 | self.scale = (x.max(axis=self.axis, keepdims=keepdims) - self.bias) 99 | return self 100 | -------------------------------------------------------------------------------- /lib/data/spatiotemporal_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from einops import rearrange 4 | 5 | from .temporal_dataset import TemporalDataset 6 | 7 | 8 | class SpatioTemporalDataset(TemporalDataset): 9 | def __init__(self, data, 10 | index=None, 11 | trend=None, 12 | scaler=None, 13 | freq=None, 14 | window=24, 15 | horizon=24, 16 | delay=0, 17 | stride=1, 18 | **exogenous): 19 | """ 20 | Pytorch dataset for data that can be represented as a single TimeSeries 21 | 22 | :param data: 23 | raw target time series (ts) (can be multivariate), shape: [steps, (features), nodes] 24 | :param exog: 25 | global exogenous variables, shape: [steps, nodes] 26 | :param trend: 27 | trend time series to be removed from the ts, shape: [steps, (features), (nodes)] 28 | :param bias: 29 | bias to be removed from the ts (after de-trending), shape [steps, (features), (nodes)] 30 | :param scale: r 31 | scaling factor to scale the ts (after de-trending), shape [steps, (features), (nodes)] 32 | :param mask: 33 | mask for valid data, 1 -> valid time step, 0 -> invalid. same shape of ts. 34 | :param target_exog: 35 | exogenous variables of the target, shape: [steps, nodes] 36 | :param window: 37 | length of windows returned by __get_intem__ 38 | :param horizon: 39 | length of prediction horizon returned by __get_intem__ 40 | :param delay: 41 | delay between input and prediction 42 | """ 43 | super(SpatioTemporalDataset, self).__init__(data, 44 | index=index, 45 | trend=trend, 46 | scaler=scaler, 47 | freq=freq, 48 | window=window, 49 | horizon=horizon, 50 | delay=delay, 51 | stride=stride, 52 | **exogenous) 53 | 54 | def __repr__(self): 55 | return "{}(n_samples={}, n_nodes={})".format(self.__class__.__name__, len(self), self.n_nodes) 56 | 57 | @property 58 | def n_nodes(self): 59 | return self.data.shape[1] 60 | 61 | @staticmethod 62 | def check_dim(data): 63 | if data.ndim == 2: # [steps, nodes] -> [steps, nodes, features] 64 | data = rearrange(data, 's (n f) -> s n f', f=1) 65 | elif data.ndim == 1: 66 | data = rearrange(data, '(s n f) -> s n f', n=1, f=1) 67 | elif data.ndim == 3: 68 | pass 69 | else: 70 | raise ValueError(f'Invalid data dimensions {data.shape}') 71 | return data 72 | 73 | def dataframe(self): 74 | if self.n_channels == 1: 75 | return pd.DataFrame(data=np.squeeze(self.data, -1), index=self.index) 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /lib/data/temporal_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from einops import rearrange 5 | from pandas import DatetimeIndex 6 | from torch.utils.data import Dataset 7 | 8 | from .preprocessing import AbstractScaler 9 | 10 | 11 | class TemporalDataset(Dataset): 12 | def __init__(self, data, 13 | index=None, 14 | freq=None, 15 | exogenous=None, 16 | trend=None, 17 | scaler=None, 18 | window=24, 19 | horizon=24, 20 | delay=0, 21 | stride=1): 22 | """Wrapper class for dataset whose entry are dependent from a sequence of temporal indices. 23 | 24 | Parameters 25 | ---------- 26 | data : np.ndarray 27 | Data relative to the main signal. 28 | index : DatetimeIndex or None 29 | Temporal indices for the data. 30 | exogenous : dict or None 31 | Exogenous data and label paired with main signal (default is None). 32 | trend : np.ndarray or None 33 | Trend paired with main signal (default is None). Must be of the same length of 'data'. 34 | scaler : AbstractScaler or None 35 | Scaler that must be used for data (default is None). 36 | freq : pd.DateTimeIndex.freq or str 37 | Frequency of the indices (defaults is indices.freq). 38 | window : int 39 | Size of the sliding window in the past. 40 | horizon : int 41 | Size of the prediction horizon. 42 | delay : int 43 | Offset between end of window and start of horizon. 44 | 45 | Raises 46 | ---------- 47 | ValueError 48 | If a frequency for the temporal indices is not provided neither in indices nor explicitly. 49 | If preprocess is True and data_scaler is None. 50 | """ 51 | super(TemporalDataset, self).__init__() 52 | # Initialize signatures 53 | self.__exogenous_keys = dict() 54 | self.__reserved_signature = {'data', 'trend', 'x', 'y'} 55 | # Store data 56 | self.data = data 57 | if exogenous is not None: 58 | for name, value in exogenous.items(): 59 | self.add_exogenous(value, name, for_window=True, for_horizon=True) 60 | # Store time information 61 | self.index = index 62 | try: 63 | freq = freq or index.freq or index.inferred_freq 64 | self.freq = pd.tseries.frequencies.to_offset(freq) 65 | except AttributeError: 66 | self.freq = None 67 | # Store offset information 68 | self.window = window 69 | self.delay = delay 70 | self.horizon = horizon 71 | self.stride = stride 72 | # Identify the indices of the samples 73 | self._indices = np.arange(self.data.shape[0] - self.sample_span + 1)[::self.stride] 74 | # Store preprocessing options 75 | self.trend = trend 76 | self.scaler = scaler 77 | 78 | def __getitem__(self, item): 79 | return self.get(item, self.preprocess) 80 | 81 | def __contains__(self, item): 82 | return item in self.__exogenous_keys 83 | 84 | def __len__(self): 85 | return len(self._indices) 86 | 87 | def __repr__(self): 88 | return "{}(n_samples={})".format(self.__class__.__name__, len(self)) 89 | 90 | # Getter and setter for data 91 | 92 | @property 93 | def data(self): 94 | return self.__data 95 | 96 | @data.setter 97 | def data(self, value): 98 | assert value is not None 99 | self.__data = self.check_input(value) 100 | 101 | @property 102 | def trend(self): 103 | return self.__trend 104 | 105 | @trend.setter 106 | def trend(self, value): 107 | self.__trend = self.check_input(value) 108 | 109 | # Setter for exogenous data 110 | 111 | def add_exogenous(self, obj, name, for_window=True, for_horizon=False): 112 | assert isinstance(name, str) 113 | if name.endswith('_window'): 114 | name = name[:-7] 115 | for_window, for_horizon = True, False 116 | if name.endswith('_horizon'): 117 | name = name[:-8] 118 | for_window, for_horizon = False, True 119 | if name in self.__reserved_signature: 120 | raise ValueError("Channel '{0}' cannot be added in this way. Use obj.{0} instead.".format(name)) 121 | if not (for_window or for_horizon): 122 | raise ValueError("Either for_window or for_horizon must be True.") 123 | obj = self.check_input(obj) 124 | setattr(self, name, obj) 125 | self.__exogenous_keys[name] = dict(for_window=for_window, for_horizon=for_horizon) 126 | return self 127 | 128 | # Dataset properties 129 | 130 | @property 131 | def horizon_offset(self): 132 | return self.window + self.delay 133 | 134 | @property 135 | def sample_span(self): 136 | return max(self.horizon_offset + self.horizon, self.window) 137 | 138 | @property 139 | def preprocess(self): 140 | return (self.trend is not None) or (self.scaler is not None) 141 | 142 | @property 143 | def n_steps(self): 144 | return self.data.shape[0] 145 | 146 | @property 147 | def n_channels(self): 148 | return self.data.shape[-1] 149 | 150 | @property 151 | def indices(self): 152 | return self._indices 153 | 154 | # Signature information 155 | 156 | @property 157 | def exo_window_keys(self): 158 | return {k for k, v in self.__exogenous_keys.items() if v['for_window']} 159 | 160 | @property 161 | def exo_horizon_keys(self): 162 | return {k for k, v in self.__exogenous_keys.items() if v['for_horizon']} 163 | 164 | @property 165 | def exo_common_keys(self): 166 | return self.exo_window_keys.intersection(self.exo_horizon_keys) 167 | 168 | @property 169 | def signature(self): 170 | attrs = [] 171 | if self.window > 0: 172 | attrs.append('x') 173 | for attr in self.exo_window_keys: 174 | attrs.append(attr if attr not in self.exo_common_keys else (attr + '_window')) 175 | for attr in self.exo_horizon_keys: 176 | attrs.append(attr if attr not in self.exo_common_keys else (attr + '_horizon')) 177 | attrs.append('y') 178 | attrs = tuple(attrs) 179 | preprocess = [] 180 | if self.trend is not None: 181 | preprocess.append('trend') 182 | if self.scaler is not None: 183 | preprocess.extend(self.scaler.params()) 184 | preprocess = tuple(preprocess) 185 | return dict(data=attrs, preprocessing=preprocess) 186 | 187 | # Item getters 188 | 189 | def get(self, item, preprocess=False): 190 | idx = self._indices[item] 191 | res, transform = dict(), dict() 192 | if self.window > 0: 193 | res['x'] = self.data[idx:idx + self.window] 194 | for attr in self.exo_window_keys: 195 | key = attr if attr not in self.exo_common_keys else (attr + '_window') 196 | res[key] = getattr(self, attr)[idx:idx + self.window] 197 | for attr in self.exo_horizon_keys: 198 | key = attr if attr not in self.exo_common_keys else (attr + '_horizon') 199 | res[key] = getattr(self, attr)[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon] 200 | res['y'] = self.data[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon] 201 | if preprocess: 202 | if self.trend is not None: 203 | y_trend = self.trend[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon] 204 | res['y'] = res['y'] - y_trend 205 | transform['trend'] = y_trend 206 | if 'x' in res: 207 | res['x'] = res['x'] - self.trend[idx:idx + self.window] 208 | if self.scaler is not None: 209 | transform.update(self.scaler.params()) 210 | if 'x' in res: 211 | res['x'] = self.scaler.transform(res['x']) 212 | return res, transform 213 | 214 | def snapshot(self, indices=None, preprocess=True): 215 | if not self.preprocess: 216 | preprocess = False 217 | data, prep = [{k: [] for k in sign} for sign in self.signature.values()] 218 | indices = np.arange(len(self._indices)) if indices is None else indices 219 | for idx in indices: 220 | data_i, prep_i = self.get(idx, preprocess) 221 | [v.append(data_i[k]) for k, v in data.items()] 222 | if len(prep_i): 223 | [v.append(prep_i[k]) for k, v in prep.items()] 224 | data = {k: np.stack(ds) for k, ds in data.items() if len(ds)} 225 | if len(prep): 226 | prep = {k: np.stack(ds) if k == 'trend' else ds[0] for k, ds in prep.items() if len(ds)} 227 | return data, prep 228 | 229 | # Data utilities 230 | 231 | def expand_indices(self, indices=None, unique=False, merge=False): 232 | ds_indices = dict.fromkeys([time for time in ['window', 'horizon'] if getattr(self, time) > 0]) 233 | indices = np.arange(len(self._indices)) if indices is None else indices 234 | if 'window' in ds_indices: 235 | w_idxs = [np.arange(idx, idx + self.window) for idx in self._indices[indices]] 236 | ds_indices['window'] = np.concatenate(w_idxs) 237 | if 'horizon' in ds_indices: 238 | h_idxs = [np.arange(idx + self.horizon_offset, idx + self.horizon_offset + self.horizon) 239 | for idx in self._indices[indices]] 240 | ds_indices['horizon'] = np.concatenate(h_idxs) 241 | if unique: 242 | ds_indices = {k: np.unique(v) for k, v in ds_indices.items()} 243 | if merge: 244 | ds_indices = np.unique(np.concatenate(list(ds_indices.values()))) 245 | return ds_indices 246 | 247 | def overlapping_indices(self, idxs1, idxs2, synch_mode='window', as_mask=False): 248 | assert synch_mode in ['window', 'horizon'] 249 | ts1 = self.data_timestamps(idxs1, flatten=False)[synch_mode] 250 | ts2 = self.data_timestamps(idxs2, flatten=False)[synch_mode] 251 | common_ts = np.intersect1d(np.unique(ts1), np.unique(ts2)) 252 | is_overlapping = lambda sample: np.any(np.in1d(sample, common_ts)) 253 | m1 = np.apply_along_axis(is_overlapping, 1, ts1) 254 | m2 = np.apply_along_axis(is_overlapping, 1, ts2) 255 | if as_mask: 256 | return m1, m2 257 | return np.sort(idxs1[m1]), np.sort(idxs2[m2]) 258 | 259 | def data_timestamps(self, indices=None, flatten=True): 260 | ds_indices = self.expand_indices(indices, unique=False) 261 | ds_timestamps = {k: self.index[v] for k, v in ds_indices.items()} 262 | if not flatten: 263 | ds_timestamps = {k: np.array(v).reshape(-1, getattr(self, k)) for k, v in ds_timestamps.items()} 264 | return ds_timestamps 265 | 266 | def reduce_dataset(self, indices, inplace=False): 267 | if not inplace: 268 | from copy import deepcopy 269 | dataset = deepcopy(self) 270 | else: 271 | dataset = self 272 | old_index = dataset.index[dataset._indices[indices]] 273 | ds_indices = dataset.expand_indices(indices, merge=True) 274 | dataset.index = dataset.index[ds_indices] 275 | dataset.data = dataset.data[ds_indices] 276 | if dataset.mask is not None: 277 | dataset.mask = dataset.mask[ds_indices] 278 | if dataset.trend is not None: 279 | dataset.trend = dataset.trend[ds_indices] 280 | for attr in dataset.exo_window_keys.union(dataset.exo_horizon_keys): 281 | if getattr(dataset, attr, None) is not None: 282 | setattr(dataset, attr, getattr(dataset, attr)[ds_indices]) 283 | dataset._indices = np.flatnonzero(np.in1d(dataset.index, old_index)) 284 | return dataset 285 | 286 | def check_input(self, data): 287 | if data is None: 288 | return data 289 | data = self.check_dim(data) 290 | data = data.clone().detach() if isinstance(data, torch.Tensor) else torch.tensor(data) 291 | # cast data 292 | if torch.is_floating_point(data): 293 | return data.float() 294 | elif data.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]: 295 | return data.int() 296 | return data 297 | 298 | # Class-specific methods (override in children) 299 | 300 | @staticmethod 301 | def check_dim(data): 302 | if data.ndim == 1: # [steps] -> [steps, features] 303 | data = rearrange(data, '(s f) -> s f', f=1) 304 | elif data.ndim != 2: 305 | raise ValueError(f'Invalid data dimensions {data.shape}') 306 | return data 307 | 308 | def dataframe(self): 309 | return pd.DataFrame(data=self.data, index=self.index) 310 | 311 | @staticmethod 312 | def add_argparse_args(parser, **kwargs): 313 | parser.add_argument('--window', type=int, default=24) 314 | parser.add_argument('--horizon', type=int, default=24) 315 | parser.add_argument('--delay', type=int, default=0) 316 | parser.add_argument('--stride', type=int, default=1) 317 | return parser 318 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .air_quality import AirQuality 2 | from .metr_la import MissingValuesMetrLA 3 | from .pems_bay import MissingValuesPemsBay 4 | from .synthetic import ChargedParticles 5 | -------------------------------------------------------------------------------- /lib/datasets/air_quality.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from lib import datasets_path 7 | from .pd_dataset import PandasDataset 8 | from ..utils.utils import disjoint_months, infer_mask, compute_mean, geographical_distance, thresholded_gaussian_kernel 9 | 10 | 11 | class AirQuality(PandasDataset): 12 | SEED = 3210 13 | 14 | def __init__(self, impute_nans=False, small=False, freq='60T', masked_sensors=None): 15 | self.random = np.random.default_rng(self.SEED) 16 | self.test_months = [3, 6, 9, 12] 17 | self.infer_eval_from = 'next' 18 | self.eval_mask = None 19 | df, dist, mask = self.load(impute_nans=impute_nans, small=small, masked_sensors=masked_sensors) 20 | self.dist = dist 21 | if masked_sensors is None: 22 | self.masked_sensors = list() 23 | else: 24 | self.masked_sensors = list(masked_sensors) 25 | super().__init__(dataframe=df, u=None, mask=mask, name='air', freq=freq, aggr='nearest') 26 | 27 | def load_raw(self, small=False): 28 | if small: 29 | path = os.path.join(datasets_path['air'], 'small36.h5') 30 | eval_mask = pd.DataFrame(pd.read_hdf(path, 'eval_mask')) 31 | else: 32 | path = os.path.join(datasets_path['air'], 'full437.h5') 33 | eval_mask = None 34 | df = pd.DataFrame(pd.read_hdf(path, 'pm25')) 35 | stations = pd.DataFrame(pd.read_hdf(path, 'stations')) 36 | return df, stations, eval_mask 37 | 38 | def load(self, impute_nans=True, small=False, masked_sensors=None): 39 | # load readings and stations metadata 40 | df, stations, eval_mask = self.load_raw(small) 41 | # compute the masks 42 | mask = (~np.isnan(df.values)).astype('uint8') # 1 if value is not nan else 0 43 | if eval_mask is None: 44 | eval_mask = infer_mask(df, infer_from=self.infer_eval_from) 45 | 46 | eval_mask = eval_mask.values.astype('uint8') 47 | if masked_sensors is not None: 48 | eval_mask[:, masked_sensors] = np.where(mask[:, masked_sensors], 1, 0) 49 | self.eval_mask = eval_mask # 1 if value is ground-truth for imputation else 0 50 | # eventually replace nans with weekly mean by hour 51 | if impute_nans: 52 | df = df.fillna(compute_mean(df)) 53 | # compute distances from latitude and longitude degrees 54 | st_coord = stations.loc[:, ['latitude', 'longitude']] 55 | dist = geographical_distance(st_coord, to_rad=True).values 56 | return df, dist, mask 57 | 58 | def splitter(self, dataset, val_len=1., in_sample=False, window=0): 59 | nontest_idxs, test_idxs = disjoint_months(dataset, months=self.test_months, synch_mode='horizon') 60 | if in_sample: 61 | train_idxs = np.arange(len(dataset)) 62 | val_months = [(m - 1) % 12 for m in self.test_months] 63 | _, val_idxs = disjoint_months(dataset, months=val_months, synch_mode='horizon') 64 | else: 65 | # take equal number of samples before each month of testing 66 | val_len = (int(val_len * len(nontest_idxs)) if val_len < 1 else val_len) // len(self.test_months) 67 | # get indices of first day of each testing month 68 | delta_idxs = np.diff(test_idxs) 69 | end_month_idxs = test_idxs[1:][np.flatnonzero(delta_idxs > delta_idxs.min())] 70 | if len(end_month_idxs) < len(self.test_months): 71 | end_month_idxs = np.insert(end_month_idxs, 0, test_idxs[0]) 72 | # expand month indices 73 | month_val_idxs = [np.arange(v_idx - val_len, v_idx) - window for v_idx in end_month_idxs] 74 | val_idxs = np.concatenate(month_val_idxs) % len(dataset) 75 | # remove overlapping indices from training set 76 | ovl_idxs, _ = dataset.overlapping_indices(nontest_idxs, val_idxs, synch_mode='horizon', as_mask=True) 77 | train_idxs = nontest_idxs[~ovl_idxs] 78 | return [train_idxs, val_idxs, test_idxs] 79 | 80 | def get_similarity(self, thr=0.1, include_self=False, force_symmetric=False, sparse=False, **kwargs): 81 | theta = np.std(self.dist[:36, :36]) # use same theta for both air and air36 82 | adj = thresholded_gaussian_kernel(self.dist, theta=theta, threshold=thr) 83 | if not include_self: 84 | adj[np.diag_indices_from(adj)] = 0. 85 | if force_symmetric: 86 | adj = np.maximum.reduce([adj, adj.T]) 87 | if sparse: 88 | import scipy.sparse as sps 89 | adj = sps.coo_matrix(adj) 90 | return adj 91 | 92 | @property 93 | def mask(self): 94 | return self._mask 95 | 96 | @property 97 | def training_mask(self): 98 | return self._mask if self.eval_mask is None else (self._mask & (1 - self.eval_mask)) 99 | 100 | def test_interval_mask(self, dtype=bool, squeeze=True): 101 | m = np.in1d(self.df.index.month, self.test_months).astype(dtype) 102 | if squeeze: 103 | return m 104 | return m[:, None] 105 | -------------------------------------------------------------------------------- /lib/datasets/metr_la.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from lib import datasets_path 7 | from .pd_dataset import PandasDataset 8 | from ..utils import sample_mask 9 | 10 | 11 | class MetrLA(PandasDataset): 12 | def __init__(self, impute_zeros=False, freq='5T'): 13 | 14 | df, dist, mask = self.load(impute_zeros=impute_zeros) 15 | self.dist = dist 16 | super().__init__(dataframe=df, u=None, mask=mask, name='la', freq=freq, aggr='nearest') 17 | 18 | def load(self, impute_zeros=True): 19 | path = os.path.join(datasets_path['la'], 'metr_la.h5') 20 | df = pd.read_hdf(path) 21 | datetime_idx = sorted(df.index) 22 | date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T') 23 | df = df.reindex(index=date_range) 24 | mask = ~np.isnan(df.values) 25 | if impute_zeros: 26 | mask = mask * (df.values != 0.).astype('uint8') 27 | df = df.replace(to_replace=0., method='ffill') 28 | else: 29 | mask = None 30 | dist = self.load_distance_matrix() 31 | return df, dist, mask 32 | 33 | def load_distance_matrix(self): 34 | path = os.path.join(datasets_path['la'], 'metr_la_dist.npy') 35 | try: 36 | dist = np.load(path) 37 | except: 38 | distances = pd.read_csv(os.path.join(datasets_path['la'], 'distances_la.csv')) 39 | with open(os.path.join(datasets_path['la'], 'sensor_ids_la.txt')) as f: 40 | ids = f.read().strip().split(',') 41 | num_sensors = len(ids) 42 | dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf 43 | # Builds sensor id to index map. 44 | sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)} 45 | 46 | # Fills cells in the matrix with distances. 47 | for row in distances.values: 48 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind: 49 | continue 50 | dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2] 51 | np.save(path, dist) 52 | return dist 53 | 54 | def get_similarity(self, thr=0.1, force_symmetric=False, sparse=False): 55 | finite_dist = self.dist.reshape(-1) 56 | finite_dist = finite_dist[~np.isinf(finite_dist)] 57 | sigma = finite_dist.std() 58 | adj = np.exp(-np.square(self.dist / sigma)) 59 | adj[adj < thr] = 0. 60 | if force_symmetric: 61 | adj = np.maximum.reduce([adj, adj.T]) 62 | if sparse: 63 | import scipy.sparse as sps 64 | adj = sps.coo_matrix(adj) 65 | return adj 66 | 67 | @property 68 | def mask(self): 69 | return self._mask 70 | 71 | 72 | class MissingValuesMetrLA(MetrLA): 73 | SEED = 9101112 74 | 75 | def __init__(self, p_fault=0.0015, p_noise=0.05): 76 | super(MissingValuesMetrLA, self).__init__(impute_zeros=True) 77 | self.rng = np.random.default_rng(self.SEED) 78 | self.p_fault = p_fault 79 | self.p_noise = p_noise 80 | eval_mask = sample_mask(self.numpy().shape, 81 | p=p_fault, 82 | p_noise=p_noise, 83 | min_seq=12, 84 | max_seq=12 * 4, 85 | rng=self.rng) 86 | self.eval_mask = (eval_mask & self.mask).astype('uint8') 87 | 88 | @property 89 | def training_mask(self): 90 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask)) 91 | 92 | def splitter(self, dataset, val_len=0, test_len=0, window=0): 93 | idx = np.arange(len(dataset)) 94 | if test_len < 1: 95 | test_len = int(test_len * len(idx)) 96 | if val_len < 1: 97 | val_len = int(val_len * (len(idx) - test_len)) 98 | test_start = len(idx) - test_len 99 | val_start = test_start - val_len 100 | return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]] -------------------------------------------------------------------------------- /lib/datasets/pd_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | 5 | 6 | class PandasDataset: 7 | def __init__(self, dataframe: pd.DataFrame, u: pd.DataFrame = None, name='pd-dataset', mask=None, freq=None, 8 | aggr='sum', **kwargs): 9 | """ 10 | Initialize a tsl dataset from a pandas dataframe. 11 | 12 | 13 | :param dataframe: dataframe containing the data, shape: n_steps, n_nodes 14 | :param u: dataframe with exog variables 15 | :param name: optional name of the dataset 16 | :param mask: mask for valid data (1:valid, 0:not valid) 17 | :param freq: force a frequency (possibly by resampling) 18 | :param aggr: aggregation method after resampling 19 | """ 20 | super().__init__() 21 | self.name = name 22 | 23 | # set dataset dataframe 24 | self.df = dataframe 25 | 26 | # set optional exog_variable dataframe 27 | # make sure to consider only the overlapping part of the two dataframes 28 | # assumption u.index \in df.index 29 | idx = sorted(self.df.index) 30 | self.start = idx[0] 31 | self.end = idx[-1] 32 | 33 | if u is not None: 34 | self.u = u[self.start:self.end] 35 | else: 36 | self.u = None 37 | 38 | if mask is not None: 39 | mask = np.asarray(mask).astype('uint8') 40 | self._mask = mask 41 | 42 | if freq is not None: 43 | self.resample_(freq=freq, aggr=aggr) 44 | else: 45 | self.freq = self.df.index.inferred_freq 46 | # make sure that all the dataframes are aligned 47 | self.resample_(self.freq, aggr=aggr) 48 | 49 | assert 'T' in self.freq 50 | self.samples_per_day = int(60 / int(self.freq[:-1]) * 24) 51 | 52 | def __repr__(self): 53 | return "{}(nodes={}, length={})".format(self.__class__.__name__, self.n_nodes, self.length) 54 | 55 | @property 56 | def has_mask(self): 57 | return self._mask is not None 58 | 59 | @property 60 | def has_u(self): 61 | return self.u is not None 62 | 63 | def resample_(self, freq, aggr): 64 | resampler = self.df.resample(freq) 65 | idx = self.df.index 66 | if aggr == 'sum': 67 | self.df = resampler.sum() 68 | elif aggr == 'mean': 69 | self.df = resampler.mean() 70 | elif aggr == 'nearest': 71 | self.df = resampler.nearest() 72 | else: 73 | raise ValueError(f'{aggr} if not a valid aggregation method.') 74 | 75 | if self.has_mask: 76 | resampler = pd.DataFrame(self._mask, index=idx).resample(freq) 77 | self._mask = resampler.min().to_numpy() 78 | 79 | if self.has_u: 80 | resampler = self.u.resample(freq) 81 | self.u = resampler.nearest() 82 | self.freq = freq 83 | 84 | def dataframe(self) -> pd.DataFrame: 85 | return self.df.copy() 86 | 87 | @property 88 | def length(self): 89 | return self.df.values.shape[0] 90 | 91 | @property 92 | def n_nodes(self): 93 | return self.df.values.shape[1] 94 | 95 | @property 96 | def mask(self): 97 | if self._mask is None: 98 | return np.ones_like(self.df.values).astype('uint8') 99 | return self._mask 100 | 101 | def numpy(self, return_idx=False): 102 | if return_idx: 103 | return self.numpy(), self.df.index 104 | return self.df.values 105 | 106 | def pytorch(self): 107 | data = self.numpy() 108 | return torch.FloatTensor(data) 109 | 110 | def __len__(self): 111 | return self.length 112 | 113 | @staticmethod 114 | def build(): 115 | raise NotImplementedError 116 | 117 | def load_raw(self): 118 | raise NotImplementedError 119 | 120 | def load(self): 121 | raise NotImplementedError 122 | -------------------------------------------------------------------------------- /lib/datasets/pems_bay.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from lib import datasets_path 7 | from .pd_dataset import PandasDataset 8 | from ..utils import sample_mask 9 | 10 | 11 | class PemsBay(PandasDataset): 12 | def __init__(self): 13 | df, dist, mask = self.load() 14 | self.dist = dist 15 | super().__init__(dataframe=df, u=None, mask=mask, name='bay', freq='5T', aggr='nearest') 16 | 17 | def load(self, impute_zeros=True): 18 | path = os.path.join(datasets_path['bay'], 'pems_bay.h5') 19 | df = pd.read_hdf(path) 20 | datetime_idx = sorted(df.index) 21 | date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T') 22 | df = df.reindex(index=date_range) 23 | mask = ~np.isnan(df.values) 24 | df.fillna(method='ffill', axis=0, inplace=True) 25 | dist = self.load_distance_matrix(list(df.columns)) 26 | return df.astype('float32'), dist, mask.astype('uint8') 27 | 28 | def load_distance_matrix(self, ids): 29 | path = os.path.join(datasets_path['bay'], 'pems_bay_dist.npy') 30 | try: 31 | dist = np.load(path) 32 | except: 33 | distances = pd.read_csv(os.path.join(datasets_path['bay'], 'distances_bay.csv')) 34 | num_sensors = len(ids) 35 | dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf 36 | # Builds sensor id to index map. 37 | sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)} 38 | 39 | # Fills cells in the matrix with distances. 40 | for row in distances.values: 41 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind: 42 | continue 43 | dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2] 44 | np.save(path, dist) 45 | return dist 46 | 47 | def get_similarity(self, type='dcrnn', thr=0.1, force_symmetric=False, sparse=False): 48 | """ 49 | Return similarity matrix among nodes. Implemented to match DCRNN. 50 | 51 | :param type: type of similarity matrix. 52 | :param thr: threshold to increase saprseness. 53 | :param trainlen: number of steps that can be used for computing the similarity. 54 | :param force_symmetric: force the result to be simmetric. 55 | :return: and NxN array representig similarity among nodes. 56 | """ 57 | if type == 'dcrnn': 58 | finite_dist = self.dist.reshape(-1) 59 | finite_dist = finite_dist[~np.isinf(finite_dist)] 60 | sigma = finite_dist.std() 61 | adj = np.exp(-np.square(self.dist / sigma)) 62 | elif type == 'stcn': 63 | sigma = 10 64 | adj = np.exp(-np.square(self.dist) / sigma) 65 | else: 66 | raise NotImplementedError 67 | adj[adj < thr] = 0. 68 | if force_symmetric: 69 | adj = np.maximum.reduce([adj, adj.T]) 70 | if sparse: 71 | import scipy.sparse as sps 72 | adj = sps.coo_matrix(adj) 73 | return adj 74 | 75 | @property 76 | def mask(self): 77 | if self._mask is None: 78 | return self.df.values != 0. 79 | return self._mask 80 | 81 | 82 | class MissingValuesPemsBay(PemsBay): 83 | SEED = 56789 84 | 85 | def __init__(self, p_fault=0.0015, p_noise=0.05): 86 | super(MissingValuesPemsBay, self).__init__() 87 | self.rng = np.random.default_rng(self.SEED) 88 | self.p_fault = p_fault 89 | self.p_noise = p_noise 90 | eval_mask = sample_mask(self.numpy().shape, 91 | p=p_fault, 92 | p_noise=p_noise, 93 | min_seq=12, 94 | max_seq=12 * 4, 95 | rng=self.rng) 96 | self.eval_mask = (eval_mask & self.mask).astype('uint8') 97 | 98 | @property 99 | def training_mask(self): 100 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask)) 101 | 102 | def splitter(self, dataset, val_len=0, test_len=0, window=0): 103 | idx = np.arange(len(dataset)) 104 | if test_len < 1: 105 | test_len = int(test_len * len(idx)) 106 | if val_len < 1: 107 | val_len = int(val_len * (len(idx) - test_len)) 108 | test_start = len(idx) - test_len 109 | val_start = test_start - val_len 110 | return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]] 111 | -------------------------------------------------------------------------------- /lib/datasets/synthetic.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import rearrange 6 | from torch.utils.data import Dataset, DataLoader, Subset 7 | 8 | from lib import datasets_path 9 | 10 | 11 | def generate_mask(shape, p_block=0.01, p_point=0.01, max_seq=1, min_seq=1, rng=None): 12 | """Generate mask in which 1 denotes valid values, 0 missing ones. Assuming shape=(steps, ...).""" 13 | if rng is None: 14 | rand = np.random.random 15 | randint = np.random.randint 16 | else: 17 | rand = rng.random 18 | randint = rng.integers 19 | # init mask 20 | mask = np.ones(shape, dtype='uint8') 21 | # block missing 22 | if p_block > 0: 23 | assert max_seq >= min_seq 24 | for col in range(shape[1]): 25 | i = 0 26 | while i < shape[0]: 27 | if rand() > p_block: 28 | i += 1 29 | else: 30 | fault_len = int(randint(min_seq, max_seq + 1)) 31 | mask[i:i + fault_len, col] = 0 32 | i += fault_len + 1 # at least one valid value between two blocks 33 | # point missing 34 | # let values before and after block missing always valid 35 | diff = np.zeros(mask.shape, dtype='uint8') 36 | diff[:-1] |= np.diff(mask, axis=0) < 0 37 | diff[1:] |= np.diff(mask, axis=0) > 0 38 | mask = np.where(mask - diff, rand(shape) > p_point, mask) 39 | return mask 40 | 41 | 42 | class SyntheticDataset(Dataset): 43 | SEED: int 44 | 45 | def __init__(self, filename, 46 | window=None, 47 | p_block=0.05, 48 | p_point=0.05, 49 | max_seq=6, 50 | min_seq=4, 51 | use_exogenous=True, 52 | mask_exogenous=True, 53 | graph_mode=True): 54 | super(SyntheticDataset, self).__init__() 55 | self.mask_exogenous = mask_exogenous 56 | self.use_exogenous = use_exogenous 57 | self.graph_mode = graph_mode 58 | # fetch data 59 | content = self.load(filename) 60 | self.window = window if window is not None else content['loc'].shape[1] 61 | self.loc = torch.tensor(content['loc'][:, :self.window]).float() 62 | self.vel = torch.tensor(content['vel'][:, :self.window]).float() 63 | self.adj = content['adj'] 64 | self.SEED = content['seed'].item() 65 | # compute masks 66 | self.rng = np.random.default_rng(self.SEED) 67 | mask_shape = (len(self), self.window, self.n_nodes, 1) 68 | mask = generate_mask(mask_shape, 69 | p_block=p_block, 70 | p_point=p_point, 71 | max_seq=max_seq, 72 | min_seq=min_seq, 73 | rng=self.rng).repeat(self.n_channels, -1) 74 | eval_mask = 1 - generate_mask(mask_shape, 75 | p_block=p_block, 76 | p_point=p_point, 77 | max_seq=max_seq, 78 | min_seq=min_seq, 79 | rng=self.rng).repeat(self.n_channels, -1) 80 | self.mask = torch.tensor(mask).byte() 81 | self.eval_mask = torch.tensor(eval_mask).byte() & self.mask 82 | # store splitting indices 83 | self.train_idxs = None 84 | self.val_idxs = None 85 | self.test_idxs = None 86 | 87 | def __len__(self): 88 | return self.loc.size(0) 89 | 90 | def __getitem__(self, index): 91 | eval_mask = self.eval_mask[index] 92 | mask = self.training_mask[index] 93 | x = mask * self.loc[index] 94 | res = dict(x=x, mask=mask, eval_mask=eval_mask) 95 | if self.use_exogenous: 96 | u = self.vel[index] 97 | if self.mask_exogenous: 98 | u *= mask.all(-1, keepdims=True) 99 | res.update(u=u) 100 | res.update(y=self.loc[index]) 101 | if not self.graph_mode: 102 | res = {k: rearrange(v, 's n f -> s (n f)') for k, v in res.items()} 103 | return res 104 | 105 | @property 106 | def n_channels(self): 107 | return self.loc.size(-1) 108 | 109 | @property 110 | def n_nodes(self): 111 | return self.loc.size(-2) 112 | 113 | @property 114 | def n_exogenous(self): 115 | return self.vel.size(-1) if self.use_exogenous else 0 116 | 117 | @property 118 | def training_mask(self): 119 | return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask)) 120 | 121 | @staticmethod 122 | def load(filename): 123 | return np.load(filename) 124 | 125 | def get_similarity(self, sparse=False): 126 | return self.adj 127 | 128 | # Splitting options 129 | 130 | def split(self, val_len=0, test_len=0): 131 | idx = np.arange(len(self)) 132 | if test_len < 1: 133 | test_len = int(test_len * len(idx)) 134 | if val_len < 1: 135 | val_len = int(val_len * (len(idx) - test_len)) 136 | test_start = len(idx) - test_len 137 | val_start = test_start - val_len 138 | # split dataset 139 | self.train_idxs = idx[:val_start] 140 | self.val_idxs = idx[val_start:test_start] 141 | self.test_idxs = idx[test_start:] 142 | 143 | def train_dataloader(self, shuffle=True, batch_size=32): 144 | return DataLoader(Subset(self, self.train_idxs), shuffle=shuffle, batch_size=batch_size, drop_last=True) 145 | 146 | def val_dataloader(self, shuffle=False, batch_size=32): 147 | return DataLoader(Subset(self, self.val_idxs), shuffle=shuffle, batch_size=batch_size) 148 | 149 | def test_dataloader(self, shuffle=False, batch_size=32): 150 | return DataLoader(Subset(self, self.test_idxs), shuffle=shuffle, batch_size=batch_size) 151 | 152 | 153 | class ChargedParticles(SyntheticDataset): 154 | 155 | def __init__(self, static_adj=False, 156 | window=None, 157 | p_block=0.05, 158 | p_point=0.05, 159 | max_seq=6, 160 | min_seq=4, 161 | use_exogenous=True, 162 | mask_exogenous=True, 163 | graph_mode=True): 164 | if static_adj: 165 | filename = os.path.join(datasets_path['synthetic'], 'charged_static.npz') 166 | else: 167 | filename = os.path.join(datasets_path['synthetic'], 'charged_varying.npz') 168 | self.static_adj = static_adj 169 | super(ChargedParticles, self).__init__(filename, window, 170 | p_block=p_block, 171 | p_point=p_point, 172 | max_seq=max_seq, 173 | min_seq=min_seq, 174 | use_exogenous=use_exogenous, 175 | mask_exogenous=mask_exogenous, 176 | graph_mode=graph_mode) 177 | charges = self.load(filename)['charges'] 178 | self.charges = torch.tensor(charges).float() 179 | 180 | def __getitem__(self, item): 181 | res = super(ChargedParticles, self).__getitem__(item) 182 | # add charges as exogenous features 183 | if self.use_exogenous: 184 | charges = self.charges[item] if not self.static_adj else self.charges 185 | stacked_charges = charges[None].expand(self.window, -1, -1) 186 | if not self.graph_mode: 187 | stacked_charges = rearrange(stacked_charges, 's n f -> s (n f)') 188 | res.update(u=torch.cat([res['u'], stacked_charges], -1)) 189 | return res 190 | 191 | def get_similarity(self, sparse=False): 192 | return np.ones((self.n_nodes, self.n_nodes)) - np.eye(self.n_nodes) 193 | 194 | @property 195 | def n_exogenous(self): 196 | if self.use_exogenous: 197 | return super(ChargedParticles, self).n_exogenous + 1 # add charges to features 198 | return 0 199 | -------------------------------------------------------------------------------- /lib/fillers/__init__.py: -------------------------------------------------------------------------------- 1 | from .filler import Filler 2 | from .britsfiller import BRITSFiller 3 | from .graphfiller import GraphFiller 4 | from .rgainfiller import RGAINFiller 5 | from .multi_imputation_filler import MultiImputationFiller 6 | -------------------------------------------------------------------------------- /lib/fillers/britsfiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import Filler 4 | from ..nn import BRITS 5 | 6 | 7 | class BRITSFiller(Filler): 8 | 9 | def training_step(self, batch, batch_idx): 10 | # Unpack batch 11 | batch_data, batch_preprocessing = self._unpack_batch(batch) 12 | 13 | # Extract mask and target 14 | mask = batch_data['mask'].clone().detach() 15 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte() 16 | eval_mask = batch_data.pop('eval_mask', None) 17 | y = batch_data.pop('y') 18 | 19 | # Compute predictions and compute loss 20 | out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False) 21 | 22 | if self.scaled_target: 23 | target = self._preprocess(y, batch_preprocessing) 24 | else: 25 | target = y 26 | imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations] 27 | predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions] 28 | 29 | loss = sum([self.loss_fn(pred, target, mask) for pred in predictions]) 30 | loss += BRITS.consistency_loss(*imputations) 31 | 32 | # Logging 33 | metrics_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data 34 | out = self._postprocess(out, batch_preprocessing) 35 | self.train_metrics.update(out.detach(), y, metrics_mask) 36 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 37 | self.log('train_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=False) 38 | return loss 39 | 40 | def validation_step(self, batch, batch_idx): 41 | # Unpack batch 42 | batch_data, batch_preprocessing = self._unpack_batch(batch) 43 | 44 | # Extract mask and target 45 | mask = batch_data.get('mask') 46 | eval_mask = batch_data.pop('eval_mask', None) 47 | y = batch_data.pop('y') 48 | 49 | # Compute predictions and compute loss 50 | out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False) 51 | 52 | if self.scaled_target: 53 | target = self._preprocess(y, batch_preprocessing) 54 | else: 55 | target = y 56 | predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions] 57 | 58 | val_loss = sum([self.loss_fn(pred, target, mask) for pred in predictions]) 59 | 60 | # Logging 61 | out = self._postprocess(out, batch_preprocessing) 62 | self.val_metrics.update(out.detach(), y, eval_mask) 63 | self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 64 | self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 65 | return val_loss 66 | 67 | def test_step(self, batch, batch_idx): 68 | # Unpack batch 69 | batch_data, batch_preprocessing = self._unpack_batch(batch) 70 | 71 | # Extract mask and target 72 | eval_mask = batch_data.pop('eval_mask', None) 73 | y = batch_data.pop('y') 74 | 75 | # Compute outputs and rescale 76 | imputation, *_ = self.predict_batch(batch, preprocess=False, postprocess=True) 77 | test_loss = self.loss_fn(imputation, y, eval_mask) 78 | 79 | # Logging 80 | self.test_metrics.update(imputation.detach(), y, eval_mask) 81 | self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 82 | self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 83 | return test_loss 84 | -------------------------------------------------------------------------------- /lib/fillers/graphfiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import Filler 4 | from ..nn.models import MPGRUNet, GRINet, BiMPGRUNet 5 | 6 | 7 | class GraphFiller(Filler): 8 | 9 | def __init__(self, 10 | model_class, 11 | model_kwargs, 12 | optim_class, 13 | optim_kwargs, 14 | loss_fn, 15 | scaled_target=False, 16 | whiten_prob=0.05, 17 | pred_loss_weight=1., 18 | warm_up=0, 19 | metrics=None, 20 | scheduler_class=None, 21 | scheduler_kwargs=None): 22 | super(GraphFiller, self).__init__(model_class=model_class, 23 | model_kwargs=model_kwargs, 24 | optim_class=optim_class, 25 | optim_kwargs=optim_kwargs, 26 | loss_fn=loss_fn, 27 | scaled_target=scaled_target, 28 | whiten_prob=whiten_prob, 29 | metrics=metrics, 30 | scheduler_class=scheduler_class, 31 | scheduler_kwargs=scheduler_kwargs) 32 | 33 | self.tradeoff = pred_loss_weight 34 | if model_class is MPGRUNet: 35 | self.trimming = (warm_up, 0) 36 | elif model_class in [GRINet, BiMPGRUNet]: 37 | self.trimming = (warm_up, warm_up) 38 | 39 | def trim_seq(self, *seq): 40 | seq = [s[:, self.trimming[0]:s.size(1) - self.trimming[1]] for s in seq] 41 | if len(seq) == 1: 42 | return seq[0] 43 | return seq 44 | 45 | def training_step(self, batch, batch_idx): 46 | # Unpack batch 47 | batch_data, batch_preprocessing = self._unpack_batch(batch) 48 | 49 | # Compute masks 50 | mask = batch_data['mask'].clone().detach() 51 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte() 52 | eval_mask = batch_data.pop('eval_mask', None) 53 | eval_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data 54 | 55 | y = batch_data.pop('y') 56 | 57 | # Compute predictions and compute loss 58 | res = self.predict_batch(batch, preprocess=False, postprocess=False) 59 | imputation, predictions = (res[0], res[1:]) if isinstance(res, (list, tuple)) else (res, []) 60 | 61 | # trim to imputation horizon len 62 | imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y) 63 | predictions = self.trim_seq(*predictions) 64 | 65 | if self.scaled_target: 66 | target = self._preprocess(y, batch_preprocessing) 67 | else: 68 | target = y 69 | imputation = self._postprocess(imputation, batch_preprocessing) 70 | for i, _ in enumerate(predictions): 71 | predictions[i] = self._postprocess(predictions[i], batch_preprocessing) 72 | 73 | loss = self.loss_fn(imputation, target, mask) 74 | for pred in predictions: 75 | loss += self.tradeoff * self.loss_fn(pred, target, mask) 76 | 77 | # Logging 78 | if self.scaled_target: 79 | imputation = self._postprocess(imputation, batch_preprocessing) 80 | self.train_metrics.update(imputation.detach(), y, eval_mask) # all unseen data 81 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 82 | self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 83 | return loss 84 | 85 | def validation_step(self, batch, batch_idx): 86 | # Unpack batch 87 | batch_data, batch_preprocessing = self._unpack_batch(batch) 88 | 89 | # Extract mask and target 90 | mask = batch_data.get('mask') 91 | eval_mask = batch_data.pop('eval_mask', None) 92 | y = batch_data.pop('y') 93 | 94 | # Compute predictions and compute loss 95 | imputation = self.predict_batch(batch, preprocess=False, postprocess=False) 96 | 97 | # trim to imputation horizon len 98 | imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y) 99 | 100 | if self.scaled_target: 101 | target = self._preprocess(y, batch_preprocessing) 102 | else: 103 | target = y 104 | imputation = self._postprocess(imputation, batch_preprocessing) 105 | 106 | val_loss = self.loss_fn(imputation, target, eval_mask) 107 | 108 | # Logging 109 | if self.scaled_target: 110 | imputation = self._postprocess(imputation, batch_preprocessing) 111 | self.val_metrics.update(imputation.detach(), y, eval_mask) 112 | self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 113 | self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 114 | return val_loss 115 | 116 | def test_step(self, batch, batch_idx): 117 | # Unpack batch 118 | batch_data, batch_preprocessing = self._unpack_batch(batch) 119 | 120 | # Extract mask and target 121 | eval_mask = batch_data.pop('eval_mask', None) 122 | y = batch_data.pop('y') 123 | 124 | # Compute outputs and rescale 125 | imputation = self.predict_batch(batch, preprocess=False, postprocess=True) 126 | test_loss = self.loss_fn(imputation, y, eval_mask) 127 | 128 | # Logging 129 | self.test_metrics.update(imputation.detach(), y, eval_mask) 130 | self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 131 | self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 132 | return test_loss 133 | -------------------------------------------------------------------------------- /lib/fillers/multi_imputation_filler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.core.decorators import auto_move_data 3 | 4 | from . import Filler 5 | 6 | 7 | class MultiImputationFiller(Filler): 8 | """ 9 | Filler with multiple imputation outputs 10 | """ 11 | 12 | def __init__(self, 13 | model_class, 14 | model_kwargs, 15 | optim_class, 16 | optim_kwargs, 17 | loss_fn, 18 | consistency_loss=False, 19 | scaled_target=False, 20 | whiten_prob=0.05, 21 | metrics=None, 22 | scheduler_class=None, 23 | scheduler_kwargs=None): 24 | 25 | super().__init__(model_class, 26 | model_kwargs, 27 | optim_class, 28 | optim_kwargs, 29 | loss_fn, 30 | scaled_target, 31 | whiten_prob, 32 | metrics, 33 | scheduler_class, 34 | scheduler_kwargs) 35 | self.consistency_loss = consistency_loss 36 | 37 | @auto_move_data 38 | def forward(self, *args, **kwargs): 39 | out = self.model(*args, **kwargs) 40 | assert isinstance(out, (list, tuple)) 41 | if self.training: 42 | return out 43 | return out[0] # we assume that the final imputation is the first one 44 | 45 | def _consistency_loss(self, imputations, mask): 46 | from itertools import combinations 47 | return sum([self.loss_fn(imp1, imp2, mask) for imp1, imp2 in combinations(imputations, 2)]) 48 | 49 | def training_step(self, batch, batch_idx): 50 | # Unpack batch 51 | batch_data, batch_preprocessing = self._unpack_batch(batch) 52 | 53 | # Extract mask and target 54 | mask = batch_data['mask'].clone().detach() 55 | batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte() 56 | eval_mask = batch_data.pop('eval_mask', None) 57 | y = batch_data.pop('y') 58 | 59 | # Compute predictions and compute loss 60 | imputations = self.predict_batch(batch, preprocess=False, postprocess=False) 61 | 62 | if self.scaled_target: 63 | target = self._preprocess(y, batch_preprocessing) 64 | else: 65 | target = y 66 | imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations] 67 | 68 | loss = sum([self.loss_fn(imp, target, mask) for imp in imputations]) 69 | if self.consistency_loss: 70 | loss += self._consistency_loss(imputations, mask) 71 | 72 | # Logging 73 | metrics_mask = (mask | eval_mask) - batch_data['mask'] # all unseen data 74 | 75 | x_hat = imputations[0] 76 | x_hat = self._postprocess(x_hat, batch_preprocessing) 77 | self.train_metrics.update(x_hat.detach(), y, metrics_mask) 78 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 79 | self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 80 | return loss 81 | -------------------------------------------------------------------------------- /lib/fillers/rgainfiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from .multi_imputation_filler import MultiImputationFiller 5 | from ..nn.utils.metric_base import MaskedMetric 6 | 7 | 8 | class MaskedBCEWithLogits(MaskedMetric): 9 | def __init__(self, 10 | mask_nans=False, 11 | mask_inf=False, 12 | compute_on_step=True, 13 | dist_sync_on_step=False, 14 | process_group=None, 15 | dist_sync_fn=None, 16 | at=None): 17 | super(MaskedBCEWithLogits, self).__init__(metric_fn=F.binary_cross_entropy_with_logits, 18 | mask_nans=mask_nans, 19 | mask_inf=mask_inf, 20 | compute_on_step=compute_on_step, 21 | dist_sync_on_step=dist_sync_on_step, 22 | process_group=process_group, 23 | dist_sync_fn=dist_sync_fn, 24 | metric_kwargs={'reduction': 'none'}, 25 | at=at) 26 | 27 | 28 | class RGAINFiller(MultiImputationFiller): 29 | def __init__(self, 30 | model_class, 31 | model_kwargs, 32 | optim_class, 33 | optim_kwargs, 34 | loss_fn, 35 | g_train_freq=1, 36 | d_train_freq=5, 37 | consistency_loss=False, 38 | scaled_target=True, 39 | whiten_prob=0.05, 40 | hint_rate=0.7, 41 | alpha=10., 42 | metrics=None, 43 | scheduler_class=None, 44 | scheduler_kwargs=None): 45 | super(RGAINFiller, self).__init__(model_class=model_class, 46 | model_kwargs=model_kwargs, 47 | optim_class=optim_class, 48 | optim_kwargs=optim_kwargs, 49 | loss_fn=loss_fn, 50 | scaled_target=scaled_target, 51 | whiten_prob=whiten_prob, 52 | metrics=metrics, 53 | consistency_loss=consistency_loss, 54 | scheduler_class=scheduler_class, 55 | scheduler_kwargs=scheduler_kwargs) 56 | # discriminator training params 57 | self.alpha = alpha 58 | self.g_train_freq = g_train_freq 59 | self.d_train_freq = d_train_freq 60 | self.masked_bce_loss = MaskedBCEWithLogits(compute_on_step=True) 61 | # activate manual optimization 62 | self.automatic_optimization = False 63 | self.hint_rate = hint_rate 64 | 65 | def training_step(self, batch, batch_idx): 66 | # Unpack batch 67 | batch_data, batch_preprocessing = self._unpack_batch(batch) 68 | g_opt, d_opt = self.optimizers() 69 | schedulers = self.lr_schedulers() 70 | 71 | # Extract mask and target 72 | x = batch_data.pop('x') 73 | mask = batch_data['mask'].clone().detach() 74 | training_mask = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte() 75 | eval_mask = batch_data.pop('eval_mask', None) 76 | y = batch_data.pop('y') 77 | 78 | ########################## 79 | # generate imputations 80 | ########################## 81 | 82 | imputations = self.model.generator(x, training_mask) 83 | imputed_seq = imputations[0] 84 | target = self._preprocess(y, batch_preprocessing) 85 | y_hat = self._postprocess(imputed_seq, batch_preprocessing) 86 | 87 | x_in = training_mask * x + (1 - training_mask) * imputed_seq 88 | hint = torch.rand_like(training_mask, dtype=torch.float) < self.hint_rate 89 | hint = hint.byte() 90 | hint = hint * training_mask + (1 - hint) * 0.5 91 | 92 | ######################### 93 | # train generator 94 | ######################### 95 | if (batch_idx % self.g_train_freq) == 0: 96 | 97 | g_opt.zero_grad() 98 | 99 | rec_loss = sum([torch.sqrt(self.loss_fn(imp, target, mask)) for imp in imputations]) 100 | if self.consistency_loss: 101 | rec_loss += self._consistency_loss(imputations, mask) 102 | 103 | logits = self.model.discriminator(x_in, hint) 104 | # maximize logit 105 | adv_loss = self.masked_bce_loss(logits, torch.ones_like(logits), 1 - training_mask) 106 | 107 | g_loss = self.alpha * rec_loss + adv_loss 108 | 109 | self.manual_backward(g_loss) 110 | g_opt.step() 111 | 112 | # Logging 113 | metrics_mask = (mask | eval_mask) - training_mask 114 | self.train_metrics.update(y_hat.detach(), y, metrics_mask) # all unseen data 115 | self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True) 116 | self.log('gen_loss', adv_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 117 | self.log('imp_loss', rec_loss.detach(), on_step=True, on_epoch=True, logger=True, prog_bar=True) 118 | 119 | ########################### 120 | # train discriminator 121 | ########################### 122 | 123 | if (batch_idx % self.d_train_freq) == 0: 124 | d_opt.zero_grad() 125 | 126 | logits = self.model.discriminator(x_in.detach(), hint) 127 | d_loss = self.masked_bce_loss(logits, training_mask.to(logits.dtype)) 128 | 129 | self.manual_backward(d_loss) 130 | d_opt.step() 131 | self.log('d_loss', d_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False) 132 | 133 | if (schedulers is not None) and self.trainer.is_last_batch: 134 | for sch in schedulers: 135 | sch.step() 136 | 137 | def configure_optimizers(self): 138 | opt_g = self.optim_class(self.model.generator.parameters(), **self.optim_kwargs) 139 | opt_d = self.optim_class(self.model.discriminator.parameters(), **self.optim_kwargs) 140 | optimizers = [opt_g, opt_d] 141 | if self.scheduler_class is not None: 142 | metric = self.scheduler_kwargs.pop('monitor', None) 143 | schedulers = [{"scheduler": self.scheduler_class(opt, **self.scheduler_kwargs), "monitor": metric} 144 | for opt in optimizers] 145 | return optimizers, schedulers 146 | return optimizers 147 | -------------------------------------------------------------------------------- /lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | -------------------------------------------------------------------------------- /lib/nn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .rits import RITS, BRITS 2 | from .gril import GRIL, BiGRIL 3 | from .spatial_conv import SpatialConvOrderK 4 | from .mpgru import MPGRUImputer 5 | -------------------------------------------------------------------------------- /lib/nn/layers/gcrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .spatial_conv import SpatialConvOrderK 5 | 6 | 7 | class GCGRUCell(nn.Module): 8 | """ 9 | Graph Convolution Gated Recurrent Unit Cell. 10 | """ 11 | 12 | def __init__(self, d_in, num_units, support_len, order, activation='tanh'): 13 | """ 14 | :param num_units: the hidden dim of rnn 15 | :param support_len: the (weighted) adjacency matrix of the graph, in numpy ndarray form 16 | :param order: the max diffusion step 17 | :param activation: if None, don't do activation for cell state 18 | """ 19 | super(GCGRUCell, self).__init__() 20 | self.activation_fn = getattr(torch, activation) 21 | 22 | self.forget_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, 23 | order=order) 24 | self.update_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, 25 | order=order) 26 | self.c_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, order=order) 27 | 28 | def forward(self, x, h, adj): 29 | """ 30 | :param x: (B, input_dim, num_nodes) 31 | :param h: (B, num_units, num_nodes) 32 | :param adj: (num_nodes, num_nodes) 33 | :return: 34 | """ 35 | # we start with bias 1.0 to not reset and not update 36 | x_gates = torch.cat([x, h], dim=1) 37 | r = torch.sigmoid(self.forget_gate(x_gates, adj)) 38 | u = torch.sigmoid(self.update_gate(x_gates, adj)) 39 | x_c = torch.cat([x, r * h], dim=1) 40 | c = self.c_gate(x_c, adj) # batch_size, self._num_nodes * output_size 41 | c = self.activation_fn(c) 42 | return u * h + (1. - u) * c 43 | 44 | 45 | class GCRNN(nn.Module): 46 | def __init__(self, 47 | d_in, 48 | d_model, 49 | d_out, 50 | n_layers, 51 | support_len, 52 | kernel_size=2): 53 | super(GCRNN, self).__init__() 54 | self.d_in = d_in 55 | self.d_model = d_model 56 | self.d_out = d_out 57 | self.n_layers = n_layers 58 | self.ks = kernel_size 59 | self.support_len = support_len 60 | self.rnn_cells = nn.ModuleList() 61 | for i in range(self.n_layers): 62 | self.rnn_cells.append(GCGRUCell(d_in=self.d_in if i == 0 else self.d_model, 63 | num_units=self.d_model, support_len=self.support_len, order=self.ks)) 64 | self.output_layer = nn.Conv2d(self.d_model, self.d_out, kernel_size=1) 65 | 66 | def init_hidden_states(self, x): 67 | return [torch.zeros(size=(x.shape[0], self.d_model, x.shape[2])).to(x.device) for _ in range(self.n_layers)] 68 | 69 | def single_pass(self, x, h, adj): 70 | out = x 71 | for l, layer in enumerate(self.rnn_cells): 72 | out = h[l] = layer(out, h[l], adj) 73 | return out, h 74 | 75 | def forward(self, x, adj, h=None): 76 | # x:[batch, features, nodes, steps] 77 | *_, steps = x.size() 78 | if h is None: 79 | h = self.init_hidden_states(x) 80 | # temporal conv 81 | for step in range(steps): 82 | out, h = self.single_pass(x[..., step], h, adj) 83 | 84 | return self.output_layer(out[..., None]) 85 | -------------------------------------------------------------------------------- /lib/nn/layers/gril.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from .spatial_conv import SpatialConvOrderK 6 | from .gcrnn import GCGRUCell 7 | from .spatial_attention import SpatialAttention 8 | from ..utils.ops import reverse_tensor 9 | 10 | 11 | class SpatialDecoder(nn.Module): 12 | def __init__(self, d_in, d_model, d_out, support_len, order=1, attention_block=False, nheads=2, dropout=0.): 13 | super(SpatialDecoder, self).__init__() 14 | self.order = order 15 | self.lin_in = nn.Conv1d(d_in, d_model, kernel_size=1) 16 | self.graph_conv = SpatialConvOrderK(c_in=d_model, c_out=d_model, 17 | support_len=support_len * order, order=1, include_self=False) 18 | if attention_block: 19 | self.spatial_att = SpatialAttention(d_in=d_model, 20 | d_model=d_model, 21 | nheads=nheads, 22 | dropout=dropout) 23 | self.lin_out = nn.Conv1d(3 * d_model, d_model, kernel_size=1) 24 | else: 25 | self.register_parameter('spatial_att', None) 26 | self.lin_out = nn.Conv1d(2 * d_model, d_model, kernel_size=1) 27 | self.read_out = nn.Conv1d(2 * d_model, d_out, kernel_size=1) 28 | self.activation = nn.PReLU() 29 | self.adj = None 30 | 31 | def forward(self, x, m, h, u, adj, cached_support=False): 32 | # [batch, channels, nodes] 33 | x_in = [x, m, h] if u is None else [x, m, u, h] 34 | x_in = torch.cat(x_in, 1) 35 | if self.order > 1: 36 | if cached_support and (self.adj is not None): 37 | adj = self.adj 38 | else: 39 | adj = SpatialConvOrderK.compute_support_orderK(adj, self.order, include_self=False, device=x_in.device) 40 | self.adj = adj if cached_support else None 41 | 42 | x_in = self.lin_in(x_in) 43 | out = self.graph_conv(x_in, adj) 44 | if self.spatial_att is not None: 45 | # [batch, channels, nodes] -> [batch, steps, nodes, features] 46 | x_in = rearrange(x_in, 'b f n -> b 1 n f') 47 | out_att = self.spatial_att(x_in, torch.eye(x_in.size(2), dtype=torch.bool, device=x_in.device)) 48 | out_att = rearrange(out_att, 'b s n f -> b f (n s)') 49 | out = torch.cat([out, out_att], 1) 50 | out = torch.cat([out, h], 1) 51 | out = self.activation(self.lin_out(out)) 52 | # out = self.lin_out(out) 53 | out = torch.cat([out, h], 1) 54 | return self.read_out(out), out 55 | 56 | 57 | class GRIL(nn.Module): 58 | def __init__(self, 59 | input_size, 60 | hidden_size, 61 | u_size=None, 62 | n_layers=1, 63 | dropout=0., 64 | kernel_size=2, 65 | decoder_order=1, 66 | global_att=False, 67 | support_len=2, 68 | n_nodes=None, 69 | layer_norm=False): 70 | super(GRIL, self).__init__() 71 | self.input_size = int(input_size) 72 | self.hidden_size = int(hidden_size) 73 | self.u_size = int(u_size) if u_size is not None else 0 74 | self.n_layers = int(n_layers) 75 | rnn_input_size = 2 * self.input_size + self.u_size # input + mask + (eventually) exogenous 76 | 77 | # Spatio-temporal encoder (rnn_input_size -> hidden_size) 78 | self.cells = nn.ModuleList() 79 | self.norms = nn.ModuleList() 80 | for i in range(self.n_layers): 81 | self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size, 82 | num_units=self.hidden_size, support_len=support_len, order=kernel_size)) 83 | if layer_norm: 84 | self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size)) 85 | else: 86 | self.norms.append(nn.Identity()) 87 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 88 | 89 | # Fist stage readout 90 | self.first_stage = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1) 91 | 92 | # Spatial decoder (rnn_input_size + hidden_size -> hidden_size) 93 | self.spatial_decoder = SpatialDecoder(d_in=rnn_input_size + self.hidden_size, 94 | d_model=self.hidden_size, 95 | d_out=self.input_size, 96 | support_len=2, 97 | order=decoder_order, 98 | attention_block=global_att) 99 | 100 | # Hidden state initialization embedding 101 | if n_nodes is not None: 102 | self.h0 = self.init_hidden_states(n_nodes) 103 | else: 104 | self.register_parameter('h0', None) 105 | 106 | def init_hidden_states(self, n_nodes): 107 | h0 = [] 108 | for l in range(self.n_layers): 109 | std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float)) 110 | vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes)) 111 | h0.append(nn.Parameter(vals)) 112 | return nn.ParameterList(h0) 113 | 114 | def get_h0(self, x): 115 | if self.h0 is not None: 116 | return [h.expand(x.shape[0], -1, -1) for h in self.h0] 117 | return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers 118 | 119 | def update_state(self, x, h, adj): 120 | rnn_in = x 121 | for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)): 122 | rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj)) 123 | if self.dropout is not None and layer < (self.n_layers - 1): 124 | rnn_in = self.dropout(rnn_in) 125 | return h 126 | 127 | def forward(self, x, adj, mask=None, u=None, h=None, cached_support=False): 128 | # x:[batch, features, nodes, steps] 129 | *_, steps = x.size() 130 | 131 | # infer all valid if mask is None 132 | if mask is None: 133 | mask = torch.ones_like(x, dtype=torch.uint8) 134 | 135 | # init hidden state using node embedding or the empty state 136 | if h is None: 137 | h = self.get_h0(x) 138 | elif not isinstance(h, list): 139 | h = [*h] 140 | 141 | # Temporal conv 142 | predictions, imputations, states = [], [], [] 143 | representations = [] 144 | for step in range(steps): 145 | x_s = x[..., step] 146 | m_s = mask[..., step] 147 | h_s = h[-1] 148 | u_s = u[..., step] if u is not None else None 149 | # firstly impute missing values with predictions from state 150 | xs_hat_1 = self.first_stage(h_s) 151 | # fill missing values in input with prediction 152 | x_s = torch.where(m_s, x_s, xs_hat_1) 153 | # prepare inputs 154 | # retrieve maximum information from neighbors 155 | xs_hat_2, repr_s = self.spatial_decoder(x=x_s, m=m_s, h=h_s, u=u_s, adj=adj, 156 | cached_support=cached_support) # receive messages from neighbors (no self-loop!) 157 | # readout of imputation state + mask to retrieve imputations 158 | # prepare inputs 159 | x_s = torch.where(m_s, x_s, xs_hat_2) 160 | inputs = [x_s, m_s] 161 | if u_s is not None: 162 | inputs.append(u_s) 163 | inputs = torch.cat(inputs, dim=1) # x_hat_2 + mask + exogenous 164 | # update state with original sequence filled using imputations 165 | h = self.update_state(inputs, h, adj) 166 | # store imputations and states 167 | imputations.append(xs_hat_2) 168 | predictions.append(xs_hat_1) 169 | states.append(torch.stack(h, dim=0)) 170 | representations.append(repr_s) 171 | 172 | # Aggregate outputs -> [batch, features, nodes, steps] 173 | imputations = torch.stack(imputations, dim=-1) 174 | predictions = torch.stack(predictions, dim=-1) 175 | states = torch.stack(states, dim=-1) 176 | representations = torch.stack(representations, dim=-1) 177 | 178 | return imputations, predictions, representations, states 179 | 180 | 181 | class BiGRIL(nn.Module): 182 | def __init__(self, 183 | input_size, 184 | hidden_size, 185 | ff_size, 186 | ff_dropout, 187 | n_layers=1, 188 | dropout=0., 189 | n_nodes=None, 190 | support_len=2, 191 | kernel_size=2, 192 | decoder_order=1, 193 | global_att=False, 194 | u_size=0, 195 | embedding_size=0, 196 | layer_norm=False, 197 | merge='mlp'): 198 | super(BiGRIL, self).__init__() 199 | self.fwd_rnn = GRIL(input_size=input_size, 200 | hidden_size=hidden_size, 201 | n_layers=n_layers, 202 | dropout=dropout, 203 | n_nodes=n_nodes, 204 | support_len=support_len, 205 | kernel_size=kernel_size, 206 | decoder_order=decoder_order, 207 | global_att=global_att, 208 | u_size=u_size, 209 | layer_norm=layer_norm) 210 | self.bwd_rnn = GRIL(input_size=input_size, 211 | hidden_size=hidden_size, 212 | n_layers=n_layers, 213 | dropout=dropout, 214 | n_nodes=n_nodes, 215 | support_len=support_len, 216 | kernel_size=kernel_size, 217 | decoder_order=decoder_order, 218 | global_att=global_att, 219 | u_size=u_size, 220 | layer_norm=layer_norm) 221 | 222 | if n_nodes is None: 223 | embedding_size = 0 224 | if embedding_size > 0: 225 | self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes)) 226 | nn.init.kaiming_normal_(self.emb, nonlinearity='relu') 227 | else: 228 | self.register_parameter('emb', None) 229 | 230 | if merge == 'mlp': 231 | self._impute_from_states = True 232 | self.out = nn.Sequential( 233 | nn.Conv2d(in_channels=4 * hidden_size + input_size + embedding_size, 234 | out_channels=ff_size, kernel_size=1), 235 | nn.ReLU(), 236 | nn.Dropout(ff_dropout), 237 | nn.Conv2d(in_channels=ff_size, out_channels=input_size, kernel_size=1) 238 | ) 239 | elif merge in ['mean', 'sum', 'min', 'max']: 240 | self._impute_from_states = False 241 | self.out = getattr(torch, merge) 242 | else: 243 | raise ValueError("Merge option %s not allowed." % merge) 244 | self.supp = None 245 | 246 | def forward(self, x, adj, mask=None, u=None, cached_support=False): 247 | if cached_support and (self.supp is not None): 248 | supp = self.supp 249 | else: 250 | supp = SpatialConvOrderK.compute_support(adj, x.device) 251 | self.supp = supp if cached_support else None 252 | # Forward 253 | fwd_out, fwd_pred, fwd_repr, _ = self.fwd_rnn(x, supp, mask=mask, u=u, cached_support=cached_support) 254 | # Backward 255 | rev_x, rev_mask, rev_u = [reverse_tensor(tens) for tens in (x, mask, u)] 256 | *bwd_res, _ = self.bwd_rnn(rev_x, supp, mask=rev_mask, u=rev_u, cached_support=cached_support) 257 | bwd_out, bwd_pred, bwd_repr = [reverse_tensor(res) for res in bwd_res] 258 | 259 | if self._impute_from_states: 260 | inputs = [fwd_repr, bwd_repr, mask] 261 | if self.emb is not None: 262 | b, *_, s = fwd_repr.shape # fwd_h: [batches, channels, nodes, steps] 263 | inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)] # stack emb for batches and steps 264 | imputation = torch.cat(inputs, dim=1) 265 | imputation = self.out(imputation) 266 | else: 267 | imputation = torch.stack([fwd_out, bwd_out], dim=1) 268 | imputation = self.out(imputation, dim=1) 269 | 270 | predictions = torch.stack([fwd_out, bwd_out, fwd_pred, bwd_pred], dim=0) 271 | 272 | return imputation, predictions 273 | -------------------------------------------------------------------------------- /lib/nn/layers/imputation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class ImputationLayer(nn.Module): 9 | def __init__(self, d_in, bias=True): 10 | super(ImputationLayer, self).__init__() 11 | self.W = nn.Parameter(torch.Tensor(d_in, d_in)) 12 | if bias: 13 | self.b = nn.Parameter(torch.Tensor(d_in)) 14 | else: 15 | self.register_buffer('b', None) 16 | mask = 1. - torch.eye(d_in) 17 | self.register_buffer('mask', mask) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) 22 | if self.b is not None: 23 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W) 24 | bound = 1 / math.sqrt(fan_in) 25 | nn.init.uniform_(self.b, -bound, bound) 26 | 27 | def forward(self, x): 28 | # batch, features 29 | return F.linear(x, self.mask * self.W, self.b) -------------------------------------------------------------------------------- /lib/nn/layers/mpgru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .gcrnn import GCGRUCell 5 | 6 | 7 | class MPGRUImputer(nn.Module): 8 | def __init__(self, 9 | input_size, 10 | hidden_size, 11 | ff_size=None, 12 | u_size=None, 13 | n_layers=1, 14 | dropout=0., 15 | kernel_size=2, 16 | support_len=2, 17 | n_nodes=None, 18 | layer_norm=False, 19 | autoencoder_mode=False): 20 | super(MPGRUImputer, self).__init__() 21 | self.input_size = int(input_size) 22 | self.hidden_size = int(hidden_size) 23 | self.ff_size = int(ff_size) if ff_size is not None else 0 24 | self.u_size = int(u_size) if u_size is not None else 0 25 | self.n_layers = int(n_layers) 26 | rnn_input_size = 2 * self.input_size + self.u_size # input + mask + (eventually) exogenous 27 | 28 | # Spatio-temporal encoder (rnn_input_size -> hidden_size) 29 | self.cells = nn.ModuleList() 30 | self.norms = nn.ModuleList() 31 | for i in range(self.n_layers): 32 | self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size, 33 | num_units=self.hidden_size, support_len=support_len, order=kernel_size)) 34 | if layer_norm: 35 | self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size)) 36 | else: 37 | self.norms.append(nn.Identity()) 38 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 39 | 40 | # Readout 41 | if self.ff_size: 42 | self.pred_readout = nn.Sequential( 43 | nn.Conv1d(in_channels=self.hidden_size, out_channels=self.ff_size, kernel_size=1), 44 | nn.PReLU(), 45 | nn.Conv1d(in_channels=self.ff_size, out_channels=self.input_size, kernel_size=1) 46 | ) 47 | else: 48 | self.pred_readout = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1) 49 | 50 | # Hidden state initialization embedding 51 | if n_nodes is not None: 52 | self.h0 = self.init_hidden_states(n_nodes) 53 | else: 54 | self.register_parameter('h0', None) 55 | 56 | self.autoencoder_mode = autoencoder_mode 57 | 58 | def init_hidden_states(self, n_nodes): 59 | h0 = [] 60 | for l in range(self.n_layers): 61 | std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float)) 62 | vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes)) 63 | h0.append(nn.Parameter(vals)) 64 | return nn.ParameterList(h0) 65 | 66 | def get_h0(self, x): 67 | if self.h0 is not None: 68 | return [h.expand(x.shape[0], -1, -1) for h in self.h0] 69 | return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers 70 | 71 | def update_state(self, x, h, adj): 72 | rnn_in = x 73 | for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)): 74 | rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj)) 75 | if self.dropout is not None and layer < (self.n_layers - 1): 76 | rnn_in = self.dropout(rnn_in) 77 | return h 78 | 79 | def forward(self, x, adj, mask=None, u=None, h=None): 80 | # x:[batch, features, nodes, steps] 81 | *_, steps = x.size() 82 | 83 | # infer all valid if mask is None 84 | if mask is None: 85 | mask = torch.ones_like(x, dtype=torch.uint8) 86 | 87 | # init hidden state using node embedding or the empty state 88 | if h is None: 89 | h = self.get_h0(x) 90 | elif not isinstance(h, list): 91 | h = [*h] 92 | 93 | # Temporal conv 94 | predictions, states = [], [] 95 | for step in range(steps): 96 | x_s = x[..., step] 97 | m_s = mask[..., step] 98 | h_s = h[-1] 99 | u_s = u[..., step] if u is not None else None 100 | # impute missing values with predictions from state 101 | x_s_hat = self.pred_readout(h_s) 102 | # store imputations and state 103 | predictions.append(x_s_hat) 104 | states.append(torch.stack(h, dim=0)) 105 | # fill missing values in input with prediction 106 | x_s = torch.where(m_s, x_s, x_s_hat) 107 | inputs = [x_s, m_s] 108 | if u_s is not None: 109 | inputs.append(u_s) 110 | inputs = torch.cat(inputs, dim=1) # x_hat complemented + mask + exogenous 111 | # update state with original sequence filled using imputations 112 | h = self.update_state(inputs, h, adj) 113 | 114 | # In autoencoder mode use states after input processing 115 | if self.autoencoder_mode: 116 | states = states[1:] + [torch.stack(h, dim=0)] 117 | 118 | # Aggregate outputs -> [batch, features, nodes, steps] 119 | predictions = torch.stack(predictions, dim=-1) 120 | states = torch.stack(states, dim=-1) 121 | 122 | return predictions, states 123 | -------------------------------------------------------------------------------- /lib/nn/layers/rits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.nn.parameter import Parameter 8 | 9 | from ..utils.ops import reverse_tensor 10 | 11 | 12 | class FeatureRegression(nn.Module): 13 | def __init__(self, input_size): 14 | super(FeatureRegression, self).__init__() 15 | self.W = Parameter(torch.Tensor(input_size, input_size)) 16 | self.b = Parameter(torch.Tensor(input_size)) 17 | 18 | m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) 19 | self.register_buffer('m', m) 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | stdv = 1. / math.sqrt(self.W.shape[0]) 25 | self.W.data.uniform_(-stdv, stdv) 26 | if self.b is not None: 27 | self.b.data.uniform_(-stdv, stdv) 28 | 29 | def forward(self, x): 30 | z_h = F.linear(x, self.W * Variable(self.m), self.b) 31 | return z_h 32 | 33 | 34 | class TemporalDecay(nn.Module): 35 | def __init__(self, d_in, d_out, diag=False): 36 | super(TemporalDecay, self).__init__() 37 | self.diag = diag 38 | self.W = Parameter(torch.Tensor(d_out, d_in)) 39 | self.b = Parameter(torch.Tensor(d_out)) 40 | 41 | if self.diag: 42 | assert (d_in == d_out) 43 | m = torch.eye(d_in, d_in) 44 | self.register_buffer('m', m) 45 | 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | stdv = 1. / math.sqrt(self.W.shape[0]) 50 | self.W.data.uniform_(-stdv, stdv) 51 | if self.b is not None: 52 | self.b.data.uniform_(-stdv, stdv) 53 | 54 | @staticmethod 55 | def compute_delta(mask, freq=1): 56 | delta = torch.zeros_like(mask).float() 57 | one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device) 58 | for i in range(1, delta.shape[-2]): 59 | m = mask[..., i - 1, :] 60 | delta[..., i, :] = m * one_step + (1 - m) * torch.add(delta[..., i - 1, :], freq) 61 | return delta 62 | 63 | def forward(self, d): 64 | if self.diag: 65 | gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b)) 66 | else: 67 | gamma = F.relu(F.linear(d, self.W, self.b)) 68 | gamma = torch.exp(-gamma) 69 | return gamma 70 | 71 | 72 | class RITS(nn.Module): 73 | def __init__(self, 74 | input_size, 75 | hidden_size=64): 76 | super(RITS, self).__init__() 77 | self.input_size = int(input_size) 78 | self.hidden_size = int(hidden_size) 79 | 80 | self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size) 81 | 82 | self.temp_decay_h = TemporalDecay(d_in=self.input_size, d_out=self.hidden_size, diag=False) 83 | self.temp_decay_x = TemporalDecay(d_in=self.input_size, d_out=self.input_size, diag=True) 84 | 85 | self.hist_reg = nn.Linear(self.hidden_size, self.input_size) 86 | self.feat_reg = FeatureRegression(self.input_size) 87 | 88 | self.weight_combine = nn.Linear(2 * self.input_size, self.input_size) 89 | 90 | def init_hidden_states(self, x): 91 | return Variable(torch.zeros((x.shape[0], self.hidden_size))).to(x.device) 92 | 93 | def forward(self, x, mask=None, delta=None): 94 | # x : [batch, steps, features] 95 | steps = x.shape[-2] 96 | 97 | if mask is None: 98 | mask = torch.ones_like(x, dtype=torch.uint8) 99 | if delta is None: 100 | delta = TemporalDecay.compute_delta(mask) 101 | 102 | # init rnn states 103 | h = self.init_hidden_states(x) 104 | c = self.init_hidden_states(x) 105 | 106 | imputation = [] 107 | predictions = [] 108 | for step in range(steps): 109 | d = delta[:, step, :] 110 | m = mask[:, step, :] 111 | x_s = x[:, step, :] 112 | 113 | gamma_h = self.temp_decay_h(d) 114 | 115 | # history prediction 116 | x_h = self.hist_reg(h) 117 | x_c = m * x_s + (1 - m) * x_h 118 | h = h * gamma_h 119 | 120 | # feature prediction 121 | z_h = self.feat_reg(x_c) 122 | 123 | # predictions combination 124 | gamma_x = self.temp_decay_x(d) 125 | alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1)) 126 | alpha = torch.sigmoid(alpha) 127 | c_h = alpha * z_h + (1 - alpha) * x_h 128 | 129 | c_c = m * x_s + (1 - m) * c_h 130 | inputs = torch.cat([c_c, m], dim=1) 131 | h, c = self.rnn_cell(inputs, (h, c)) 132 | 133 | imputation.append(c_c) 134 | predictions.append(torch.stack((c_h, z_h, x_h), dim=0)) 135 | 136 | # imputation -> [batch, steps, features] 137 | imputation = torch.stack(imputation, dim=-2) 138 | # predictions -> [predictions, batch, steps, features] 139 | predictions = torch.stack(predictions, dim=-2) 140 | c_h, z_h, x_h = predictions 141 | 142 | return imputation, (c_h, z_h, x_h) 143 | 144 | 145 | class BRITS(nn.Module): 146 | 147 | def __init__(self, input_size, hidden_size): 148 | super().__init__() 149 | self.rits_fwd = RITS(input_size, hidden_size) 150 | self.rits_bwd = RITS(input_size, hidden_size) 151 | 152 | def forward(self, x, mask=None): 153 | # x: [batches, steps, features] 154 | # forward 155 | imp_fwd, pred_fwd = self.rits_fwd(x, mask) 156 | # backward 157 | x_bwd = reverse_tensor(x, axis=1) 158 | mask_bwd = reverse_tensor(mask, axis=1) if mask is not None else None 159 | imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd) 160 | imp_bwd, pred_bwd = reverse_tensor(imp_bwd, axis=1), [reverse_tensor(pb, axis=1) for pb in pred_bwd] 161 | # stack into shape = [batch, directions, steps, features] 162 | imputation = torch.stack([imp_fwd, imp_bwd], dim=1) 163 | predictions = [torch.stack([pf, pb], dim=1) for pf, pb in zip(pred_fwd, pred_bwd)] 164 | c_h, z_h, x_h = predictions 165 | 166 | return imputation, (c_h, z_h, x_h) 167 | 168 | @staticmethod 169 | def consistency_loss(imp_fwd, imp_bwd): 170 | loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean() 171 | return loss 172 | -------------------------------------------------------------------------------- /lib/nn/layers/spatial_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from einops import rearrange 4 | 5 | 6 | class SpatialAttention(nn.Module): 7 | def __init__(self, d_in, d_model, nheads, dropout=0.): 8 | super(SpatialAttention, self).__init__() 9 | self.lin_in = nn.Linear(d_in, d_model) 10 | self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout) 11 | 12 | def forward(self, x, att_mask=None, **kwargs): 13 | r"""Pass the input through the encoder layer. 14 | 15 | Args: 16 | src: the sequence to the encoder layer (required). 17 | src_mask: the mask for the src sequence (optional). 18 | src_key_padding_mask: the mask for the src keys per batch (optional). 19 | 20 | Shape: 21 | see the docs in Transformer class. 22 | """ 23 | b, s, n, f = x.size() 24 | x = rearrange(x, 'b s n f -> n (b s) f') 25 | x = self.lin_in(x) 26 | x = self.self_attn(x, x, x, attn_mask=att_mask)[0] 27 | x = rearrange(x, 'n (b s) f -> b s n f', b=b, s=s) 28 | return x 29 | -------------------------------------------------------------------------------- /lib/nn/layers/spatial_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ... import epsilon 5 | 6 | 7 | class SpatialConvOrderK(nn.Module): 8 | """ 9 | Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs) 10 | 11 | Efficient implementation inspired from graph-wavenet codebase 12 | """ 13 | 14 | def __init__(self, c_in, c_out, support_len=3, order=2, include_self=True): 15 | super(SpatialConvOrderK, self).__init__() 16 | self.include_self = include_self 17 | c_in = (order * support_len + (1 if include_self else 0)) * c_in 18 | self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1) 19 | self.order = order 20 | 21 | @staticmethod 22 | def compute_support(adj, device=None): 23 | if device is not None: 24 | adj = adj.to(device) 25 | adj_bwd = adj.T 26 | adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon) 27 | adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon) 28 | support = [adj_fwd, adj_bwd] 29 | return support 30 | 31 | @staticmethod 32 | def compute_support_orderK(adj, k, include_self=False, device=None): 33 | if isinstance(adj, (list, tuple)): 34 | support = adj 35 | else: 36 | support = SpatialConvOrderK.compute_support(adj, device) 37 | supp_k = [] 38 | for a in support: 39 | ak = a 40 | for i in range(k - 1): 41 | ak = torch.matmul(ak, a.T) 42 | if not include_self: 43 | ak.fill_diagonal_(0.) 44 | supp_k.append(ak) 45 | return support + supp_k 46 | 47 | def forward(self, x, support): 48 | # [batch, features, nodes, steps] 49 | if x.dim() < 4: 50 | squeeze = True 51 | x = torch.unsqueeze(x, -1) 52 | else: 53 | squeeze = False 54 | out = [x] if self.include_self else [] 55 | if (type(support) is not list): 56 | support = [support] 57 | for a in support: 58 | x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous() 59 | out.append(x1) 60 | for k in range(2, self.order + 1): 61 | x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous() 62 | out.append(x2) 63 | x1 = x2 64 | 65 | out = torch.cat(out, dim=1) 66 | out = self.mlp(out) 67 | if squeeze: 68 | out = out.squeeze(-1) 69 | return out 70 | -------------------------------------------------------------------------------- /lib/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .grin import GRINet 2 | from .brits import BRITSNet 3 | from .mpgru import MPGRUNet, BiMPGRUNet 4 | from .var import VARImputer 5 | from .rgain import RGAINNet 6 | from .rnn_imputers import BiRNNImputer, RNNImputer 7 | -------------------------------------------------------------------------------- /lib/nn/models/brits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ..layers import BRITS 5 | 6 | 7 | class BRITSNet(nn.Module): 8 | def __init__(self, 9 | d_in, 10 | d_hidden=64): 11 | super(BRITSNet, self).__init__() 12 | self.birits = BRITS(input_size=d_in, 13 | hidden_size=d_hidden) 14 | 15 | def forward(self, x, mask=None, **kwargs): 16 | # x: [batches, steps, features] 17 | imputations, predictions = self.birits(x, mask=mask) 18 | # predictions: [batch, directions, steps, features] x 3 19 | out = torch.mean(imputations, dim=1) # -> [batch, steps, features] 20 | predictions = torch.cat(predictions, dim=1) # -> [batch, directions * n_predictions, steps, features] 21 | # reshape 22 | imputations = torch.transpose(imputations, 0, 1) # rearrange(imputations, 'b d s f -> d b s f') 23 | predictions = torch.transpose(predictions, 0, 1) # rearrange(predictions, 'b d s f -> d b s f') 24 | return out, imputations, predictions 25 | 26 | @staticmethod 27 | def add_model_specific_args(parser): 28 | parser.add_argument('--d-in', type=int) 29 | parser.add_argument('--d-hidden', type=int, default=64) 30 | return parser 31 | -------------------------------------------------------------------------------- /lib/nn/models/grin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | from ..layers import BiGRIL 6 | from ...utils.parser_utils import str_to_bool 7 | 8 | 9 | class GRINet(nn.Module): 10 | def __init__(self, 11 | adj, 12 | d_in, 13 | d_hidden, 14 | d_ff, 15 | ff_dropout, 16 | n_layers=1, 17 | kernel_size=2, 18 | decoder_order=1, 19 | global_att=False, 20 | d_u=0, 21 | d_emb=0, 22 | layer_norm=False, 23 | merge='mlp', 24 | impute_only_holes=True): 25 | super(GRINet, self).__init__() 26 | self.d_in = d_in 27 | self.d_hidden = d_hidden 28 | self.d_u = int(d_u) if d_u is not None else 0 29 | self.d_emb = int(d_emb) if d_emb is not None else 0 30 | self.register_buffer('adj', torch.tensor(adj).float()) 31 | self.impute_only_holes = impute_only_holes 32 | 33 | self.bigrill = BiGRIL(input_size=self.d_in, 34 | ff_size=d_ff, 35 | ff_dropout=ff_dropout, 36 | hidden_size=self.d_hidden, 37 | embedding_size=self.d_emb, 38 | n_nodes=self.adj.shape[0], 39 | n_layers=n_layers, 40 | kernel_size=kernel_size, 41 | decoder_order=decoder_order, 42 | global_att=global_att, 43 | u_size=self.d_u, 44 | layer_norm=layer_norm, 45 | merge=merge) 46 | 47 | def forward(self, x, mask=None, u=None, **kwargs): 48 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps] 49 | x = rearrange(x, 'b s n c -> b c n s') 50 | if mask is not None: 51 | mask = rearrange(mask, 'b s n c -> b c n s') 52 | 53 | if u is not None: 54 | u = rearrange(u, 'b s n c -> b c n s') 55 | 56 | # imputation: [batches, channels, nodes, steps] prediction: [4, batches, channels, nodes, steps] 57 | imputation, prediction = self.bigrill(x, self.adj, mask=mask, u=u, cached_support=self.training) 58 | # In evaluation stage impute only missing values 59 | if self.impute_only_holes and not self.training: 60 | imputation = torch.where(mask, x, imputation) 61 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels] 62 | imputation = torch.transpose(imputation, -3, -1) 63 | prediction = torch.transpose(prediction, -3, -1) 64 | if self.training: 65 | return imputation, prediction 66 | return imputation 67 | 68 | @staticmethod 69 | def add_model_specific_args(parser): 70 | parser.add_argument('--d-hidden', type=int, default=64) 71 | parser.add_argument('--d-ff', type=int, default=64) 72 | parser.add_argument('--ff-dropout', type=int, default=0.) 73 | parser.add_argument('--n-layers', type=int, default=1) 74 | parser.add_argument('--kernel-size', type=int, default=2) 75 | parser.add_argument('--decoder-order', type=int, default=1) 76 | parser.add_argument('--d-u', type=int, default=0) 77 | parser.add_argument('--d-emb', type=int, default=8) 78 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False) 79 | parser.add_argument('--global-att', type=str_to_bool, nargs='?', const=True, default=False) 80 | parser.add_argument('--merge', type=str, default='mlp') 81 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True) 82 | return parser 83 | -------------------------------------------------------------------------------- /lib/nn/models/mpgru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | from ..layers import MPGRUImputer, SpatialConvOrderK 6 | from ..utils.ops import reverse_tensor 7 | from ...utils.parser_utils import str_to_bool 8 | 9 | 10 | class MPGRUNet(nn.Module): 11 | def __init__(self, 12 | adj, 13 | d_in, 14 | d_hidden, 15 | d_ff=0, 16 | d_u=0, 17 | n_layers=1, 18 | dropout=0., 19 | kernel_size=2, 20 | support_len=2, 21 | layer_norm=False, 22 | impute_only_holes=True): 23 | super(MPGRUNet, self).__init__() 24 | self.register_buffer('adj', torch.tensor(adj).float()) 25 | n_nodes = adj.shape[0] 26 | self.gcgru = MPGRUImputer(input_size=d_in, 27 | hidden_size=d_hidden, 28 | ff_size=d_ff, 29 | u_size=d_u, 30 | n_layers=n_layers, 31 | dropout=dropout, 32 | kernel_size=kernel_size, 33 | support_len=support_len, 34 | layer_norm=layer_norm, 35 | n_nodes=n_nodes) 36 | self.impute_only_holes = impute_only_holes 37 | 38 | def forward(self, x, mask=None, u=None, h=None): 39 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps] 40 | x = rearrange(x, 'b s n c -> b c n s') 41 | if mask is not None: 42 | mask = rearrange(mask, 'b s n c -> b c n s') 43 | if u is not None: 44 | u = rearrange(u, 'b s n c -> b c n s') 45 | 46 | adj = SpatialConvOrderK.compute_support(self.adj, x.device) 47 | imputation, _ = self.gcgru(x, adj, mask=mask, u=u, h=h) 48 | 49 | # In evaluation stage impute only missing values 50 | if self.impute_only_holes and not self.training: 51 | imputation = torch.where(mask, x, imputation) 52 | 53 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels] 54 | imputation = rearrange(imputation, 'b c n s -> b s n c') 55 | 56 | return imputation 57 | 58 | @staticmethod 59 | def add_model_specific_args(parser): 60 | parser.add_argument('--d-hidden', type=int, default=64) 61 | parser.add_argument('--d-ff', type=int, default=64) 62 | parser.add_argument('--n-layers', type=int, default=1) 63 | parser.add_argument('--kernel-size', type=int, default=2) 64 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False) 65 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True) 66 | parser.add_argument('--dropout', type=float, default=0.) 67 | return parser 68 | 69 | 70 | class BiMPGRUNet(nn.Module): 71 | def __init__(self, 72 | adj, 73 | d_in, 74 | d_hidden, 75 | d_ff=0, 76 | d_u=0, 77 | n_layers=1, 78 | dropout=0., 79 | kernel_size=2, 80 | support_len=2, 81 | layer_norm=False, 82 | embedding_size=0, 83 | merge='mlp', 84 | impute_only_holes=True, 85 | autoencoder_mode=False): 86 | super(BiMPGRUNet, self).__init__() 87 | self.register_buffer('adj', torch.tensor(adj).float()) 88 | n_nodes = adj.shape[0] 89 | self.gcgru_fwd = MPGRUImputer(input_size=d_in, 90 | hidden_size=d_hidden, 91 | u_size=d_u, 92 | n_layers=n_layers, 93 | dropout=dropout, 94 | kernel_size=kernel_size, 95 | support_len=support_len, 96 | layer_norm=layer_norm, 97 | n_nodes=n_nodes, 98 | autoencoder_mode=autoencoder_mode) 99 | self.gcgru_bwd = MPGRUImputer(input_size=d_in, 100 | hidden_size=d_hidden, 101 | u_size=d_u, 102 | n_layers=n_layers, 103 | dropout=dropout, 104 | kernel_size=kernel_size, 105 | support_len=support_len, 106 | layer_norm=layer_norm, 107 | n_nodes=n_nodes, 108 | autoencoder_mode=autoencoder_mode) 109 | self.impute_only_holes = impute_only_holes 110 | 111 | if n_nodes is None: 112 | embedding_size = 0 113 | if embedding_size > 0: 114 | self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes)) 115 | nn.init.kaiming_normal_(self.emb, nonlinearity='relu') 116 | else: 117 | self.register_parameter('emb', None) 118 | 119 | if merge == 'mlp': 120 | self._impute_from_states = True 121 | self.out = nn.Sequential( 122 | nn.Conv2d(in_channels=2 * d_hidden + d_in + embedding_size, 123 | out_channels=d_ff, kernel_size=1), 124 | nn.ReLU(), 125 | nn.Conv2d(in_channels=d_ff, out_channels=d_in, kernel_size=1) 126 | ) 127 | elif merge in ['mean', 'sum', 'min', 'max']: 128 | self._impute_from_states = False 129 | self.out = getattr(torch, merge) 130 | else: 131 | raise ValueError("Merge option %s not allowed." % merge) 132 | 133 | def forward(self, x, mask=None, u=None, h=None): 134 | # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps] 135 | x = rearrange(x, 'b s n c -> b c n s') 136 | if mask is not None: 137 | mask = rearrange(mask, 'b s n c -> b c n s') 138 | if u is not None: 139 | u = rearrange(u, 'b s n c -> b c n s') 140 | 141 | adj = SpatialConvOrderK.compute_support(self.adj, x.device) 142 | 143 | # Forward 144 | fwd_pred, fwd_states = self.gcgru_fwd(x, adj, mask=mask, u=u) 145 | # Backward 146 | rev_x, rev_mask, rev_u = [reverse_tensor(tens, axis=-1) for tens in (x, mask, u)] 147 | bwd_res = self.gcgru_bwd(rev_x, adj, mask=rev_mask, u=rev_u) 148 | bwd_pred, bwd_states = [reverse_tensor(res, axis=-1) for res in bwd_res] 149 | 150 | if self._impute_from_states: 151 | inputs = [fwd_states[-1], bwd_states[-1], mask] # take only state of last gcgru layer 152 | if self.emb is not None: 153 | b, *_, s = x.shape # fwd_h: [batches, channels, nodes, steps] 154 | inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)] # stack emb for batches and steps 155 | imputation = torch.cat(inputs, dim=1) 156 | imputation = self.out(imputation) 157 | else: 158 | imputation = torch.stack([fwd_pred, bwd_pred], dim=1) 159 | imputation = self.out(imputation, dim=1) 160 | 161 | # In evaluation stage impute only missing values 162 | if self.impute_only_holes and not self.training: 163 | imputation = torch.where(mask, x, imputation) 164 | 165 | # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels] 166 | imputation = rearrange(imputation, 'b c n s -> b s n c') 167 | 168 | return imputation 169 | 170 | @staticmethod 171 | def add_model_specific_args(parser): 172 | parser.add_argument('--d-hidden', type=int, default=64) 173 | parser.add_argument('--d-ff', type=int, default=64) 174 | parser.add_argument('--n-layers', type=int, default=1) 175 | parser.add_argument('--kernel-size', type=int, default=2) 176 | parser.add_argument('--d-emb', type=int, default=8) 177 | parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False) 178 | parser.add_argument('--merge', type=str, default='mlp') 179 | parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True) 180 | parser.add_argument('--dropout', type=float, default=0.) 181 | parser.add_argument('--autoencoder-mode', type=str_to_bool, nargs='?', const=True, default=False) 182 | return parser 183 | -------------------------------------------------------------------------------- /lib/nn/models/rgain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .rnn_imputers import BiRNNImputer 5 | from ...utils.parser_utils import str_to_bool 6 | 7 | 8 | class Generator(nn.Module): 9 | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=True): 10 | super(Generator, self).__init__() 11 | self.inject_noise = inject_noise 12 | self.d_z = d_z if inject_noise else 0 13 | self.birnn = BiRNNImputer(d_in, 14 | d_model, 15 | d_u=d_z, 16 | concat_mask=True, 17 | detach_inputs=False, 18 | dropout=dropout, 19 | state_init='zero') 20 | 21 | def forward(self, x, mask): 22 | if self.inject_noise: 23 | z = torch.rand(x.size(0), x.size(1), self.d_z, device=x.device) * 0.1 24 | else: 25 | z = None 26 | return self.birnn(x, mask, u=z) 27 | 28 | 29 | class Discriminator(torch.nn.Module): 30 | def __init__(self, d_in, d_model, dropout=0.): 31 | super(Discriminator, self).__init__() 32 | self.birnn = nn.GRU(2 * d_in, d_model, bidirectional=True, batch_first=True) 33 | self.dropout = nn.Dropout(dropout) 34 | self.read_out = nn.Linear(2 * d_model, d_in) 35 | 36 | def forward(self, x, h): 37 | x_in = torch.cat([x, h], dim=-1) 38 | out, _ = self.birnn(x_in) 39 | logits = self.read_out(self.dropout(out)) 40 | return logits 41 | 42 | 43 | class RGAINNet(torch.nn.Module): 44 | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=False, k=5): 45 | super(RGAINNet, self).__init__() 46 | self.inject_noise = inject_noise 47 | self.k = k 48 | self.generator = Generator(d_in, d_model, d_z=d_z, dropout=dropout, inject_noise=inject_noise) 49 | self.discriminator = Discriminator(d_in, d_model, dropout) 50 | 51 | def forward(self, x, mask, **kwargs): 52 | if not self.training and self.inject_noise: 53 | res = [] 54 | for _ in range(self.k): 55 | res.append(self.generator(x, mask)[0]) 56 | return torch.stack(res, 0).mean(0), 57 | 58 | return self.generator(x, mask) 59 | 60 | @staticmethod 61 | def add_model_specific_args(parser): 62 | parser.add_argument('--d-in', type=int) 63 | parser.add_argument('--d-model', type=int, default=None) 64 | parser.add_argument('--d-z', type=int, default=8) 65 | parser.add_argument('--k', type=int, default=5) 66 | parser.add_argument('--inject-noise', type=str_to_bool, nargs='?', const=True, default=False) 67 | parser.add_argument('--dropout', type=float, default=0.) 68 | return parser 69 | -------------------------------------------------------------------------------- /lib/nn/models/rnn_imputers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ..utils.ops import reverse_tensor 5 | 6 | 7 | class RNNImputer(nn.Module): 8 | """Fill the blanks with a 1-step-ahead GRU predictor.""" 9 | 10 | def __init__(self, d_in, d_model, concat_mask=True, detach_inputs=False, state_init='zero', d_u=0): 11 | super(RNNImputer, self).__init__() 12 | self.concat_mask = concat_mask 13 | self.detach_inputs = detach_inputs 14 | self.state_init = state_init 15 | self.d_model = d_model 16 | self.input_dim = d_in + d_u if not concat_mask else 2 * d_in + d_u 17 | self.rnn_cell = nn.GRUCell(self.input_dim, d_model) 18 | self.read_out = nn.Linear(d_model, d_in) 19 | 20 | def init_hidden_state(self, x): 21 | if self.state_init == 'zero': 22 | return torch.zeros((x.size(0), self.d_model), device=x.device, dtype=x.dtype) 23 | if self.state_init == 'noise': 24 | return torch.randn(x.size(0), self.d_model, device=x.device, dtype=x.dtype) 25 | 26 | def _preprocess_input(self, x, x_hat, m, u): 27 | if self.detach_inputs: 28 | x_p = torch.where(m, x, x_hat.detach()) 29 | else: 30 | x_p = torch.where(m, x, x_hat) 31 | 32 | if u is not None: 33 | x_p = torch.cat([x_p, u], -1) 34 | if self.concat_mask: 35 | x_p = torch.cat([x_p, m], -1) 36 | return x_p 37 | 38 | def forward(self, x, mask, u=None, return_hidden=False): 39 | # x: [batches, steps, features] 40 | steps = x.size(1) 41 | # ensure masked values are not visible 42 | x = torch.where(mask, x, torch.zeros_like(x)) 43 | 44 | h = self.init_hidden_state(x) 45 | x_hat = self.read_out(h) 46 | hs = [h] 47 | preds = [x_hat] 48 | for s in range(steps - 1): 49 | u_t = None if u is None else u[:, s] 50 | x_t = self._preprocess_input(x[:, s], x_hat, mask[:, s], u_t) 51 | h = self.rnn_cell(x_t, h) 52 | x_hat = self.read_out(h) 53 | hs.append(h) 54 | preds.append(x_hat) 55 | 56 | x_hat = torch.stack(preds, 1) 57 | h = torch.stack(hs, 1) 58 | if return_hidden: 59 | return x_hat, h 60 | return x_hat 61 | 62 | @staticmethod 63 | def add_model_specific_args(parser): 64 | parser.add_argument('--d-in', type=int) 65 | parser.add_argument('--d-model', type=int, default=None) 66 | return parser 67 | 68 | 69 | class BiRNNImputer(nn.Module): 70 | """Fill the blanks with a 1-step-ahead GRU predictor.""" 71 | 72 | def __init__(self, d_in, d_model, dropout=0., concat_mask=True, detach_inputs=False, state_init='zero', d_u=0): 73 | super(BiRNNImputer, self).__init__() 74 | self.d_model = d_model 75 | self.fwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init, 76 | d_u=d_u) 77 | self.bwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init, 78 | d_u=d_u) 79 | self.dropout = nn.Dropout(dropout) 80 | self.read_out = nn.Linear(2 * d_model, d_in) 81 | 82 | def forward(self, x, mask, u=None, return_hidden=False): 83 | # x: [batches, steps, features] 84 | x_hat_fwd, h_fwd = self.fwd_rnn(x, mask, u=u, return_hidden=True) 85 | x_hat_bwd, h_bwd = self.bwd_rnn(reverse_tensor(x, 1), 86 | reverse_tensor(mask, 1), 87 | u=reverse_tensor(u, 1) if u is not None else None, 88 | return_hidden=True) 89 | x_hat_bwd = reverse_tensor(x_hat_bwd, 1) 90 | h_bwd = reverse_tensor(h_bwd, 1) 91 | h = self.dropout(torch.cat([h_fwd, h_bwd], -1)) 92 | x_hat = self.read_out(h) 93 | if return_hidden: 94 | return (x_hat, x_hat_fwd, x_hat_bwd), h 95 | return x_hat, x_hat_fwd, x_hat_bwd 96 | 97 | @staticmethod 98 | def add_model_specific_args(parser): 99 | parser.add_argument('--d-in', type=int) 100 | parser.add_argument('--d-model', type=int, default=None) 101 | parser.add_argument('--dropout', type=float, default=0.) 102 | return parser 103 | -------------------------------------------------------------------------------- /lib/nn/models/var.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | from lib import epsilon 6 | 7 | 8 | class VAR(nn.Module): 9 | def __init__(self, order, d_in, d_out=None, steps_ahead=1, bias=True): 10 | super(VAR, self).__init__() 11 | self.order = order 12 | self.d_in = d_in 13 | self.d_out = d_out if d_out is not None else d_in 14 | self.steps_ahead = steps_ahead 15 | self.lin = nn.Linear(order * d_in, steps_ahead * self.d_out, bias=bias) 16 | 17 | def forward(self, x): 18 | # x: [batches, steps, features] 19 | x = rearrange(x, 'b s f -> b (s f)') 20 | out = self.lin(x) 21 | out = rearrange(out, 'b (s f) -> b s f', s=self.steps_ahead, f=self.d_out) 22 | return out 23 | 24 | @staticmethod 25 | def add_model_specific_args(parser): 26 | parser.add_argument('--order', type=int) 27 | parser.add_argument('--d-in', type=int) 28 | parser.add_argument('--d-out', type=int, default=None) 29 | parser.add_argument('--steps-ahead', type=int, default=1) 30 | return parser 31 | 32 | 33 | class VARImputer(nn.Module): 34 | """Fill the blanks with a 1-step-ahead VAR predictor.""" 35 | 36 | def __init__(self, order, d_in, padding='mean'): 37 | super(VARImputer, self).__init__() 38 | assert padding in ['mean', 'zero'] 39 | self.order = order 40 | self.padding = padding 41 | self.predictor = VAR(order, d_in, d_out=d_in, steps_ahead=1) 42 | 43 | def forward(self, x, mask=None): 44 | # x: [batches, steps, features] 45 | batch_size, steps, n_feats = x.shape 46 | if mask is None: 47 | mask = torch.ones_like(x, dtype=torch.uint8) 48 | x = x * mask 49 | # pad input sequence to start filling from first step 50 | if self.padding == 'mean': 51 | mean = torch.sum(x, 1) / (torch.sum(mask, 1) + epsilon) 52 | pad = torch.repeat_interleave(mean.unsqueeze(1), self.order, 1) 53 | elif self.padding == 'zero': 54 | pad = torch.zeros((batch_size, self.order, n_feats)).to(x.device) 55 | x = torch.cat([pad, x], 1) 56 | # x: [batch, order + steps, features] 57 | x = [x[:, i] for i in range(x.shape[1])] 58 | for s in range(steps): 59 | x_hat = self.predictor(torch.stack(x[s:s + self.order], 1)) 60 | x_hat = x_hat[:, 0] 61 | x[s + self.order] = torch.where(mask[:, s], x[s + self.order], x_hat) 62 | x = torch.stack(x[self.order:], 1) # remove padding 63 | return x 64 | 65 | @staticmethod 66 | def add_model_specific_args(parser): 67 | parser.add_argument('--order', type=int) 68 | parser.add_argument('--d-in', type=int) 69 | parser.add_argument("--padding", type=str, default='mean') 70 | return parser 71 | -------------------------------------------------------------------------------- /lib/nn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/grin/4a28afbb092600b6e6abeeabaaf67e87dbd1ed6e/lib/nn/utils/__init__.py -------------------------------------------------------------------------------- /lib/nn/utils/metric_base.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from pytorch_lightning.metrics import Metric 5 | from torchmetrics.utilities.checks import _check_same_shape 6 | 7 | 8 | class MaskedMetric(Metric): 9 | def __init__(self, 10 | metric_fn, 11 | mask_nans=False, 12 | mask_inf=False, 13 | compute_on_step=True, 14 | dist_sync_on_step=False, 15 | process_group=None, 16 | dist_sync_fn=None, 17 | metric_kwargs=None, 18 | at=None): 19 | super(MaskedMetric, self).__init__(compute_on_step=compute_on_step, 20 | dist_sync_on_step=dist_sync_on_step, 21 | process_group=process_group, 22 | dist_sync_fn=dist_sync_fn) 23 | 24 | if metric_kwargs is None: 25 | metric_kwargs = dict() 26 | self.metric_fn = partial(metric_fn, **metric_kwargs) 27 | self.mask_nans = mask_nans 28 | self.mask_inf = mask_inf 29 | if at is None: 30 | self.at = slice(None) 31 | else: 32 | self.at = slice(at, at + 1) 33 | self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0.).float()) 34 | self.add_state('numel', dist_reduce_fx='sum', default=torch.tensor(0)) 35 | 36 | def _check_mask(self, mask, val): 37 | if mask is None: 38 | mask = torch.ones_like(val).byte() 39 | else: 40 | _check_same_shape(mask, val) 41 | if self.mask_nans: 42 | mask = mask * ~torch.isnan(val) 43 | if self.mask_inf: 44 | mask = mask * ~torch.isinf(val) 45 | return mask 46 | 47 | def _compute_masked(self, y_hat, y, mask): 48 | _check_same_shape(y_hat, y) 49 | val = self.metric_fn(y_hat, y) 50 | mask = self._check_mask(mask, val) 51 | val = torch.where(mask, val, torch.tensor(0., device=val.device).float()) 52 | return val.sum(), mask.sum() 53 | 54 | def _compute_std(self, y_hat, y): 55 | _check_same_shape(y_hat, y) 56 | val = self.metric_fn(y_hat, y) 57 | return val.sum(), val.numel() 58 | 59 | def is_masked(self, mask): 60 | return self.mask_inf or self.mask_nans or (mask is not None) 61 | 62 | def update(self, y_hat, y, mask=None): 63 | y_hat = y_hat[:, self.at] 64 | y = y[:, self.at] 65 | if mask is not None: 66 | mask = mask[:, self.at] 67 | if self.is_masked(mask): 68 | val, numel = self._compute_masked(y_hat, y, mask) 69 | else: 70 | val, numel = self._compute_std(y_hat, y) 71 | self.value += val 72 | self.numel += numel 73 | 74 | def compute(self): 75 | if self.numel > 0: 76 | return self.value / self.numel 77 | return self.value 78 | -------------------------------------------------------------------------------- /lib/nn/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from .metric_base import MaskedMetric 2 | from .ops import mape 3 | from torch.nn import functional as F 4 | import torch 5 | 6 | from torchmetrics.utilities.checks import _check_same_shape 7 | 8 | from ... import epsilon 9 | 10 | 11 | class MaskedMAE(MaskedMetric): 12 | def __init__(self, 13 | mask_nans=False, 14 | mask_inf=False, 15 | compute_on_step=True, 16 | dist_sync_on_step=False, 17 | process_group=None, 18 | dist_sync_fn=None, 19 | at=None): 20 | super(MaskedMAE, self).__init__(metric_fn=F.l1_loss, 21 | mask_nans=mask_nans, 22 | mask_inf=mask_inf, 23 | compute_on_step=compute_on_step, 24 | dist_sync_on_step=dist_sync_on_step, 25 | process_group=process_group, 26 | dist_sync_fn=dist_sync_fn, 27 | metric_kwargs={'reduction': 'none'}, 28 | at=at) 29 | 30 | 31 | class MaskedMAPE(MaskedMetric): 32 | def __init__(self, 33 | mask_nans=False, 34 | compute_on_step=True, 35 | dist_sync_on_step=False, 36 | process_group=None, 37 | dist_sync_fn=None, 38 | at=None): 39 | super(MaskedMAPE, self).__init__(metric_fn=mape, 40 | mask_nans=mask_nans, 41 | mask_inf=True, 42 | compute_on_step=compute_on_step, 43 | dist_sync_on_step=dist_sync_on_step, 44 | process_group=process_group, 45 | dist_sync_fn=dist_sync_fn, 46 | at=at) 47 | 48 | 49 | class MaskedMSE(MaskedMetric): 50 | def __init__(self, 51 | mask_nans=False, 52 | compute_on_step=True, 53 | dist_sync_on_step=False, 54 | process_group=None, 55 | dist_sync_fn=None, 56 | at=None): 57 | super(MaskedMSE, self).__init__(metric_fn=F.mse_loss, 58 | mask_nans=mask_nans, 59 | mask_inf=True, 60 | compute_on_step=compute_on_step, 61 | dist_sync_on_step=dist_sync_on_step, 62 | process_group=process_group, 63 | dist_sync_fn=dist_sync_fn, 64 | metric_kwargs={'reduction': 'none'}, 65 | at=at) 66 | 67 | 68 | class MaskedMRE(MaskedMetric): 69 | def __init__(self, 70 | mask_nans=False, 71 | mask_inf=False, 72 | compute_on_step=True, 73 | dist_sync_on_step=False, 74 | process_group=None, 75 | dist_sync_fn=None, 76 | at=None): 77 | super(MaskedMRE, self).__init__(metric_fn=F.l1_loss, 78 | mask_nans=mask_nans, 79 | mask_inf=mask_inf, 80 | compute_on_step=compute_on_step, 81 | dist_sync_on_step=dist_sync_on_step, 82 | process_group=process_group, 83 | dist_sync_fn=dist_sync_fn, 84 | metric_kwargs={'reduction': 'none'}, 85 | at=at) 86 | self.add_state('tot', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float)) 87 | 88 | def _compute_masked(self, y_hat, y, mask): 89 | _check_same_shape(y_hat, y) 90 | val = self.metric_fn(y_hat, y) 91 | mask = self._check_mask(mask, val) 92 | val = torch.where(mask, val, torch.tensor(0., device=y.device, dtype=torch.float)) 93 | y_masked = torch.where(mask, y, torch.tensor(0., device=y.device, dtype=torch.float)) 94 | return val.sum(), mask.sum(), y_masked.sum() 95 | 96 | def _compute_std(self, y_hat, y): 97 | _check_same_shape(y_hat, y) 98 | val = self.metric_fn(y_hat, y) 99 | return val.sum(), val.numel(), y.sum() 100 | 101 | def compute(self): 102 | if self.tot > epsilon: 103 | return self.value / self.tot 104 | return self.value 105 | 106 | def update(self, y_hat, y, mask=None): 107 | y_hat = y_hat[:, self.at] 108 | y = y[:, self.at] 109 | if mask is not None: 110 | mask = mask[:, self.at] 111 | if self.is_masked(mask): 112 | val, numel, tot = self._compute_masked(y_hat, y, mask) 113 | else: 114 | val, numel, tot = self._compute_std(y_hat, y) 115 | self.value += val 116 | self.numel += numel 117 | self.tot += tot 118 | 119 | 120 | -------------------------------------------------------------------------------- /lib/nn/utils/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import reduce 4 | from torch.autograd import Variable 5 | 6 | from ... import epsilon 7 | 8 | 9 | def mae(y_hat, y, reduction='none'): 10 | return F.l1_loss(y_hat, y, reduction=reduction) 11 | 12 | 13 | def mape(y_hat, y): 14 | return torch.abs((y_hat - y) / y) 15 | 16 | 17 | def wape_loss(y_hat, y): 18 | l = torch.abs(y_hat - y) 19 | return l.sum() / (y.sum() + epsilon) 20 | 21 | 22 | def smape_loss(y_hat, y): 23 | c = torch.abs(y) > epsilon 24 | l_minus = torch.abs(y_hat - y) 25 | l_plus = torch.abs(y_hat + y) + epsilon 26 | l = 2 * l_minus / l_plus * c.float() 27 | return l.sum() / c.sum() 28 | 29 | 30 | def peak_prediction_loss(y_hat, y, reduction='none'): 31 | y_max = reduce(y, 'b s n 1 -> b 1 n 1', 'max') 32 | y_min = reduce(y, 'b s n 1 -> b 1 n 1', 'min') 33 | target = torch.cat([y_max, y_min], dim=1) 34 | return F.mse_loss(y_hat, target, reduction=reduction) 35 | 36 | 37 | def wrap_loss_fn(base_loss): 38 | def loss_fn(y_hat, y_true, mask=None): 39 | scaling = 1. 40 | if mask is not None: 41 | try: 42 | loss = base_loss(y_hat, y_true, reduction='none') 43 | except TypeError: 44 | loss = base_loss(y_hat, y_true) 45 | loss = loss * mask 46 | loss = loss.sum() / (mask.sum() + epsilon) 47 | # scaling = mask.sum() / torch.numel(mask) 48 | else: 49 | loss = base_loss(y_hat, y_true).mean() 50 | return scaling * loss 51 | 52 | return loss_fn 53 | 54 | 55 | def rbf_sim(x, gamma, device='cpu'): 56 | n = x.size()[0] 57 | a = torch.exp(-gamma * F.pdist(x, 2) ** 2) 58 | row_idx, col_idx = torch.triu_indices(n, n, 1) 59 | A = 0.5 * torch.eye(n, n).to(device) 60 | A[row_idx, col_idx] = a 61 | return A + A.T 62 | 63 | 64 | def reverse_tensor(tensor=None, axis=-1): 65 | if tensor is None: 66 | return None 67 | if tensor.dim() <= 1: 68 | return tensor 69 | indices = range(tensor.size()[axis])[::-1] 70 | indices = Variable(torch.LongTensor(indices), requires_grad=False).to(tensor.device) 71 | return tensor.index_select(axis, indices) 72 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /lib/utils/numpy_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def mae(y_hat, y): 5 | return np.abs(y_hat - y).mean() 6 | 7 | 8 | def nmae(y_hat, y): 9 | delta = np.max(y) - np.min(y) + 1e-8 10 | return mae(y_hat, y) * 100 / delta 11 | 12 | 13 | def mape(y_hat, y): 14 | return 100 * np.abs((y_hat - y) / (y + 1e-8)).mean() 15 | 16 | 17 | def mse(y_hat, y): 18 | return np.square(y_hat - y).mean() 19 | 20 | 21 | def rmse(y_hat, y): 22 | return np.sqrt(mse(y_hat, y)) 23 | 24 | 25 | def nrmse(y_hat, y): 26 | delta = np.max(y) - np.min(y) + 1e-8 27 | return rmse(y_hat, y) * 100 / delta 28 | 29 | 30 | def nrmse_2(y_hat, y): 31 | nrmse_ = np.sqrt(np.square(y_hat - y).sum() / np.square(y).sum()) 32 | return nrmse_ * 100 33 | 34 | 35 | def r2(y_hat, y): 36 | return 1. - np.square(y_hat - y).sum() / (np.square(y.mean(0) - y).sum()) 37 | 38 | 39 | def masked_mae(y_hat, y, mask): 40 | err = np.abs(y_hat - y) * mask 41 | return err.sum() / mask.sum() 42 | 43 | 44 | def masked_mape(y_hat, y, mask): 45 | err = np.abs((y_hat - y) / (y + 1e-8)) * mask 46 | return err.sum() / mask.sum() 47 | 48 | 49 | def masked_mse(y_hat, y, mask): 50 | err = np.square(y_hat - y) * mask 51 | return err.sum() / mask.sum() 52 | 53 | 54 | def masked_rmse(y_hat, y, mask): 55 | err = np.square(y_hat - y) * mask 56 | return np.sqrt(err.sum() / mask.sum()) 57 | 58 | 59 | def masked_mre(y_hat, y, mask): 60 | err = np.abs(y_hat - y) * mask 61 | return err.sum() / ((y * mask).sum() + 1e-8) 62 | -------------------------------------------------------------------------------- /lib/utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from argparse import Namespace, ArgumentParser 3 | from typing import Union 4 | 5 | 6 | def str_to_bool(value): 7 | if isinstance(value, bool): 8 | return value 9 | if value.lower() in {'false', 'f', '0', 'no', 'n', 'off'}: 10 | return False 11 | elif value.lower() in {'true', 't', '1', 'yes', 'y', 'on'}: 12 | return True 13 | raise ValueError(f'{value} is not a valid boolean value') 14 | 15 | 16 | def config_dict_from_args(args): 17 | """ 18 | Extract a dictionary with the experiment configuration from arguments (necessary to filter TestTube arguments) 19 | 20 | :param args: TTNamespace 21 | :return: hyparams dict 22 | """ 23 | keys_to_remove = {'hpc_exp_number', 'trials', 'optimize_parallel', 'optimize_parallel_gpu', 24 | 'optimize_parallel_cpu', 'generate_trials', 'optimize_trials_parallel_gpu'} 25 | hparams = {key: v for key, v in args.__dict__.items() if key not in keys_to_remove} 26 | return hparams 27 | 28 | 29 | def update_from_config(args: Namespace, config: dict): 30 | assert set(config.keys()) <= set(vars(args)), f'{set(config.keys()).difference(vars(args))} not in args.' 31 | args.__dict__.update(config) 32 | return args 33 | 34 | 35 | def parse_by_group(parser): 36 | """ 37 | Create a nested namespace using the groups defined in the argument parser. 38 | Adapted from https://stackoverflow.com/a/56631542/6524027 39 | 40 | :param args: arguments 41 | :param parser: the parser 42 | :return: 43 | """ 44 | assert isinstance(parser, ArgumentParser) 45 | args = parser.parse_args() 46 | 47 | # the first two argument groups are 'positional_arguments' and 'optional_arguments' 48 | pos_group, optional_group = parser._action_groups[0], parser._action_groups[1] 49 | args_dict = args._get_kwargs() 50 | pos_optional_arg_names = [arg.dest for arg in pos_group._group_actions] + [arg.dest for arg in 51 | optional_group._group_actions] 52 | pos_optional_args = {name: value for name, value in args_dict if name in pos_optional_arg_names} 53 | other_group_args = dict() 54 | 55 | # If there are additional argument groups, add them as nested namespaces 56 | if len(parser._action_groups) > 2: 57 | for group in parser._action_groups[2:]: 58 | group_arg_names = [arg.dest for arg in group._group_actions] 59 | other_group_args[group.title] = Namespace( 60 | **{name: value for name, value in args_dict if name in group_arg_names}) 61 | 62 | # combine the positiona/optional args and the group args 63 | combined_args = pos_optional_args 64 | combined_args.update(other_group_args) 65 | return Namespace(flat=args, **combined_args) 66 | 67 | 68 | def filter_args(args: Union[Namespace, dict], target_cls, return_dict=False): 69 | argspec = inspect.getfullargspec(target_cls.__init__) 70 | target_args = argspec.args 71 | if isinstance(args, Namespace): 72 | args = vars(args) 73 | filtered_args = {k: args[k] for k in target_args if k in args} 74 | if return_dict: 75 | return filtered_args 76 | return Namespace(**filtered_args) 77 | 78 | 79 | def filter_function_args(args: Union[Namespace, dict], function, return_dict=False): 80 | argspec = inspect.getfullargspec(function) 81 | target_args = argspec.args 82 | if isinstance(args, Namespace): 83 | args = vars(args) 84 | filtered_args = {k: args[k] for k in target_args if k in args} 85 | if return_dict: 86 | return filtered_args 87 | return Namespace(**filtered_args) 88 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from sklearn.metrics.pairwise import haversine_distances 5 | 6 | 7 | def sample_mask(shape, p=0.002, p_noise=0., max_seq=1, min_seq=1, rng=None): 8 | if rng is None: 9 | rand = np.random.random 10 | randint = np.random.randint 11 | else: 12 | rand = rng.random 13 | randint = rng.integers 14 | mask = rand(shape) < p 15 | for col in range(mask.shape[1]): 16 | idxs = np.flatnonzero(mask[:, col]) 17 | if not len(idxs): 18 | continue 19 | fault_len = min_seq 20 | if max_seq > min_seq: 21 | fault_len = fault_len + int(randint(max_seq - min_seq)) 22 | idxs_ext = np.concatenate([np.arange(i, i + fault_len) for i in idxs]) 23 | idxs = np.unique(idxs_ext) 24 | idxs = np.clip(idxs, 0, shape[0] - 1) 25 | mask[idxs, col] = True 26 | mask = mask | (rand(mask.shape) < p_noise) 27 | return mask.astype('uint8') 28 | 29 | 30 | def compute_mean(x, index=None): 31 | """Compute the mean values for each datetime. The mean is first computed hourly over the week of the year. 32 | Further NaN values are computed using hourly mean over the same month through the years. If other NaN are present, 33 | they are removed using the mean of the sole hours. Hoping reasonably that there is at least a non-NaN entry of the 34 | same hour of the NaN datetime in all the dataset.""" 35 | if isinstance(x, np.ndarray) and index is not None: 36 | shape = x.shape 37 | x = x.reshape((shape[0], -1)) 38 | df_mean = pd.DataFrame(x, index=index) 39 | else: 40 | df_mean = x.copy() 41 | cond0 = [df_mean.index.year, df_mean.index.isocalendar().week, df_mean.index.hour] 42 | cond1 = [df_mean.index.year, df_mean.index.month, df_mean.index.hour] 43 | conditions = [cond0, cond1, cond1[1:], cond1[2:]] 44 | while df_mean.isna().values.sum() and len(conditions): 45 | nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean) 46 | df_mean = df_mean.fillna(nan_mean) 47 | conditions = conditions[1:] 48 | if df_mean.isna().values.sum(): 49 | df_mean = df_mean.fillna(method='ffill') 50 | df_mean = df_mean.fillna(method='bfill') 51 | if isinstance(x, np.ndarray): 52 | df_mean = df_mean.values.reshape(shape) 53 | return df_mean 54 | 55 | 56 | def geographical_distance(x=None, to_rad=True): 57 | """ 58 | Compute the as-the-crow-flies distance between every pair of samples in `x`. The first dimension of each point is 59 | assumed to be the latitude, the second is the longitude. The inputs is assumed to be in degrees. If it is not the 60 | case, `to_rad` must be set to False. The dimension of the data must be 2. 61 | 62 | Parameters 63 | ---------- 64 | x : pd.DataFrame or np.ndarray 65 | array_like structure of shape (n_samples_2, 2). 66 | to_rad : bool 67 | whether to convert inputs to radians (provided that they are in degrees). 68 | 69 | Returns 70 | ------- 71 | distances : 72 | The distance between the points in kilometers. 73 | """ 74 | _AVG_EARTH_RADIUS_KM = 6371.0088 75 | 76 | # Extract values of X if it is a DataFrame, else assume it is 2-dim array of lat-lon pairs 77 | latlon_pairs = x.values if isinstance(x, pd.DataFrame) else x 78 | 79 | # If the input values are in degrees, convert them in radians 80 | if to_rad: 81 | latlon_pairs = np.vectorize(np.radians)(latlon_pairs) 82 | 83 | distances = haversine_distances(latlon_pairs) * _AVG_EARTH_RADIUS_KM 84 | 85 | # Cast response 86 | if isinstance(x, pd.DataFrame): 87 | res = pd.DataFrame(distances, x.index, x.index) 88 | else: 89 | res = distances 90 | 91 | return res 92 | 93 | 94 | def infer_mask(df, infer_from='next'): 95 | """Infer evaluation mask from DataFrame. In the evaluation mask a value is 1 if it is present in the DataFrame and 96 | absent in the `infer_from` month. 97 | 98 | @param pd.DataFrame df: the DataFrame. 99 | @param str infer_from: denotes from which month the evaluation value must be inferred. 100 | Can be either `previous` or `next`. 101 | @return: pd.DataFrame eval_mask: the evaluation mask for the DataFrame 102 | """ 103 | mask = (~df.isna()).astype('uint8') 104 | eval_mask = pd.DataFrame(index=mask.index, columns=mask.columns, data=0).astype('uint8') 105 | if infer_from == 'previous': 106 | offset = -1 107 | elif infer_from == 'next': 108 | offset = 1 109 | else: 110 | raise ValueError('infer_from can only be one of %s' % ['previous', 'next']) 111 | months = sorted(set(zip(mask.index.year, mask.index.month))) 112 | length = len(months) 113 | for i in range(length): 114 | j = (i + offset) % length 115 | year_i, month_i = months[i] 116 | year_j, month_j = months[j] 117 | mask_j = mask[(mask.index.year == year_j) & (mask.index.month == month_j)] 118 | mask_i = mask_j.shift(1, pd.DateOffset(months=12 * (year_i - year_j) + (month_i - month_j))) 119 | mask_i = mask_i[~mask_i.index.duplicated(keep='first')] 120 | mask_i = mask_i[np.in1d(mask_i.index, mask.index)] 121 | eval_mask.loc[mask_i.index] = ~mask_i.loc[mask_i.index] & mask.loc[mask_i.index] 122 | return eval_mask 123 | 124 | 125 | def prediction_dataframe(y, index, columns=None, aggregate_by='mean'): 126 | """Aggregate batched predictions in a single DataFrame. 127 | 128 | @param (list or np.ndarray) y: the list of predictions. 129 | @param (list or np.ndarray) index: the list of time indexes coupled with the predictions. 130 | @param (list or pd.Index) columns: the columns of the returned DataFrame. 131 | @param (str or list) aggregate_by: how to aggregate the predictions in case there are more than one for a step. 132 | - `mean`: take the mean of the predictions 133 | - `central`: take the prediction at the central position, assuming that the predictions are ordered chronologically 134 | - `smooth_central`: average the predictions weighted by a gaussian signal with std=1 135 | - `last`: take the last prediction 136 | @return: pd.DataFrame df: the evaluation mask for the DataFrame 137 | """ 138 | dfs = [pd.DataFrame(data=data.reshape(data.shape[:2]), index=idx, columns=columns) for data, idx in zip(y, index)] 139 | df = pd.concat(dfs) 140 | preds_by_step = df.groupby(df.index) 141 | # aggregate according passed methods 142 | aggr_methods = ensure_list(aggregate_by) 143 | dfs = [] 144 | for aggr_by in aggr_methods: 145 | if aggr_by == 'mean': 146 | dfs.append(preds_by_step.mean()) 147 | elif aggr_by == 'central': 148 | dfs.append(preds_by_step.aggregate(lambda x: x[int(len(x) // 2)])) 149 | elif aggr_by == 'smooth_central': 150 | from scipy.signal import gaussian 151 | dfs.append(preds_by_step.aggregate(lambda x: np.average(x, weights=gaussian(len(x), 1)))) 152 | elif aggr_by == 'last': 153 | dfs.append(preds_by_step.aggregate(lambda x: x[0])) # first imputation has missing value in last position 154 | else: 155 | raise ValueError('aggregate_by can only be one of %s' % ['mean', 'central' 'smooth_central', 'last']) 156 | if isinstance(aggregate_by, str): 157 | return dfs[0] 158 | return dfs 159 | 160 | 161 | def ensure_list(obj): 162 | if isinstance(obj, (list, tuple)): 163 | return list(obj) 164 | else: 165 | return [obj] 166 | 167 | 168 | def missing_val_lens(mask): 169 | m = np.concatenate([np.zeros((1, mask.shape[1])), 170 | (~mask.astype('bool')).astype('int'), 171 | np.zeros((1, mask.shape[1]))]) 172 | mdiff = np.diff(m, axis=0) 173 | lens = [] 174 | for c in range(m.shape[1]): 175 | mj, = mdiff[:, c].nonzero() 176 | diff = np.diff(mj)[::2] 177 | lens.extend(list(diff)) 178 | return lens 179 | 180 | 181 | def disjoint_months(dataset, months=None, synch_mode='window'): 182 | idxs = np.arange(len(dataset)) 183 | months = ensure_list(months) 184 | # divide indices according to window or horizon 185 | if synch_mode == 'window': 186 | start, end = 0, dataset.window - 1 187 | elif synch_mode == 'horizon': 188 | start, end = dataset.horizon_offset, dataset.horizon_offset + dataset.horizon - 1 189 | else: 190 | raise ValueError('synch_mode can only be one of %s' % ['window', 'horizon']) 191 | # after idxs 192 | start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months) 193 | end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months) 194 | idxs_in_months = start_in_months & end_in_months 195 | after_idxs = idxs[idxs_in_months] 196 | # previous idxs 197 | months = np.setdiff1d(np.arange(1, 13), months) 198 | start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months) 199 | end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months) 200 | idxs_in_months = start_in_months & end_in_months 201 | prev_idxs = idxs[idxs_in_months] 202 | return prev_idxs, after_idxs 203 | 204 | 205 | def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False): 206 | if theta is None: 207 | theta = np.std(x) 208 | weights = np.exp(-np.square(x / theta)) 209 | if threshold is not None: 210 | mask = x > threshold if threshold_on_input else weights < threshold 211 | weights[mask] = 0. 212 | return weights 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | fancyimpute==0.6 3 | h5py 4 | openpyxl 5 | numpy 6 | pandas 7 | pytorch-lightning==1.4 8 | pyyaml 9 | scikit-learn 10 | scipy 11 | tables 12 | tensorboard 13 | tensorflow==2.5.0 14 | tensorflow-gpu==2.4.0 15 | torch==1.8 16 | torchvision 17 | torchaudio 18 | torchmetrics==0.5 19 | -------------------------------------------------------------------------------- /scripts/run_baselines.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import numpy as np 4 | from fancyimpute import MatrixFactorization, IterativeImputer 5 | from sklearn.neighbors import kneighbors_graph 6 | 7 | from lib import datasets 8 | from lib.utils import numpy_metrics 9 | from lib.utils.parser_utils import str_to_bool 10 | 11 | metrics = { 12 | 'mae': numpy_metrics.masked_mae, 13 | 'mse': numpy_metrics.masked_mse, 14 | 'mre': numpy_metrics.masked_mre, 15 | 'mape': numpy_metrics.masked_mape 16 | } 17 | 18 | 19 | def parse_args(): 20 | parser = ArgumentParser() 21 | # experiment setting 22 | parser.add_argument('--datasets', nargs='+', type=str, default=['all']) 23 | parser.add_argument('--imputers', nargs='+', type=str, default=['all']) 24 | parser.add_argument('--n-runs', type=int, default=5) 25 | parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=True) 26 | # SpatialKNNImputer params 27 | parser.add_argument('--k', type=int, default=10) 28 | # MFImputer params 29 | parser.add_argument('--rank', type=int, default=10) 30 | # MICEImputer params 31 | parser.add_argument('--mice-iterations', type=int, default=100) 32 | parser.add_argument('--mice-n-features', type=int, default=None) 33 | args = parser.parse_args() 34 | # parse dataset 35 | if args.datasets[0] == 'all': 36 | args.datasets = ['air36', 'air', 'bay', 'irish', 'la', 'bay_noise', 'irish_noise', 'la_noise'] 37 | # parse imputers 38 | if args.imputers[0] == 'all': 39 | args.imputers = ['mean', 'knn', 'mf', 'mice'] 40 | if not args.in_sample: 41 | args.imputers = [name for name in args.imputers if name in ['mean', 'mice']] 42 | return args 43 | 44 | 45 | class Imputer: 46 | short_name: str 47 | 48 | def __init__(self, method=None, is_deterministic=True, in_sample=True): 49 | self.name = self.__class__.__name__ 50 | self.method = method 51 | self.is_deterministic = is_deterministic 52 | self.in_sample = in_sample 53 | 54 | def fit(self, x, mask): 55 | if not self.in_sample: 56 | x_hat = np.where(mask, x, np.nan) 57 | return self.method.fit(x_hat) 58 | 59 | def predict(self, x, mask): 60 | x_hat = np.where(mask, x, np.nan) 61 | if self.in_sample: 62 | return self.method.fit_transform(x_hat) 63 | else: 64 | return self.method.transform(x_hat) 65 | 66 | def params(self): 67 | return dict() 68 | 69 | 70 | class SpatialKNNImputer(Imputer): 71 | short_name = 'knn' 72 | 73 | def __init__(self, adj, k=20): 74 | super(SpatialKNNImputer, self).__init__() 75 | self.k = k 76 | # normalize sim between [0, 1] 77 | sim = (adj + adj.min()) / (adj.max() + adj.min()) 78 | knns = kneighbors_graph(1 - sim, 79 | n_neighbors=self.k, 80 | include_self=False, 81 | metric='precomputed').toarray() 82 | self.knns = knns 83 | 84 | def fit(self, x, mask): 85 | pass 86 | 87 | def predict(self, x, mask): 88 | x = np.where(mask, x, 0) 89 | with np.errstate(divide='ignore', invalid='ignore'): 90 | y_hat = (x @ self.knns.T) / (mask @ self.knns.T) 91 | y_hat[~np.isfinite(y_hat)] = x.mean() 92 | return np.where(mask, x, y_hat) 93 | 94 | def params(self): 95 | return dict(k=self.k) 96 | 97 | 98 | class MeanImputer(Imputer): 99 | short_name = 'mean' 100 | 101 | def fit(self, x, mask): 102 | d = np.where(mask, x, np.nan) 103 | self.means = np.nanmean(d, axis=0, keepdims=True) 104 | 105 | def predict(self, x, mask): 106 | if self.in_sample: 107 | d = np.where(mask, x, np.nan) 108 | means = np.nanmean(d, axis=0, keepdims=True) 109 | else: 110 | means = self.means 111 | return np.where(mask, x, means) 112 | 113 | 114 | class MatrixFactorizationImputer(Imputer): 115 | short_name = 'mf' 116 | 117 | def __init__(self, rank=10, loss='mae', verbose=0): 118 | method = MatrixFactorization(rank=rank, loss=loss, verbose=verbose) 119 | super(MatrixFactorizationImputer, self).__init__(method, is_deterministic=False, in_sample=True) 120 | 121 | def params(self): 122 | return dict(rank=self.method.rank) 123 | 124 | 125 | class MICEImputer(Imputer): 126 | short_name = 'mice' 127 | 128 | def __init__(self, max_iter=100, n_nearest_features=None, in_sample=True, verbose=False): 129 | method = IterativeImputer(max_iter=max_iter, n_nearest_features=n_nearest_features, verbose=verbose) 130 | is_deterministic = n_nearest_features is None 131 | super(MICEImputer, self).__init__(method, is_deterministic=is_deterministic, in_sample=in_sample) 132 | 133 | def params(self): 134 | return dict(max_iter=self.method.max_iter, k=self.method.n_nearest_features or -1) 135 | 136 | 137 | def get_dataset(dataset_name): 138 | if dataset_name[:3] == 'air': 139 | dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36') 140 | elif dataset_name == 'bay': 141 | dataset = datasets.MissingValuesPemsBay() 142 | elif dataset_name == 'la': 143 | dataset = datasets.MissingValuesMetrLA() 144 | elif dataset_name == 'la_noise': 145 | dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25) 146 | elif dataset_name == 'bay_noise': 147 | dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25) 148 | else: 149 | raise ValueError(f"Dataset {dataset_name} not available in this setting.") 150 | 151 | # split in train/test 152 | if isinstance(dataset, datasets.AirQuality): 153 | test_slice = np.in1d(dataset.df.index.month, dataset.test_months) 154 | train_slice = ~test_slice 155 | else: 156 | train_slice = np.zeros(len(dataset)).astype(bool) 157 | train_slice[:-int(0.2 * len(dataset))] = True 158 | 159 | # integrate back eval values in dataset 160 | dataset.eval_mask[train_slice] = 0 161 | 162 | return dataset, train_slice 163 | 164 | 165 | def get_imputer(imputer_name, args): 166 | if imputer_name == 'mean': 167 | imputer = MeanImputer(in_sample=args.in_sample) 168 | elif imputer_name == 'knn': 169 | imputer = SpatialKNNImputer(adj=args.adj, k=args.k) 170 | elif imputer_name == 'mf': 171 | imputer = MatrixFactorizationImputer(rank=args.rank) 172 | elif imputer_name == 'mice': 173 | imputer = MICEImputer(max_iter=args.mice_iterations, 174 | n_nearest_features=args.mice_n_features, 175 | in_sample=args.in_sample) 176 | else: 177 | raise ValueError(f"Imputer {imputer_name} not available in this setting.") 178 | return imputer 179 | 180 | 181 | def run(imputer, dataset, train_slice): 182 | test_slice = ~train_slice 183 | if args.in_sample: 184 | x_train, mask_train = dataset.numpy(), dataset.training_mask 185 | y_hat = imputer.predict(x_train, mask_train)[test_slice] 186 | else: 187 | x_train, mask_train = dataset.numpy()[train_slice], dataset.training_mask[train_slice] 188 | imputer.fit(x_train, mask_train) 189 | x_test, mask_test = dataset.numpy()[test_slice], dataset.training_mask[test_slice] 190 | y_hat = imputer.predict(x_test, mask_test) 191 | 192 | # Evaluate model 193 | y_true = dataset.numpy()[test_slice] 194 | eval_mask = dataset.eval_mask[test_slice] 195 | 196 | for metric, metric_fn in metrics.items(): 197 | error = metric_fn(y_hat, y_true, eval_mask) 198 | print(f'{imputer.name} on {ds_name} {metric}: {error:.4f}') 199 | 200 | 201 | if __name__ == '__main__': 202 | 203 | args = parse_args() 204 | print(args.__dict__) 205 | 206 | for ds_name in args.datasets: 207 | 208 | dataset, train_slice = get_dataset(ds_name) 209 | args.adj = dataset.get_similarity(thr=0.1) 210 | 211 | # Instantiate imputers 212 | imputers = [get_imputer(name, args) for name in args.imputers] 213 | 214 | for imputer in imputers: 215 | n_runs = 1 if imputer.is_deterministic else args.n_runs 216 | for _ in range(n_runs): 217 | run(imputer, dataset, train_slice) 218 | -------------------------------------------------------------------------------- /scripts/run_imputation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import os 4 | import pathlib 5 | from argparse import ArgumentParser 6 | 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn.functional as F 11 | import yaml 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | 16 | from lib import fillers, datasets, config 17 | from lib.data.datamodule import SpatioTemporalDataModule 18 | from lib.data.imputation_dataset import ImputationDataset, GraphImputationDataset 19 | from lib.nn import models 20 | from lib.nn.utils.metric_base import MaskedMetric 21 | from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE 22 | from lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe 23 | from lib.utils.parser_utils import str_to_bool 24 | 25 | 26 | def has_graph_support(model_cls): 27 | return model_cls in [models.GRINet, models.MPGRUNet, models.BiMPGRUNet] 28 | 29 | 30 | def get_model_classes(model_str): 31 | if model_str == 'brits': 32 | model, filler = models.BRITSNet, fillers.BRITSFiller 33 | elif model_str == 'grin': 34 | model, filler = models.GRINet, fillers.GraphFiller 35 | elif model_str == 'mpgru': 36 | model, filler = models.MPGRUNet, fillers.GraphFiller 37 | elif model_str == 'bimpgru': 38 | model, filler = models.BiMPGRUNet, fillers.GraphFiller 39 | elif model_str == 'var': 40 | model, filler = models.VARImputer, fillers.Filler 41 | elif model_str == 'gain': 42 | model, filler = models.RGAINNet, fillers.RGAINFiller 43 | elif model_str == 'birnn': 44 | model, filler = models.BiRNNImputer, fillers.MultiImputationFiller 45 | elif model_str == 'rnn': 46 | model, filler = models.RNNImputer, fillers.Filler 47 | else: 48 | raise ValueError(f'Model {model_str} not available.') 49 | return model, filler 50 | 51 | 52 | def get_dataset(dataset_name): 53 | if dataset_name[:3] == 'air': 54 | dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36') 55 | elif dataset_name == 'bay_block': 56 | dataset = datasets.MissingValuesPemsBay() 57 | elif dataset_name == 'la_block': 58 | dataset = datasets.MissingValuesMetrLA() 59 | elif dataset_name == 'la_point': 60 | dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25) 61 | elif dataset_name == 'bay_point': 62 | dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25) 63 | else: 64 | raise ValueError(f"Dataset {dataset_name} not available in this setting.") 65 | return dataset 66 | 67 | 68 | def parse_args(): 69 | # Argument parser 70 | parser = ArgumentParser() 71 | parser.add_argument('--seed', type=int, default=-1) 72 | parser.add_argument("--model-name", type=str, default='brits') 73 | parser.add_argument("--dataset-name", type=str, default='air36') 74 | parser.add_argument("--config", type=str, default=None) 75 | # Splitting/aggregation params 76 | parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=False) 77 | parser.add_argument('--val-len', type=float, default=0.1) 78 | parser.add_argument('--test-len', type=float, default=0.2) 79 | parser.add_argument('--aggregate-by', type=str, default='mean') 80 | # Training params 81 | parser.add_argument('--lr', type=float, default=0.001) 82 | parser.add_argument('--epochs', type=int, default=300) 83 | parser.add_argument('--patience', type=int, default=40) 84 | parser.add_argument('--l2-reg', type=float, default=0.) 85 | parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=True) 86 | parser.add_argument('--grad-clip-val', type=float, default=5.) 87 | parser.add_argument('--grad-clip-algorithm', type=str, default='norm') 88 | parser.add_argument('--loss-fn', type=str, default='l1_loss') 89 | parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True) 90 | parser.add_argument('--consistency-loss', type=str_to_bool, nargs='?', const=True, default=False) 91 | parser.add_argument('--whiten-prob', type=float, default=0.05) 92 | parser.add_argument('--pred-loss-weight', type=float, default=1.0) 93 | parser.add_argument('--warm-up', type=int, default=0) 94 | # graph params 95 | parser.add_argument("--adj-threshold", type=float, default=0.1) 96 | # gain hparams 97 | parser.add_argument('--alpha', type=float, default=10.) 98 | parser.add_argument('--hint-rate', type=float, default=0.7) 99 | parser.add_argument('--g-train-freq', type=int, default=1) 100 | parser.add_argument('--d-train-freq', type=int, default=5) 101 | 102 | known_args, _ = parser.parse_known_args() 103 | model_cls, _ = get_model_classes(known_args.model_name) 104 | parser = model_cls.add_model_specific_args(parser) 105 | parser = SpatioTemporalDataModule.add_argparse_args(parser) 106 | parser = ImputationDataset.add_argparse_args(parser) 107 | 108 | args = parser.parse_args() 109 | if args.config is not None: 110 | with open(args.config, 'r') as fp: 111 | config_args = yaml.load(fp, Loader=yaml.FullLoader) 112 | for arg in config_args: 113 | setattr(args, arg, config_args[arg]) 114 | 115 | return args 116 | 117 | 118 | def run_experiment(args): 119 | # Set configuration and seed 120 | args = copy.deepcopy(args) 121 | if args.seed < 0: 122 | args.seed = np.random.randint(1e9) 123 | torch.set_num_threads(1) 124 | pl.seed_everything(args.seed) 125 | 126 | model_cls, filler_cls = get_model_classes(args.model_name) 127 | dataset = get_dataset(args.dataset_name) 128 | 129 | ######################################## 130 | # create logdir and save configuration # 131 | ######################################## 132 | 133 | exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}" 134 | logdir = os.path.join(config['logs'], args.dataset_name, args.model_name, exp_name) 135 | # save config for logging 136 | pathlib.Path(logdir).mkdir(parents=True) 137 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp: 138 | yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True) 139 | 140 | ######################################## 141 | # data module # 142 | ######################################## 143 | 144 | # instantiate dataset 145 | dataset_cls = GraphImputationDataset if has_graph_support(model_cls) else ImputationDataset 146 | torch_dataset = dataset_cls(*dataset.numpy(return_idx=True), 147 | mask=dataset.training_mask, 148 | eval_mask=dataset.eval_mask, 149 | window=args.window, 150 | stride=args.stride) 151 | 152 | # get train/val/test indices 153 | split_conf = parser_utils.filter_function_args(args, dataset.splitter, return_dict=True) 154 | train_idxs, val_idxs, test_idxs = dataset.splitter(torch_dataset, **split_conf) 155 | 156 | # configure datamodule 157 | data_conf = parser_utils.filter_args(args, SpatioTemporalDataModule, return_dict=True) 158 | dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs, 159 | **data_conf) 160 | dm.setup() 161 | 162 | # if out of sample in air, add values removed for evaluation in train set 163 | if not args.in_sample and args.dataset_name[:3] == 'air': 164 | dm.torch_dataset.mask[dm.train_slice] |= dm.torch_dataset.eval_mask[dm.train_slice] 165 | 166 | # get adjacency matrix 167 | adj = dataset.get_similarity(thr=args.adj_threshold) 168 | # force adj with no self loop 169 | np.fill_diagonal(adj, 0.) 170 | 171 | ######################################## 172 | # predictor # 173 | ######################################## 174 | 175 | # model's inputs 176 | additional_model_hparams = dict(adj=adj, d_in=dm.d_in, n_nodes=dm.n_nodes) 177 | model_kwargs = parser_utils.filter_args(args={**vars(args), **additional_model_hparams}, 178 | target_cls=model_cls, 179 | return_dict=True) 180 | 181 | # loss and metrics 182 | loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn), 183 | compute_on_step=True, 184 | metric_kwargs={'reduction': 'none'}) 185 | 186 | metrics = {'mae': MaskedMAE(compute_on_step=False), 187 | 'mape': MaskedMAPE(compute_on_step=False), 188 | 'mse': MaskedMSE(compute_on_step=False), 189 | 'mre': MaskedMRE(compute_on_step=False)} 190 | 191 | # filler's inputs 192 | scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None 193 | additional_filler_hparams = dict(model_class=model_cls, 194 | model_kwargs=model_kwargs, 195 | optim_class=torch.optim.Adam, 196 | optim_kwargs={'lr': args.lr, 197 | 'weight_decay': args.l2_reg}, 198 | loss_fn=loss_fn, 199 | metrics=metrics, 200 | scheduler_class=scheduler_class, 201 | scheduler_kwargs={ 202 | 'eta_min': 0.0001, 203 | 'T_max': args.epochs 204 | }, 205 | alpha=args.alpha, 206 | hint_rate=args.hint_rate, 207 | g_train_freq=args.g_train_freq, 208 | d_train_freq=args.d_train_freq) 209 | filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams}, 210 | target_cls=filler_cls, 211 | return_dict=True) 212 | filler = filler_cls(**filler_kwargs) 213 | 214 | ######################################## 215 | # training # 216 | ######################################## 217 | 218 | # callbacks 219 | early_stop_callback = EarlyStopping(monitor='val_mae', patience=args.patience, mode='min') 220 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mae', mode='min') 221 | 222 | logger = TensorBoardLogger(logdir, name="model") 223 | 224 | trainer = pl.Trainer(max_epochs=args.epochs, 225 | logger=logger, 226 | default_root_dir=logdir, 227 | gpus=1 if torch.cuda.is_available() else None, 228 | gradient_clip_val=args.grad_clip_val, 229 | gradient_clip_algorithm=args.grad_clip_algorithm, 230 | callbacks=[early_stop_callback, checkpoint_callback]) 231 | 232 | trainer.fit(filler, datamodule=dm) 233 | 234 | ######################################## 235 | # testing # 236 | ######################################## 237 | 238 | filler.load_state_dict(torch.load(checkpoint_callback.best_model_path, 239 | lambda storage, loc: storage)['state_dict']) 240 | filler.freeze() 241 | trainer.test() 242 | filler.eval() 243 | 244 | if torch.cuda.is_available(): 245 | filler.cuda() 246 | 247 | with torch.no_grad(): 248 | y_true, y_hat, mask = filler.predict_loader(dm.test_dataloader(), return_mask=True) 249 | y_hat = y_hat.detach().cpu().numpy().reshape(y_hat.shape[:3]) # reshape to (eventually) squeeze node channels 250 | 251 | # Test imputations in whole series 252 | eval_mask = dataset.eval_mask[dm.test_slice] 253 | df_true = dataset.df.iloc[dm.test_slice] 254 | metrics = { 255 | 'mae': numpy_metrics.masked_mae, 256 | 'mse': numpy_metrics.masked_mse, 257 | 'mre': numpy_metrics.masked_mre, 258 | 'mape': numpy_metrics.masked_mape 259 | } 260 | # Aggregate predictions in dataframes 261 | index = dm.torch_dataset.data_timestamps(dm.testset.indices, flatten=False)['horizon'] 262 | aggr_methods = ensure_list(args.aggregate_by) 263 | df_hats = prediction_dataframe(y_hat, index, dataset.df.columns, aggregate_by=aggr_methods) 264 | df_hats = dict(zip(aggr_methods, df_hats)) 265 | for aggr_by, df_hat in df_hats.items(): 266 | # Compute error 267 | print(f'- AGGREGATE BY {aggr_by.upper()}') 268 | for metric_name, metric_fn in metrics.items(): 269 | error = metric_fn(df_hat.values, df_true.values, eval_mask).item() 270 | print(f' {metric_name}: {error:.4f}') 271 | 272 | return y_true, y_hat, mask 273 | 274 | 275 | if __name__ == '__main__': 276 | args = parse_args() 277 | run_experiment(args) 278 | -------------------------------------------------------------------------------- /scripts/run_synthetic.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import os 4 | import pathlib 5 | from argparse import ArgumentParser 6 | 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn.functional as F 11 | import yaml 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | 16 | from lib import fillers, config 17 | from lib.datasets import ChargedParticles 18 | from lib.nn import models 19 | from lib.nn.utils.metric_base import MaskedMetric 20 | from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE 21 | from lib.utils import parser_utils 22 | from lib.utils.parser_utils import str_to_bool 23 | 24 | 25 | def has_graph_support(model_cls): 26 | return model_cls is models.GRINet 27 | 28 | 29 | def get_model_classes(model_str): 30 | if model_str == 'brits': 31 | model, filler = models.BRITSNet, fillers.BRITSFiller 32 | elif model_str == 'grin': 33 | model, filler = models.GRINet, fillers.GraphFiller 34 | else: 35 | raise ValueError(f'Model {model_str} not available.') 36 | return model, filler 37 | 38 | 39 | def parse_args(): 40 | # Argument parser 41 | parser = ArgumentParser() 42 | parser.add_argument('--seed', type=int, default=-1) 43 | parser.add_argument("--model-name", type=str, default='bigrill') 44 | parser.add_argument("--config", type=str, default=None) 45 | # Dataset params 46 | parser.add_argument('--static-adj', type=str_to_bool, nargs='?', const=True, default=False) 47 | parser.add_argument('--window', type=int, default=50) 48 | parser.add_argument('--p-block', type=float, default=0.025) 49 | parser.add_argument('--p-point', type=float, default=0.025) 50 | parser.add_argument('--min-seq', type=int, default=5) 51 | parser.add_argument('--max-seq', type=int, default=10) 52 | parser.add_argument('--use-exogenous', type=str_to_bool, nargs='?', const=True, default=True) 53 | # Splitting/aggregation params 54 | parser.add_argument('--val-len', type=float, default=0.1) 55 | parser.add_argument('--test-len', type=float, default=0.2) 56 | # Training params 57 | parser.add_argument('--lr', type=float, default=0.001) 58 | parser.add_argument('--epochs', type=int, default=300) 59 | parser.add_argument('--patience', type=int, default=40) 60 | parser.add_argument('--l2-reg', type=float, default=0.) 61 | parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=False) 62 | parser.add_argument('--grad-clip-val', type=float, default=5.) 63 | parser.add_argument('--grad-clip-algorithm', type=str, default='norm') 64 | parser.add_argument('--loss-fn', type=str, default='mse_loss') 65 | parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True) 66 | parser.add_argument('--whiten-prob', type=float, default=0.05) 67 | parser.add_argument('--pred-loss-weight', type=float, default=1.0) 68 | parser.add_argument('--warm-up', type=int, default=0) 69 | # graph params 70 | parser.add_argument("--adj-threshold", type=float, default=0.1) 71 | 72 | known_args, _ = parser.parse_known_args() 73 | model_cls, _ = get_model_classes(known_args.model_name) 74 | parser = model_cls.add_model_specific_args(parser) 75 | 76 | args = parser.parse_args() 77 | if args.config is not None: 78 | with open(args.config, 'r') as fp: 79 | config_args = yaml.load(fp, Loader=yaml.FullLoader) 80 | for arg in config_args: 81 | setattr(args, arg, config_args[arg]) 82 | 83 | return args 84 | 85 | 86 | def run_experiment(args): 87 | # Set configuration and seed 88 | args = copy.deepcopy(args) 89 | if args.seed < 0: 90 | args.seed = np.random.randint(1e9) 91 | torch.set_num_threads(1) 92 | pl.seed_everything(args.seed) 93 | 94 | ######################################## 95 | # load dataset and model # 96 | ######################################## 97 | 98 | model_cls, filler_cls = get_model_classes(args.model_name) 99 | 100 | dataset = ChargedParticles(static_adj=args.static_adj, 101 | window=args.window, 102 | p_block=args.p_block, 103 | p_point=args.p_point, 104 | max_seq=args.max_seq, 105 | min_seq=args.min_seq, 106 | use_exogenous=args.use_exogenous, 107 | graph_mode=has_graph_support(model_cls)) 108 | 109 | dataset.split(args.val_len, args.test_len) 110 | 111 | # get adjacency matrix 112 | adj = dataset.get_similarity() 113 | np.fill_diagonal(adj, 0.) # force adj with no self loop 114 | 115 | ######################################## 116 | # create logdir and save configuration # 117 | ######################################## 118 | 119 | exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}" 120 | logdir = os.path.join(config['logs'], 'synthetic', args.model_name, exp_name) 121 | # save config for logging 122 | pathlib.Path(logdir).mkdir(parents=True) 123 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp: 124 | yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True) 125 | 126 | ######################################## 127 | # predictor # 128 | ######################################## 129 | 130 | # model's inputs 131 | if has_graph_support(model_cls): 132 | model_params = dict(adj=adj, d_in=dataset.n_channels, d_u=dataset.n_exogenous, n_nodes=dataset.n_nodes) 133 | else: 134 | model_params = dict(d_in=(dataset.n_channels * dataset.n_nodes), d_u=(dataset.n_channels * dataset.n_exogenous)) 135 | model_kwargs = parser_utils.filter_args(args={**vars(args), **model_params}, 136 | target_cls=model_cls, 137 | return_dict=True) 138 | 139 | # loss and metrics 140 | loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn), 141 | compute_on_step=True, 142 | metric_kwargs={'reduction': 'none'}) 143 | 144 | metrics = {'mae': MaskedMAE(compute_on_step=False), 145 | 'mape': MaskedMAPE(compute_on_step=False), 146 | 'mse': MaskedMSE(compute_on_step=False), 147 | 'mre': MaskedMRE(compute_on_step=False)} 148 | 149 | # filler's inputs 150 | scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None 151 | additional_filler_hparams = dict(model_class=model_cls, 152 | model_kwargs=model_kwargs, 153 | optim_class=torch.optim.Adam, 154 | optim_kwargs={'lr': args.lr, 155 | 'weight_decay': args.l2_reg}, 156 | loss_fn=loss_fn, 157 | metrics=metrics, 158 | scheduler_class=scheduler_class, 159 | scheduler_kwargs={ 160 | 'eta_min': 0.0001, 161 | 'T_max': args.epochs 162 | }, 163 | alpha=args.alpha, 164 | hint_rate=args.hint_rate, 165 | g_train_freq=args.g_train_freq, 166 | d_train_freq=args.d_train_freq) 167 | filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams}, 168 | target_cls=filler_cls, 169 | return_dict=True) 170 | filler = filler_cls(**filler_kwargs) 171 | 172 | ######################################## 173 | # logging options # 174 | ######################################## 175 | 176 | # log number of parameters 177 | args.trainable_parameters = filler.trainable_parameters 178 | 179 | # log statistics on masks 180 | for mask_type in ['mask', 'eval_mask', 'training_mask']: 181 | mask_type_mean = getattr(dataset, mask_type).float().mean().item() 182 | setattr(args, mask_type, mask_type_mean) 183 | 184 | print(args) 185 | 186 | ######################################## 187 | # training # 188 | ######################################## 189 | 190 | # callbacks 191 | early_stop_callback = EarlyStopping(monitor='val_mse', patience=args.patience, mode='min') 192 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mse', mode='min') 193 | 194 | logger = TensorBoardLogger(logdir, name="model") 195 | 196 | trainer = pl.Trainer(max_epochs=args.epochs, 197 | default_root_dir=logdir, 198 | logger=logger, 199 | gpus=1 if torch.cuda.is_available() else None, 200 | gradient_clip_val=args.grad_clip_val, 201 | gradient_clip_algorithm=args.grad_clip_algorithm, 202 | callbacks=[early_stop_callback, checkpoint_callback]) 203 | 204 | trainer.fit(filler, 205 | train_dataloader=dataset.train_dataloader(batch_size=args.batch_size), 206 | val_dataloaders=dataset.val_dataloader(batch_size=args.batch_size)) 207 | 208 | ######################################## 209 | # testing # 210 | ######################################## 211 | 212 | filler.load_state_dict(torch.load(checkpoint_callback.best_model_path, 213 | lambda storage, loc: storage)['state_dict']) 214 | filler.freeze() 215 | trainer.test(filler, test_dataloaders=dataset.test_dataloader(batch_size=args.batch_size)) 216 | filler.eval() 217 | 218 | 219 | if __name__ == '__main__': 220 | args = parse_args() 221 | run_experiment(args) 222 | --------------------------------------------------------------------------------