├── .gitignore ├── LICENSE ├── README.md ├── conda_env.yml ├── config ├── imputation │ ├── brits.yaml │ ├── grin.yaml │ ├── saits.yaml │ ├── spin.yaml │ ├── spin_h.yaml │ └── transformer.yaml └── inference.yaml ├── experiments ├── run_imputation.py └── run_inference.py ├── paper_neurips.pdf ├── poster_neurips.pdf ├── sparse_att.png ├── spin ├── __init__.py ├── baselines │ ├── __init__.py │ ├── brits │ │ ├── __init__.py │ │ ├── brits.py │ │ └── layers.py │ ├── saits │ │ ├── __init__.py │ │ ├── layers.py │ │ └── saits.py │ └── transformer │ │ ├── __init__.py │ │ └── transformer.py ├── imputers │ ├── __init__.py │ ├── brits_imputer.py │ ├── saits_imputer.py │ └── spin_imputer.py ├── layers │ ├── __init__.py │ ├── additive_attention.py │ ├── hierarchical_temporal_graph_attention.py │ ├── postional_encoding.py │ └── temporal_graph_additive_attention.py ├── models │ ├── __init__.py │ ├── spin.py │ └── spin_hierarchical.py ├── scheduler.py └── utils.py └── tsl_config.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_STORE 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Graph Machine Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations (NeurIPS 2022) 2 | 3 | [![NeurIPS](https://img.shields.io/badge/NeurIPS-2022-blue.svg?style=flat-square)](#) 4 | [![PDF](https://img.shields.io/badge/%E2%87%A9-PDF-orange.svg?style=flat-square)](https://arxiv.org/pdf/2205.13479) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2205.13479-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2205.13479) 6 | 7 | This repository contains the code for the reproducibility of the experiments presented in the paper "Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations" (NeurIPS 2022). We propose a graph neural network that exploits a novel spatiotemporal attention to impute missing values leveraging only (sparse) valid observations. 8 | 9 | **Authors**: [Ivan Marisca](mailto:ivan.marisca@usi.ch), [Andrea Cini](mailto:andrea.cini@usi.ch), Cesare Alippi 10 | 11 | ## SPIN in a nutshell 12 | 13 | Spatiotemporal graphs are often highly sparse, with time series characterized by multiple, concurrent, and even long sequences of missing data, e.g., due to the unreliable underlying sensor network. In this context, autoregressive models can be brittle and exhibit unstable learning dynamics. The objective of this paper is, then, to tackle the problem of learning effective models to reconstruct, i.e., impute, missing data points by conditioning the reconstruction only on the available observations. In particular, we propose a novel class of attention-based architectures that, given a set of highly sparse discrete observations, learn a representation for points in time and space by exploiting a spatiotemporal diffusion architecture aligned with the imputation task. Representations are trained end-to-end to reconstruct observations w.r.t. the corresponding sensor and its neighboring nodes. Compared to the state of the art, our model handles sparse data without propagating prediction errors or requiring a bidirectional model to encode forward and backward time dependencies. 14 | 15 |
16 | Example of the sparse spatiotemporal attention layer. 17 |

Example of the sparse spatiotemporal attention layer. On the left, the input spatiotemporal graph, with time series associated with every node. On the right, how the layer acts to update target representation (highlighted by the green box), by simultaneously performing inter-node spatiotemporal cross-attention (red block) and intra-node temporal self-attention (violet block).

18 |
19 | 20 | --- 21 | 22 | ## Directory structure 23 | 24 | The directory is structured as follows: 25 | 26 | ``` 27 | . 28 | ├── config/ 29 | │   ├── imputation/ 30 | │   │   ├── brits.yaml 31 | │   │   ├── grin.yaml 32 | │   │   ├── saits.yaml 33 | │   │   ├── spin.yaml 34 | │   │   ├── spin_h.yaml 35 | │   │   └── transformer.yaml 36 | │   └── inference.yaml 37 | ├── experiments/ 38 | │   ├── run_imputation.py 39 | │   └── run_inference.py 40 | ├── spin/ 41 | │   ├── baselines/ 42 | │   ├── imputers/ 43 | │   ├── layers/ 44 | │   ├── models/ 45 | │   └── ... 46 | ├── conda_env.yaml 47 | └── tsl_config.yaml 48 | 49 | ``` 50 | 51 | ## Installation 52 | 53 | We provide a conda environment with all the project dependencies. To install the environment use: 54 | 55 | ```bash 56 | conda env create -f conda_env.yml 57 | conda activate spin 58 | ``` 59 | 60 | ## Configuration files 61 | 62 | The `config/` directory stores all the configuration files used to run the experiment. `config/imputation/` stores model configurations used for experiments on imputation. 63 | 64 | ## Python package `spin` 65 | 66 | The support code, including models and baselines, are packed in a python package named `spin`. 67 | 68 | ## Experiments 69 | 70 | The scripts used for the experiment in the paper are in the `experiments` folder. 71 | 72 | * `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is 73 | 74 | ```bash 75 | conda activate spin 76 | python ./experiments/run_imputation.py --config imputation/spin.yaml --model-name spin --dataset-name bay_block 77 | ``` 78 | 79 | * `run_inference.py` is used for the experiments on sparse datasets using pre-trained models. An example of usage is 80 | 81 | ```bash 82 | conda activate spin 83 | python ./experiments/run_inference.py --config inference.yaml --model-name spin --dataset-name bay_point --exp-name {exp_name} 84 | ``` 85 | 86 | ## Bibtex reference 87 | 88 | If you find this code useful please consider to cite our paper: 89 | 90 | ``` 91 | @article{marisca2022learning, 92 | title={Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations}, 93 | author={Marisca, Ivan and Cini, Andrea and Alippi, Cesare}, 94 | journal={arXiv preprint arXiv:2205.13479}, 95 | year={2022} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: spin 2 | channels: 3 | - pytorch 4 | - pyg 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - pip 9 | - pyg>=2.0 10 | - python=3.8 11 | - pytorch>=1.9 12 | - torchvision 13 | - torchaudio 14 | - wheel 15 | - pip: 16 | - torch_spatiotemporal==0.1.1 -------------------------------------------------------------------------------- /config/imputation/brits.yaml: -------------------------------------------------------------------------------- 1 | ######################### BRITS CONFIG ########################## 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: 0.05 12 | scale_target: True 13 | 14 | epochs: 300 15 | loss_fn: l1_loss 16 | lr_scheduler: cosine 17 | lr: 0.001 18 | batch_size: 32 19 | batches_epoch: 160 20 | 21 | #### Model params ############################################################# 22 | model_name: 'brits' 23 | hidden_size: 64 # [64, 128, 256] 24 | -------------------------------------------------------------------------------- /config/imputation/grin.yaml: -------------------------------------------------------------------------------- 1 | ########################## GRIN CONFIG ########################## 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: 0.05 12 | scale_target: True 13 | 14 | epochs: 300 15 | loss_fn: l1_loss 16 | lr_scheduler: cosine 17 | lr: 0.001 18 | batch_size: 32 19 | batches_epoch: 160 20 | 21 | #### Model params ############################################################# 22 | model_name: 'grin' 23 | 24 | hidden_size: 64 25 | ff_size: 64 26 | embedding_size: 8 27 | n_layers: 1 28 | kernel_size: 2 29 | decoder_order: 1 30 | layer_norm: false 31 | ff_dropout: 0 32 | merge_mode: 'mlp' 33 | -------------------------------------------------------------------------------- /config/imputation/saits.yaml: -------------------------------------------------------------------------------- 1 | ######################### SAITS CONFIG ########################## 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: 0.2 12 | prediction_loss_weight: 1 13 | scale_target: True 14 | 15 | epochs: 300 16 | loss_fn: l1_loss 17 | lr: 0.0003 18 | batch_size: 128 19 | batches_epoch: 300 20 | patience: 40 21 | 22 | #### Model params ############################################################# 23 | model_name: 'saits' 24 | input_with_mask: True 25 | 26 | n_groups: 1 27 | n_group_inner_layers: 1 28 | param_sharing_strategy: inner_group 29 | d_model: 1024 30 | d_inner: 1024 31 | n_head: 4 32 | d_k: 256 33 | d_v: 256 34 | dropout: 0 35 | diagonal_attention_mask: True -------------------------------------------------------------------------------- /config/imputation/spin.yaml: -------------------------------------------------------------------------------- 1 | ########################## SPIN CONFIG ########################## 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: [0.2, 0.5, 0.8] 12 | scale_target: True 13 | 14 | epochs: 300 15 | loss_fn: l1_loss 16 | lr_scheduler: magic 17 | lr: 0.0008 18 | patience: 40 19 | precision: 16 20 | batch_size: 8 21 | split_batch_in: 2 22 | batches_epoch: 300 23 | 24 | #### Model params ############################################################# 25 | model_name: 'spin' 26 | hidden_size: 32 27 | eta: 3 28 | n_layers: 4 29 | message_layers: 1 30 | temporal_self_attention: True 31 | reweight: 'softmax' 32 | -------------------------------------------------------------------------------- /config/imputation/spin_h.yaml: -------------------------------------------------------------------------------- 1 | ########################## SPIN-H CONFIG ######################## 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: [0.2, 0.5, 0.8] 12 | scale_target: True 13 | 14 | epochs: 300 15 | loss_fn: l1_loss 16 | lr_scheduler: magic 17 | lr: 0.0008 18 | patience: 40 19 | precision: 16 20 | batch_size: 8 21 | batch_inference: 20 22 | batches_epoch: 300 23 | 24 | #### Model params ############################################################# 25 | model_name: 'spin_h' 26 | h_size: 32 27 | z_size: 128 28 | z_heads: 4 29 | eta: 3 30 | n_layers: 5 31 | message_layers: 1 32 | update_z_cross: False 33 | norm: True 34 | reweight: 'softmax' 35 | spatial_aggr: 'softmax' 36 | -------------------------------------------------------------------------------- /config/imputation/transformer.yaml: -------------------------------------------------------------------------------- 1 | ###################### TRANSFORMER CONFIG ####################### 2 | 3 | #### Dataset params ########################################################### 4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 5 | val_len: 0.1 6 | 7 | window: 24 # [24, 36] 8 | stride: 1 9 | 10 | #### Training params ########################################################## 11 | whiten_prob: [0.2, 0.5, 0.8] 12 | scale_target: True 13 | 14 | epochs: 300 15 | loss_fn: l1_loss 16 | lr_scheduler: magic 17 | lr: 0.0008 18 | patience: 40 19 | precision: 16 20 | batch_size: 8 21 | batch_inference: 32 22 | batches_epoch: 300 23 | 24 | #### Model params ############################################################# 25 | model_name: 'transformer' 26 | condition_on_u: True 27 | hidden_size: 64 28 | ff_size: 128 29 | n_heads: 4 30 | n_layers: 5 31 | dropout: 0 32 | axis: 'both' 33 | -------------------------------------------------------------------------------- /config/inference.yaml: -------------------------------------------------------------------------------- 1 | ######################### INFERENCE CONFIG ########################## 2 | 3 | #### Experiment params ######################################################## 4 | root: 'log' 5 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36] 6 | #model_name: [spin, spin_h, grin, transformer, saits, brits] 7 | #exp_name: {exp_name} 8 | 9 | #### Data params ############################################################## 10 | p_fault: 0 11 | p_noise: 0.95 # [0.5, 0.75, 0.95] 12 | test_mask_seed: [1043, 2043, 3043, 4043, 5043] 13 | 14 | #### Windowing params ######################################################### 15 | batch_size: 128 16 | -------------------------------------------------------------------------------- /experiments/run_imputation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import os 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | import yaml 9 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from tsl import config, logger 13 | from tsl.data import SpatioTemporalDataModule, ImputationDataset 14 | from tsl.data.preprocessing import StandardScaler 15 | from tsl.datasets import AirQuality, MetrLA, PemsBay 16 | from tsl.imputers import Imputer 17 | from tsl.nn.metrics import MaskedMetric, MaskedMAE, MaskedMSE, MaskedMRE 18 | from tsl.nn.models.imputation import GRINModel 19 | from tsl.nn.utils import casting 20 | from tsl.ops.imputation import add_missing_values 21 | from tsl.utils import parser_utils, numpy_metrics 22 | from tsl.utils.parser_utils import ArgParser 23 | 24 | from spin.baselines import SAITS, TransformerModel, BRITS 25 | from spin.imputers import SPINImputer, SAITSImputer, BRITSImputer 26 | from spin.models import SPINModel, SPINHierarchicalModel 27 | from spin.scheduler import CosineSchedulerWithRestarts 28 | 29 | 30 | def get_model_classes(model_str): 31 | if model_str == 'spin': 32 | model, filler = SPINModel, SPINImputer 33 | elif model_str == 'spin_h': 34 | model, filler = SPINHierarchicalModel, SPINImputer 35 | elif model_str == 'grin': 36 | model, filler = GRINModel, Imputer 37 | elif model_str == 'saits': 38 | model, filler = SAITS, SAITSImputer 39 | elif model_str == 'transformer': 40 | model, filler = TransformerModel, SPINImputer 41 | elif model_str == 'brits': 42 | model, filler = BRITS, BRITSImputer 43 | else: 44 | raise ValueError(f'Model {model_str} not available.') 45 | return model, filler 46 | 47 | 48 | def get_dataset(dataset_name: str): 49 | if dataset_name.startswith('air'): 50 | return AirQuality(impute_nans=True, small=dataset_name[3:] == '36') 51 | # build missing dataset 52 | if dataset_name.endswith('_point'): 53 | p_fault, p_noise = 0., 0.25 54 | dataset_name = dataset_name[:-6] 55 | elif dataset_name.endswith('_block'): 56 | p_fault, p_noise = 0.0015, 0.05 57 | dataset_name = dataset_name[:-6] 58 | else: 59 | raise ValueError(f"Invalid dataset name: {dataset_name}.") 60 | if dataset_name == 'la': 61 | return add_missing_values(MetrLA(), p_fault=p_fault, p_noise=p_noise, 62 | min_seq=12, max_seq=12 * 4, seed=9101112) 63 | if dataset_name == 'bay': 64 | return add_missing_values(PemsBay(), p_fault=p_fault, p_noise=p_noise, 65 | min_seq=12, max_seq=12 * 4, seed=56789) 66 | raise ValueError(f"Invalid dataset name: {dataset_name}.") 67 | 68 | 69 | def get_scheduler(scheduler_name: str = None, args=None): 70 | if scheduler_name is None: 71 | return None, None 72 | scheduler_name = scheduler_name.lower() 73 | if scheduler_name == 'cosine': 74 | scheduler_class = CosineAnnealingLR 75 | scheduler_kwargs = dict(eta_min=0.1 * args.lr, T_max=args.epochs) 76 | elif scheduler_name == 'magic': 77 | scheduler_class = CosineSchedulerWithRestarts 78 | scheduler_kwargs = dict(num_warmup_steps=12, min_factor=0.1, 79 | linear_decay=0.67, 80 | num_training_steps=args.epochs, 81 | num_cycles=args.epochs // 100) 82 | else: 83 | raise ValueError(f"Invalid scheduler name: {scheduler_name}.") 84 | return scheduler_class, scheduler_kwargs 85 | 86 | 87 | def parse_args(): 88 | # Argument parser 89 | parser = ArgParser() 90 | 91 | parser.add_argument('--seed', type=int, default=-1) 92 | parser.add_argument('--precision', type=int, default=32) 93 | parser.add_argument("--model-name", type=str, default='spin') 94 | parser.add_argument("--dataset-name", type=str, default='air36') 95 | parser.add_argument("--config", type=str, default='imputation/spin.yaml') 96 | 97 | # Splitting/aggregation params 98 | parser.add_argument('--val-len', type=float, default=0.1) 99 | parser.add_argument('--test-len', type=float, default=0.2) 100 | 101 | # Training params 102 | parser.add_argument('--lr', type=float, default=0.001) 103 | parser.add_argument('--epochs', type=int, default=300) 104 | parser.add_argument('--patience', type=int, default=40) 105 | parser.add_argument('--l2-reg', type=float, default=0.) 106 | parser.add_argument('--batches-epoch', type=int, default=300) 107 | parser.add_argument('--batch-inference', type=int, default=32) 108 | parser.add_argument('--split-batch-in', type=int, default=1) 109 | parser.add_argument('--grad-clip-val', type=float, default=5.) 110 | parser.add_argument('--loss-fn', type=str, default='l1_loss') 111 | parser.add_argument('--lr-scheduler', type=str, default=None) 112 | 113 | # Connectivity params 114 | parser.add_argument("--adj-threshold", type=float, default=0.1) 115 | 116 | known_args, _ = parser.parse_known_args() 117 | model_cls, imputer_cls = get_model_classes(known_args.model_name) 118 | parser = model_cls.add_model_specific_args(parser) 119 | parser = imputer_cls.add_argparse_args(parser) 120 | parser = SpatioTemporalDataModule.add_argparse_args(parser) 121 | parser = ImputationDataset.add_argparse_args(parser) 122 | 123 | args = parser.parse_args() 124 | if args.config is not None: 125 | cfg_path = os.path.join(config.config_dir, args.config) 126 | with open(cfg_path, 'r') as fp: 127 | config_args = yaml.load(fp, Loader=yaml.FullLoader) 128 | for arg in config_args: 129 | setattr(args, arg, config_args[arg]) 130 | 131 | return args 132 | 133 | 134 | def run_experiment(args): 135 | # Set configuration and seed 136 | args = copy.deepcopy(args) 137 | if args.seed < 0: 138 | args.seed = np.random.randint(1e9) 139 | torch.set_num_threads(1) 140 | pl.seed_everything(args.seed) 141 | 142 | # script flags 143 | is_spin = args.model_name in ['spin', 'spin_h'] 144 | 145 | model_cls, imputer_class = get_model_classes(args.model_name) 146 | dataset = get_dataset(args.dataset_name) 147 | 148 | logger.info(args) 149 | 150 | ######################################## 151 | # create logdir and save configuration # 152 | ######################################## 153 | 154 | exp_name = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 155 | exp_name = f"{exp_name}_{args.seed}" 156 | logdir = os.path.join(config.log_dir, args.dataset_name, 157 | args.model_name, exp_name) 158 | # save config for logging 159 | os.makedirs(logdir, exist_ok=True) 160 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp: 161 | yaml.dump(parser_utils.config_dict_from_args(args), fp, 162 | indent=4, sort_keys=True) 163 | 164 | ######################################## 165 | # data module # 166 | ######################################## 167 | 168 | # time embedding 169 | if is_spin or args.model_name == 'transformer': 170 | time_emb = dataset.datetime_encoded(['day', 'week']).values 171 | exog_map = {'global_temporal_encoding': time_emb} 172 | 173 | input_map = { 174 | 'u': 'temporal_encoding', 175 | 'x': 'data' 176 | } 177 | else: 178 | exog_map = input_map = None 179 | 180 | if is_spin or args.model_name == 'grin': 181 | adj = dataset.get_connectivity(threshold=args.adj_threshold, 182 | include_self=False, 183 | force_symmetric=is_spin) 184 | else: 185 | adj = None 186 | 187 | # instantiate dataset 188 | torch_dataset = ImputationDataset(*dataset.numpy(return_idx=True), 189 | training_mask=dataset.training_mask, 190 | eval_mask=dataset.eval_mask, 191 | connectivity=adj, 192 | exogenous=exog_map, 193 | input_map=input_map, 194 | window=args.window, 195 | stride=args.stride) 196 | 197 | # get train/val/test indices 198 | splitter = dataset.get_splitter(val_len=args.val_len, 199 | test_len=args.test_len) 200 | 201 | scalers = {'data': StandardScaler(axis=(0, 1))} 202 | 203 | dm = SpatioTemporalDataModule(torch_dataset, 204 | scalers=scalers, 205 | splitter=splitter, 206 | batch_size=args.batch_size // args.split_batch_in) 207 | dm.setup() 208 | 209 | ######################################## 210 | # predictor # 211 | ######################################## 212 | 213 | additional_model_hparams = dict(n_nodes=dm.n_nodes, 214 | input_size=dm.n_channels, 215 | u_size=4, 216 | output_size=dm.n_channels, 217 | window_size=dm.window) 218 | 219 | # model's inputs 220 | model_kwargs = parser_utils.filter_args( 221 | args={**vars(args), **additional_model_hparams}, 222 | target_cls=model_cls, 223 | return_dict=True) 224 | 225 | # loss and metrics 226 | loss_fn = MaskedMetric(metric_fn=getattr(torch.nn.functional, args.loss_fn), 227 | compute_on_step=True, 228 | metric_kwargs={'reduction': 'none'}) 229 | 230 | metrics = {'mae': MaskedMAE(compute_on_step=False), 231 | 'mse': MaskedMSE(compute_on_step=False), 232 | 'mre': MaskedMRE(compute_on_step=False)} 233 | 234 | scheduler_class, scheduler_kwargs = get_scheduler(args.lr_scheduler, args) 235 | 236 | # setup imputer 237 | imputer_kwargs = parser_utils.filter_argparse_args(args, imputer_class, 238 | return_dict=True) 239 | imputer = imputer_class( 240 | model_class=model_cls, 241 | model_kwargs=model_kwargs, 242 | optim_class=torch.optim.Adam, 243 | optim_kwargs={'lr': args.lr, 244 | 'weight_decay': args.l2_reg}, 245 | loss_fn=loss_fn, 246 | metrics=metrics, 247 | scheduler_class=scheduler_class, 248 | scheduler_kwargs=scheduler_kwargs, 249 | **imputer_kwargs 250 | ) 251 | 252 | ######################################## 253 | # training # 254 | ######################################## 255 | 256 | # callbacks 257 | early_stop_callback = EarlyStopping(monitor='val_mae', 258 | patience=args.patience, mode='min') 259 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, 260 | monitor='val_mae', mode='min') 261 | 262 | tb_logger = TensorBoardLogger(logdir, name="model") 263 | 264 | trainer = pl.Trainer(max_epochs=args.epochs, 265 | default_root_dir=logdir, 266 | logger=tb_logger, 267 | precision=args.precision, 268 | accumulate_grad_batches=args.split_batch_in, 269 | gpus=int(torch.cuda.is_available()), 270 | gradient_clip_val=args.grad_clip_val, 271 | limit_train_batches=args.batches_epoch * args.split_batch_in, 272 | callbacks=[early_stop_callback, checkpoint_callback]) 273 | 274 | trainer.fit(imputer, 275 | train_dataloaders=dm.train_dataloader(), 276 | val_dataloaders=dm.val_dataloader( 277 | batch_size=args.batch_inference)) 278 | 279 | ######################################## 280 | # testing # 281 | ######################################## 282 | 283 | imputer.load_model(checkpoint_callback.best_model_path) 284 | imputer.freeze() 285 | trainer.test(imputer, dataloaders=dm.test_dataloader( 286 | batch_size=args.batch_inference)) 287 | 288 | output = trainer.predict(imputer, dataloaders=dm.test_dataloader( 289 | batch_size=args.batch_inference)) 290 | output = casting.numpy(output) 291 | y_hat, y_true, mask = output['y_hat'].squeeze(-1), \ 292 | output['y'].squeeze(-1), \ 293 | output['mask'].squeeze(-1) 294 | check_mae = numpy_metrics.masked_mae(y_hat, y_true, mask) 295 | print(f'Test MAE: {check_mae:.2f}') 296 | return y_hat 297 | 298 | 299 | if __name__ == '__main__': 300 | args = parse_args() 301 | run_experiment(args) 302 | -------------------------------------------------------------------------------- /experiments/run_inference.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import tsl 8 | import yaml 9 | from tsl import config 10 | from tsl.data import SpatioTemporalDataModule, ImputationDataset 11 | from tsl.data.preprocessing import StandardScaler 12 | from tsl.datasets import AirQuality, MetrLA, PemsBay 13 | from tsl.imputers import Imputer 14 | from tsl.nn.models.imputation import GRINModel 15 | from tsl.nn.utils import casting 16 | from tsl.ops.imputation import add_missing_values, sample_mask 17 | from tsl.utils import ArgParser, parser_utils, numpy_metrics 18 | from tsl.utils.python_utils import ensure_list 19 | 20 | from spin.baselines import SAITS, TransformerModel, BRITS 21 | from spin.imputers import SPINImputer, SAITSImputer, BRITSImputer 22 | from spin.models import SPINModel, SPINHierarchicalModel 23 | 24 | 25 | def get_model_classes(model_str): 26 | if model_str == 'spin': 27 | model, filler = SPINModel, SPINImputer 28 | elif model_str == 'spin_h': 29 | model, filler = SPINHierarchicalModel, SPINImputer 30 | elif model_str == 'grin': 31 | model, filler = GRINModel, Imputer 32 | elif model_str == 'saits': 33 | model, filler = SAITS, SAITSImputer 34 | elif model_str == 'transformer': 35 | model, filler = TransformerModel, SPINImputer 36 | elif model_str == 'brits': 37 | model, filler = BRITS, BRITSImputer 38 | else: 39 | raise ValueError(f'Model {model_str} not available.') 40 | return model, filler 41 | 42 | 43 | def get_dataset(dataset_name: str): 44 | if dataset_name.startswith('air'): 45 | return AirQuality(impute_nans=True, small=dataset_name[3:] == '36') 46 | # build missing dataset 47 | if dataset_name.endswith('_point'): 48 | p_fault, p_noise = 0., 0.25 49 | dataset_name = dataset_name[:-6] 50 | elif dataset_name.endswith('_block'): 51 | p_fault, p_noise = 0.0015, 0.05 52 | dataset_name = dataset_name[:-6] 53 | else: 54 | raise ValueError(f"Invalid dataset name: {dataset_name}.") 55 | if dataset_name == 'la': 56 | return add_missing_values(MetrLA(), p_fault=p_fault, p_noise=p_noise, 57 | min_seq=12, max_seq=12 * 4, seed=9101112) 58 | if dataset_name == 'bay': 59 | return add_missing_values(PemsBay(), p_fault=p_fault, p_noise=p_noise, 60 | min_seq=12, max_seq=12 * 4, seed=56789) 61 | raise ValueError(f"Invalid dataset name: {dataset_name}.") 62 | 63 | 64 | def parse_args(): 65 | # Argument parser 66 | parser = ArgParser() 67 | 68 | parser.add_argument("--model-name", type=str) 69 | parser.add_argument("--dataset-name", type=str) 70 | parser.add_argument("--exp-name", type=str) 71 | parser.add_argument("--config", type=str, default='inference.yaml') 72 | parser.add_argument("--root", type=str, default='log') 73 | 74 | # Data sparsity params 75 | parser.add_argument('--p-fault', type=float, default=0.0) 76 | parser.add_argument('--p-noise', type=float, default=0.75) 77 | parser.add_argument('--test-mask-seed', type=int, default=None) 78 | 79 | # Splitting/aggregation params 80 | parser.add_argument('--val-len', type=float, default=0.1) 81 | parser.add_argument('--test-len', type=float, default=0.2) 82 | parser.add_argument('--batch-size', type=int, default=32) 83 | 84 | # Connectivity params 85 | parser.add_argument("--adj-threshold", type=float, default=0.1) 86 | 87 | args = parser.parse_args() 88 | if args.config is not None: 89 | cfg_path = os.path.join(config.config_dir, args.config) 90 | with open(cfg_path, 'r') as fp: 91 | config_args = yaml.load(fp, Loader=yaml.FullLoader) 92 | for arg in config_args: 93 | setattr(args, arg, config_args[arg]) 94 | 95 | return args 96 | 97 | 98 | def load_model(exp_dir, exp_config, dm): 99 | model_cls, imputer_class = get_model_classes(exp_config['model_name']) 100 | additional_model_hparams = dict(n_nodes=dm.n_nodes, 101 | input_size=dm.n_channels, 102 | u_size=4, 103 | output_size=dm.n_channels, 104 | window_size=dm.window) 105 | 106 | # model's inputs 107 | model_kwargs = parser_utils.filter_args( 108 | args={**exp_config, **additional_model_hparams}, 109 | target_cls=model_cls, 110 | return_dict=True) 111 | 112 | # setup imputer 113 | imputer_kwargs = parser_utils.filter_argparse_args(exp_config, 114 | imputer_class, 115 | return_dict=True) 116 | imputer = imputer_class( 117 | model_class=model_cls, 118 | model_kwargs=model_kwargs, 119 | optim_class=torch.optim.Adam, 120 | optim_kwargs={}, 121 | loss_fn=None, 122 | **imputer_kwargs 123 | ) 124 | 125 | model_path = None 126 | for file in os.listdir(exp_dir): 127 | if file.endswith(".ckpt"): 128 | model_path = os.path.join(exp_dir, file) 129 | break 130 | if model_path is None: 131 | raise ValueError(f"Model not found.") 132 | 133 | imputer.load_model(model_path) 134 | imputer.freeze() 135 | return imputer 136 | 137 | 138 | def update_test_eval_mask(dm, dataset, p_fault, p_noise, seed=None): 139 | if seed is None: 140 | seed = np.random.randint(1e9) 141 | random = np.random.default_rng(seed) 142 | dataset.set_eval_mask( 143 | sample_mask(dataset.shape, p=p_fault, p_noise=p_noise, 144 | min_seq=12, max_seq=36, rng=random) 145 | ) 146 | dm.torch_dataset.set_mask(dataset.training_mask) 147 | dm.torch_dataset.update_exogenous('eval_mask', dataset.eval_mask) 148 | 149 | 150 | def run_experiment(args): 151 | # Set configuration 152 | args = copy.deepcopy(args) 153 | tsl.logger.disabled = True 154 | 155 | # script flags 156 | is_spin = args.model_name in ['spin', 'spin_h'] 157 | 158 | ######################################## 159 | # load config # 160 | ######################################## 161 | 162 | if args.root is None: 163 | root = tsl.config.log_dir 164 | else: 165 | root = os.path.join(tsl.config.curr_dir, args.root) 166 | exp_dir = os.path.join(root, args.dataset_name, 167 | args.model_name, args.exp_name) 168 | 169 | with open(os.path.join(exp_dir, 'config.yaml'), 'r') as fp: 170 | exp_config = yaml.load(fp, Loader=yaml.FullLoader) 171 | 172 | ######################################## 173 | # load dataset # 174 | ######################################## 175 | 176 | dataset = get_dataset(exp_config['dataset_name']) 177 | 178 | ######################################## 179 | # load data module # 180 | ######################################## 181 | 182 | # time embedding 183 | if is_spin or args.model_name == 'transformer': 184 | time_emb = dataset.datetime_encoded(['day', 'week']).values 185 | exog_map = {'global_temporal_encoding': time_emb} 186 | 187 | input_map = { 188 | 'u': 'temporal_encoding', 189 | 'x': 'data' 190 | } 191 | else: 192 | exog_map = input_map = None 193 | 194 | if is_spin or args.model_name == 'grin': 195 | adj = dataset.get_connectivity(threshold=args.adj_threshold, 196 | include_self=False, 197 | force_symmetric=is_spin) 198 | else: 199 | adj = None 200 | 201 | # instantiate dataset 202 | torch_dataset = ImputationDataset(*dataset.numpy(return_idx=True), 203 | training_mask=dataset.training_mask, 204 | eval_mask=dataset.eval_mask, 205 | connectivity=adj, 206 | exogenous=exog_map, 207 | input_map=input_map, 208 | window=exp_config['window'], 209 | stride=exp_config['stride']) 210 | 211 | # get train/val/test indices 212 | splitter = dataset.get_splitter(val_len=args.val_len, 213 | test_len=args.test_len) 214 | 215 | scalers = {'data': StandardScaler(axis=(0, 1))} 216 | 217 | dm = SpatioTemporalDataModule(torch_dataset, 218 | scalers=scalers, 219 | splitter=splitter, 220 | batch_size=args.batch_size) 221 | dm.setup() 222 | 223 | ######################################## 224 | # load model # 225 | ######################################## 226 | 227 | imputer = load_model(exp_dir, exp_config, dm) 228 | 229 | trainer = pl.Trainer(gpus=int(torch.cuda.is_available())) 230 | 231 | ######################################## 232 | # inference # 233 | ######################################## 234 | 235 | seeds = ensure_list(args.test_mask_seed) 236 | mae = [] 237 | 238 | for seed in seeds: 239 | # Change evaluation mask 240 | update_test_eval_mask(dm, dataset, args.p_fault, args.p_noise, seed) 241 | 242 | output = trainer.predict(imputer, dataloaders=dm.test_dataloader()) 243 | output = casting.numpy(output) 244 | y_hat, y_true, mask = output['y_hat'].squeeze(-1), \ 245 | output['y'].squeeze(-1), \ 246 | output['mask'].squeeze(-1) 247 | 248 | check_mae = numpy_metrics.masked_mae(y_hat, y_true, mask) 249 | mae.append(check_mae) 250 | print(f'SEED {seed} - Test MAE: {check_mae:.2f}') 251 | 252 | print(f'MAE over {len(seeds)} runs: {np.mean(mae):.2f}±{np.std(mae):.2f}') 253 | 254 | 255 | if __name__ == '__main__': 256 | args = parse_args() 257 | run_experiment(args) 258 | -------------------------------------------------------------------------------- /paper_neurips.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/paper_neurips.pdf -------------------------------------------------------------------------------- /poster_neurips.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/poster_neurips.pdf -------------------------------------------------------------------------------- /sparse_att.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/sparse_att.png -------------------------------------------------------------------------------- /spin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/spin/__init__.py -------------------------------------------------------------------------------- /spin/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from .brits import BRITS 2 | from .saits import SAITS 3 | from .transformer import TransformerModel 4 | -------------------------------------------------------------------------------- /spin/baselines/brits/__init__.py: -------------------------------------------------------------------------------- 1 | from .brits import BRITS 2 | -------------------------------------------------------------------------------- /spin/baselines/brits/brits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | from tsl.nn.functional import reverse_tensor 5 | 6 | from .layers import RITS 7 | 8 | 9 | class BRITS(nn.Module): 10 | 11 | def __init__(self, input_size: int, n_nodes: int, hidden_size: int = 64): 12 | super().__init__() 13 | self.n_nodes = n_nodes 14 | self.rits_fwd = RITS(input_size * n_nodes, hidden_size) 15 | self.rits_bwd = RITS(input_size * n_nodes, hidden_size) 16 | 17 | def forward(self, x, mask=None): 18 | # tsl shape to original shape: [b s n c] -> [b s c] 19 | x = rearrange(x, 'b s n c -> b s (n c)') 20 | mask = rearrange(mask, 'b s n c -> b s (n c)') 21 | # forward 22 | imp_fwd, pred_fwd = self.rits_fwd(x, mask) 23 | # backward 24 | x_bwd = reverse_tensor(x, dim=1) 25 | mask_bwd = reverse_tensor(mask, dim=1) if mask is not None else None 26 | imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd) 27 | imp_bwd, pred_bwd = reverse_tensor(imp_bwd, dim=1), \ 28 | [reverse_tensor(pb, dim=1) for pb in pred_bwd] 29 | # stack into shape = [batch, directions, steps, features] 30 | imputation = (imp_fwd + imp_bwd) / 2 31 | predictions = [imp_fwd, imp_bwd] + pred_fwd + pred_bwd 32 | 33 | imputation = rearrange(imputation, 'b s (n c) -> b s n c', 34 | n=self.n_nodes) 35 | predictions = [rearrange(pred, 'b s (n c) -> b s n c', n=self.n_nodes) 36 | for pred in predictions] 37 | 38 | return imputation, predictions 39 | 40 | @staticmethod 41 | def consistency_loss(imp_fwd, imp_bwd): 42 | loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean() 43 | return loss 44 | 45 | @staticmethod 46 | def add_model_specific_args(parser): 47 | parser.add_argument('--hidden-size', type=int, default=64) 48 | return parser 49 | -------------------------------------------------------------------------------- /spin/baselines/brits/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FeatureRegression(nn.Module): 9 | def __init__(self, input_size): 10 | super(FeatureRegression, self).__init__() 11 | self.W = nn.Parameter(torch.Tensor(input_size, input_size)) 12 | self.b = nn.Parameter(torch.Tensor(input_size)) 13 | 14 | m = 1 - torch.eye(input_size, input_size) 15 | self.register_buffer('m', m) 16 | 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | stdv = 1. / math.sqrt(self.W.shape[0]) 21 | self.W.data.uniform_(-stdv, stdv) 22 | self.b.data.uniform_(-stdv, stdv) 23 | 24 | def forward(self, x): 25 | z_h = F.linear(x, self.W * self.m, self.b) 26 | return z_h 27 | 28 | 29 | class TemporalDecay(nn.Module): 30 | def __init__(self, d_in, d_out, diag=False): 31 | super(TemporalDecay, self).__init__() 32 | self.diag = diag 33 | self.W = nn.Parameter(torch.Tensor(d_out, d_in)) 34 | self.b = nn.Parameter(torch.Tensor(d_out)) 35 | 36 | if self.diag: 37 | assert (d_in == d_out) 38 | m = torch.eye(d_in, d_in) 39 | self.register_buffer('m', m) 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | stdv = 1. / math.sqrt(self.W.shape[0]) 45 | self.W.data.uniform_(-stdv, stdv) 46 | self.b.data.uniform_(-stdv, stdv) 47 | 48 | @staticmethod 49 | def compute_delta(mask, freq=1): 50 | delta = torch.zeros_like(mask).float() 51 | one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device) 52 | for i in range(1, delta.shape[-2]): 53 | m = mask[..., i - 1, :] 54 | delta[..., i, :] = m * one_step + (1 - m) * torch.add( 55 | delta[..., i - 1, :], freq) 56 | return delta 57 | 58 | def forward(self, d): 59 | if self.diag: 60 | gamma = F.relu(F.linear(d, self.W * self.m, self.b)) 61 | else: 62 | gamma = F.relu(F.linear(d, self.W, self.b)) 63 | gamma = torch.exp(-gamma) 64 | return gamma 65 | 66 | 67 | class RITS(nn.Module): 68 | def __init__(self, 69 | input_size, 70 | hidden_size=64): 71 | super(RITS, self).__init__() 72 | self.input_size = int(input_size) 73 | self.hidden_size = int(hidden_size) 74 | 75 | self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size) 76 | 77 | self.temp_decay_h = TemporalDecay(d_in=self.input_size, 78 | d_out=self.hidden_size, diag=False) 79 | self.temp_decay_x = TemporalDecay(d_in=self.input_size, 80 | d_out=self.input_size, diag=True) 81 | 82 | self.hist_reg = nn.Linear(self.hidden_size, self.input_size) 83 | self.feat_reg = FeatureRegression(self.input_size) 84 | 85 | self.weight_combine = nn.Linear(2 * self.input_size, self.input_size) 86 | 87 | def init_hidden_states(self, x): 88 | return torch.zeros((x.shape[0], self.hidden_size)).to(x.device) 89 | 90 | def forward(self, x, mask=None, delta=None): 91 | # x : [batch, steps, features] 92 | steps = x.shape[-2] 93 | 94 | if mask is None: 95 | mask = torch.ones_like(x, dtype=torch.uint8) 96 | if delta is None: 97 | delta = TemporalDecay.compute_delta(mask) 98 | 99 | # init rnn states 100 | h = self.init_hidden_states(x) 101 | c = self.init_hidden_states(x) 102 | 103 | imputation = [] 104 | predictions = [] 105 | for step in range(steps): 106 | d = delta[:, step, :] 107 | m = mask[:, step, :] 108 | x_s = x[:, step, :] 109 | 110 | gamma_h = self.temp_decay_h(d) 111 | 112 | # history prediction 113 | x_h = self.hist_reg(h) 114 | x_c = m * x_s + (1 - m) * x_h 115 | h = h * gamma_h 116 | 117 | # feature prediction 118 | z_h = self.feat_reg(x_c) 119 | 120 | # predictions combination 121 | gamma_x = self.temp_decay_x(d) 122 | alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1)) 123 | alpha = torch.sigmoid(alpha) 124 | c_h = alpha * z_h + (1 - alpha) * x_h 125 | 126 | c_c = m * x_s + (1 - m) * c_h 127 | inputs = torch.cat([c_c, m], dim=1) 128 | h, c = self.rnn_cell(inputs, (h, c)) 129 | 130 | imputation.append(c_h) 131 | predictions.append(torch.stack((z_h, x_h), dim=0)) 132 | 133 | # imputation -> [batch, steps, features] 134 | imputation = torch.stack(imputation, dim=-2) 135 | # predictions -> [predictions, batch, steps, features] 136 | predictions = torch.stack(predictions, dim=-2) 137 | 138 | return imputation, [*predictions] 139 | -------------------------------------------------------------------------------- /spin/baselines/saits/__init__.py: -------------------------------------------------------------------------------- 1 | from .saits import SAITS 2 | -------------------------------------------------------------------------------- /spin/baselines/saits/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | """scaled dot-product attention""" 10 | 11 | def __init__(self, temperature, attn_dropout=0.1): 12 | super().__init__() 13 | self.temperature = temperature 14 | self.dropout = nn.Dropout(attn_dropout) 15 | 16 | def forward(self, q, k, v, attn_mask=None): 17 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 18 | if attn_mask is not None: 19 | attn = attn.masked_fill(attn_mask == 1, -1e9) 20 | attn = self.dropout(F.softmax(attn, dim=-1)) 21 | output = torch.matmul(attn, v) 22 | return output, attn 23 | 24 | 25 | class MultiHeadAttention(nn.Module): 26 | """original Transformer multi-head attention""" 27 | 28 | def __init__(self, n_head, d_model, d_k, d_v, attn_dropout): 29 | super().__init__() 30 | 31 | self.n_head = n_head 32 | self.d_k = d_k 33 | self.d_v = d_v 34 | 35 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 36 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 37 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 38 | 39 | self.attention = ScaledDotProductAttention(d_k ** 0.5, attn_dropout) 40 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 41 | 42 | def forward(self, q, k, v, attn_mask=None): 43 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 44 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 45 | 46 | # Pass through the pre-attention projection: b x lq x (n*dv) 47 | # Separate different heads: b x lq x n x dv 48 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 49 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 50 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 51 | 52 | # Transpose for attention dot product: b x n x lq x dv 53 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 54 | 55 | if attn_mask is not None: 56 | # this mask is imputation mask, which is not generated from 57 | # each batch, so needs broadcasting on batch dim 58 | attn_mask = attn_mask.unsqueeze(0).unsqueeze(1) 59 | 60 | v, attn_weights = self.attention(q, k, v, attn_mask) 61 | 62 | # Transpose to move the head dimension back: b x lq x n x dv 63 | # Concatenate all the heads together: b x lq x (n*dv) 64 | v = v.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 65 | v = self.fc(v) 66 | return v, attn_weights 67 | 68 | 69 | class PositionWiseFeedForward(nn.Module): 70 | def __init__(self, d_in, d_hid, dropout=0.1): 71 | super().__init__() 72 | self.w_1 = nn.Linear(d_in, d_hid) 73 | self.w_2 = nn.Linear(d_hid, d_in) 74 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 75 | self.dropout = nn.Dropout(dropout) 76 | 77 | def forward(self, x): 78 | residual = x 79 | x = self.layer_norm(x) 80 | x = self.w_2(F.relu(self.w_1(x))) 81 | x = self.dropout(x) 82 | x += residual 83 | return x 84 | 85 | 86 | class EncoderLayer(nn.Module): 87 | def __init__(self, d_time, d_feature, d_model, d_inner, n_head, d_k, d_v, 88 | diagonal_attention_mask: bool = True, 89 | dropout: float = 0.1, 90 | attn_dropout: float = 0.1): 91 | super(EncoderLayer, self).__init__() 92 | 93 | self.diagonal_attention_mask = diagonal_attention_mask 94 | self.d_time = d_time 95 | self.d_feature = d_feature 96 | 97 | self.layer_norm = nn.LayerNorm(d_model) 98 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, 99 | attn_dropout) 100 | self.dropout = nn.Dropout(dropout) 101 | self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) 102 | 103 | def forward(self, x: Tensor): 104 | if self.diagonal_attention_mask: 105 | mask_time = torch.eye(self.d_time).to(x.device) 106 | else: 107 | mask_time = None 108 | 109 | residual = x 110 | # here we apply LN before attention cal, namely Pre-LN 111 | enc_input = self.layer_norm(x) 112 | enc_output, attn_weights = self.slf_attn(enc_input, enc_input, 113 | enc_input, attn_mask=mask_time) 114 | enc_output = self.dropout(enc_output) 115 | enc_output += residual 116 | 117 | enc_output = self.pos_ffn(enc_output) 118 | return enc_output, attn_weights 119 | 120 | 121 | class PositionalEncoding(nn.Module): 122 | def __init__(self, d_hid, n_position=200): 123 | super(PositionalEncoding, self).__init__() 124 | # Not a parameter 125 | pos_table = self._get_sinusoid_encoding_table(n_position, d_hid) 126 | self.register_buffer('pos_table', pos_table) 127 | 128 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 129 | """ Sinusoid position encoding table """ 130 | 131 | def get_position_angle_vec(position): 132 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for 133 | hid_j in range(d_hid)] 134 | 135 | sinusoid_table = np.array([get_position_angle_vec(pos_i) 136 | for pos_i in range(n_position)]) 137 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 138 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 139 | return torch.Tensor(sinusoid_table) 140 | 141 | def forward(self, x): 142 | return x + self.pos_table[:x.size(1)] 143 | -------------------------------------------------------------------------------- /spin/baselines/saits/saits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn, Tensor 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import inits 6 | 7 | from .layers import EncoderLayer, PositionalEncoding 8 | 9 | 10 | class TransformerEncoder(nn.Module): 11 | def __init__(self, n_groups, n_group_inner_layers, d_time, d_feature, 12 | d_model, d_inner, n_head, d_k, d_v, dropout, 13 | **kwargs): 14 | super().__init__() 15 | self.n_groups = n_groups 16 | self.n_group_inner_layers = n_group_inner_layers 17 | self.input_with_mask = kwargs['input_with_mask'] 18 | actual_d_feature = d_feature * 2 if self.input_with_mask else d_feature 19 | self.param_sharing_strategy = kwargs['param_sharing_strategy'] 20 | self.MIT = kwargs['MIT'] 21 | 22 | if self.param_sharing_strategy == 'between_group': 23 | # For between_group, only need to create 1 group and 24 | # repeat n_groups times while forwarding 25 | self.layer_stack = nn.ModuleList([ 26 | EncoderLayer(d_time, actual_d_feature, d_model, d_inner, n_head, 27 | d_k, d_v, dropout, dropout, **kwargs) 28 | for _ in range(n_group_inner_layers) 29 | ]) 30 | else: # then inner_group,inner_group is the way used in ALBERT 31 | # For inner_group, only need to create n_groups layers and 32 | # repeat n_group_inner_layers times in each group while forwarding 33 | self.layer_stack = nn.ModuleList([ 34 | EncoderLayer(d_time, actual_d_feature, d_model, d_inner, n_head, 35 | d_k, d_v, dropout, dropout, **kwargs) 36 | for _ in range(n_groups) 37 | ]) 38 | 39 | self.embedding = nn.Linear(actual_d_feature, d_model) 40 | self.position_enc = PositionalEncoding(d_model, n_position=d_time) 41 | self.dropout = nn.Dropout(p=dropout) 42 | self.reduce_dim = nn.Linear(d_model, d_feature) 43 | 44 | def impute(self, x, mask, **kwargs): 45 | # tsl shape to original shape: [b s n c=1] -> [b s c] 46 | is_bsnc = False 47 | if x.ndim == 4: 48 | is_bsnc = True 49 | x, mask = x.squeeze(-1), mask.squeeze(-1) 50 | 51 | # 1st DMSA ############################################################ 52 | 53 | # Cat mask (eventually) 54 | if self.input_with_mask: 55 | input_X = torch.cat([x, mask], dim=2) 56 | else: 57 | input_X = x 58 | input_X = self.embedding(input_X) 59 | enc_output = self.dropout(self.position_enc(input_X)) 60 | 61 | if self.param_sharing_strategy == 'between_group': 62 | for _ in range(self.n_groups): 63 | for encoder_layer in self.layer_stack: 64 | enc_output, _ = encoder_layer(enc_output) 65 | else: 66 | for encoder_layer in self.layer_stack: 67 | for _ in range(self.n_group_inner_layers): 68 | enc_output, _ = encoder_layer(enc_output) 69 | 70 | learned_presentation = self.reduce_dim(enc_output) 71 | # replace non-missing part with original data 72 | imputed_data = mask * x + (1 - mask) * learned_presentation 73 | 74 | if is_bsnc: 75 | imputed_data.unsqueeze_(-1), learned_presentation.unsqueeze_(-1) 76 | return imputed_data, learned_presentation 77 | 78 | 79 | class SAITS(nn.Module): 80 | def __init__(self, input_size: int, window_size: int, n_nodes: int, 81 | d_model: int = 256, 82 | d_inner: int = 128, 83 | n_head: int = 4, 84 | d_k: int = None, # or 64 85 | d_v: int = 64, 86 | n_groups: int = 2, 87 | n_group_inner_layers: int = 1, 88 | param_sharing_strategy: str = 'inner_group', 89 | dropout: float = 0.1, 90 | input_with_mask: bool = True, 91 | diagonal_attention_mask: bool = True, 92 | trainable_mask_token: bool = False): 93 | super().__init__() 94 | self.n_nodes = n_nodes 95 | self.input_size = input_size 96 | self.n_groups = n_groups 97 | self.n_group_inner_layers = n_group_inner_layers 98 | self.input_with_mask = input_with_mask 99 | self.param_sharing_strategy = param_sharing_strategy 100 | 101 | d_in = in_features = input_size * n_nodes 102 | if self.input_with_mask: 103 | d_in = 2 * d_in 104 | d_k = d_k or (d_model // n_head) # from the appendix 105 | 106 | if trainable_mask_token: 107 | self.mask_token = nn.Parameter(torch.Tensor(1, 1, in_features)) 108 | inits.uniform(in_features, self.mask_token) 109 | else: 110 | self.register_buffer('mask_token', torch.zeros(1, 1, in_features)) 111 | 112 | if self.param_sharing_strategy == 'between_group': 113 | # For between_group, only need to create 1 group and 114 | # repeat n_groups times while forwarding 115 | n_layers = n_group_inner_layers 116 | else: # then inner_group,inner_group is the way used in ALBERT 117 | # For inner_group, only need to create n_groups layers and 118 | # repeat n_group_inner_layers times in each group while forwarding 119 | n_layers = n_groups 120 | 121 | self.layer_stack_for_first_block = nn.ModuleList([ 122 | EncoderLayer(d_time=window_size, 123 | d_feature=d_in, 124 | d_model=d_model, 125 | d_inner=d_inner, 126 | n_head=n_head, 127 | d_k=d_k, 128 | d_v=d_v, 129 | dropout=dropout, 130 | attn_dropout=0, 131 | diagonal_attention_mask=diagonal_attention_mask) 132 | for _ in range(n_layers) 133 | ]) 134 | self.layer_stack_for_second_block = nn.ModuleList([ 135 | EncoderLayer(d_time=window_size, 136 | d_feature=d_in, 137 | d_model=d_model, 138 | d_inner=d_inner, 139 | n_head=n_head, 140 | d_k=d_k, 141 | d_v=d_v, 142 | dropout=dropout, 143 | attn_dropout=0, 144 | diagonal_attention_mask=diagonal_attention_mask) 145 | for _ in range(n_layers) 146 | ]) 147 | 148 | self.dropout = nn.Dropout(p=dropout) 149 | self.position_enc = PositionalEncoding(d_model, n_position=window_size) 150 | # for operation on time dim 151 | self.embedding_1 = nn.Linear(d_in, d_model) 152 | self.reduce_dim_z = nn.Linear(d_model, in_features) 153 | # for operation on measurement dim 154 | self.embedding_2 = nn.Linear(d_in, d_model) 155 | self.reduce_dim_beta = nn.Linear(d_model, in_features) 156 | self.reduce_dim_gamma = nn.Linear(in_features, in_features) 157 | # for delta decay factor 158 | self.weight_combine = nn.Linear(in_features + window_size, in_features) 159 | 160 | def forward(self, x: Tensor, mask: Tensor, **kwargs): 161 | # tsl shape to original shape: [b s n c] -> [b s c] 162 | x = rearrange(x, 'b s n c -> b s (n c)') 163 | mask = rearrange(mask, 'b s n c -> b s (n c)') 164 | # whiten missing values 165 | x = torch.where(mask.bool(), x, self.mask_token) 166 | 167 | # 1st DMSA ############################################################ 168 | 169 | # Cat mask (eventually) 170 | if self.input_with_mask: 171 | x_in = torch.cat([x, mask], dim=-1) 172 | else: 173 | x_in = x 174 | x_in = self.embedding_1(x_in) 175 | z = self.dropout(self.position_enc(x_in)) 176 | 177 | # Encode (deeply?) 178 | if self.param_sharing_strategy == 'between_group': 179 | for _ in range(self.n_groups): 180 | for encoder_layer in self.layer_stack_for_first_block: 181 | z, _ = encoder_layer(z) 182 | else: 183 | for encoder_layer in self.layer_stack_for_first_block: 184 | for _ in range(self.n_group_inner_layers): 185 | z, _ = encoder_layer(z) 186 | 187 | x_tilde_1 = self.reduce_dim_z(z) 188 | x_hat_1 = mask * x + (1 - mask) * x_tilde_1 189 | 190 | # 2nd DMSA ############################################################ 191 | 192 | # Cat mask (eventually) 193 | if self.input_with_mask: 194 | x_in = torch.cat([x_hat_1, mask], dim=-1) 195 | else: 196 | x_in = x_hat_1 197 | x_in = self.embedding_2(x_in) 198 | z = self.position_enc(x_in) 199 | 200 | # Encode 201 | if self.param_sharing_strategy == 'between_group': 202 | for _ in range(self.n_groups): 203 | for encoder_layer in self.layer_stack_for_second_block: 204 | z, attn_weights = encoder_layer(z) 205 | else: 206 | for encoder_layer in self.layer_stack_for_second_block: 207 | for _ in range(self.n_group_inner_layers): 208 | z, attn_weights = encoder_layer(z) 209 | 210 | x_tilde_2 = self.reduce_dim_gamma(F.relu(self.reduce_dim_beta(z))) 211 | 212 | # Average attention heads 213 | if attn_weights.size(1) > 1: 214 | attn_weights = attn_weights.mean(dim=1) 215 | 216 | weights = torch.cat([mask, attn_weights], dim=2) 217 | weights = F.sigmoid(self.weight_combine(weights)) 218 | # combine x_tilde_1 and X_tilde_2 219 | # x_tilde_3 = (1 - weights) * x_tilde_2 + weights * x_tilde_1 220 | x_hat = torch.lerp(x_tilde_2, x_tilde_1, weights) 221 | # replace non-missing part with original data 222 | x_tilde = [x_tilde_1, x_tilde_2, x_hat] 223 | 224 | # restore original shape 225 | x_hat = rearrange(x_hat, 'b s (n c) -> b s n c', n=self.n_nodes) 226 | x_tilde = [rearrange(tens, 'b s (n c) -> b s n c', n=self.n_nodes) 227 | for tens in x_tilde] 228 | 229 | return x_hat, x_tilde 230 | 231 | @staticmethod 232 | def add_model_specific_args(parser): 233 | parser.opt_list('--d-model', type=int, default=256, tunable=True, 234 | options=[64, 128, 256, 512, 1024]) 235 | parser.opt_list('--d-inner', type=int, default=128, tunable=True, 236 | options=[128, 256, 512, 1024, 2048, 4096]) 237 | parser.opt_list('--n-head', type=int, default=4, tunable=True, 238 | options=[2, 4, 8]) 239 | parser.add_argument('--d-k', type=int, default=None) 240 | parser.opt_list('--d-v', type=int, default=64, tunable=True, 241 | options=[64, 128, 256, 512]) 242 | parser.add_argument('--dropout', type=float, default=0.1) 243 | # 244 | parser.opt_list('--n-groups', type=int, default=2, tunable=True, 245 | options=[1, 2, 4, 6, 8]) 246 | parser.add_argument('--n-group-inner-layers', type=int, default=1) 247 | parser.add_argument('--param-sharing-strategy', type=str, 248 | default='inner_group') 249 | # 250 | parser.add_argument('--input-with-mask', type=bool, default=True) 251 | parser.add_argument('--diagonal-attention-mask', type=bool, 252 | default=True) 253 | parser.add_argument('--trainable-mask-token', type=bool, 254 | default=False) 255 | return parser 256 | -------------------------------------------------------------------------------- /spin/baselines/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import TransformerModel 2 | -------------------------------------------------------------------------------- /spin/baselines/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from tsl.nn.base import StaticGraphEmbedding 3 | from tsl.nn.blocks.encoders.mlp import MLP 4 | from tsl.nn.blocks.encoders.transformer import SpatioTemporalTransformerLayer, \ 5 | TransformerLayer 6 | from tsl.nn.layers import PositionalEncoding 7 | from tsl.utils.parser_utils import ArgParser, str_to_bool 8 | 9 | 10 | class TransformerModel(nn.Module): 11 | r"""Spatiotemporal Transformer for multivariate time series imputation. 12 | 13 | Args: 14 | input_size (int): Input size. 15 | hidden_size (int): Dimension of the learned representations. 16 | output_size (int): Dimension of the output. 17 | ff_size (int): Units in the MLP after self attention. 18 | u_size (int): Dimension of the exogenous variables. 19 | n_heads (int, optional): Number of parallel attention heads. 20 | n_layers (int, optional): Number of layers. 21 | dropout (float, optional): Dropout probability. 22 | axis (str, optional): Dimension on which to apply attention to update 23 | the representations. 24 | activation (str, optional): Activation function. 25 | """ 26 | 27 | def __init__(self, 28 | input_size: int, 29 | hidden_size: int, 30 | output_size: int, 31 | ff_size: int, 32 | u_size: int, 33 | n_heads: int = 1, 34 | n_layers: int = 1, 35 | dropout: float = 0., 36 | condition_on_u: bool = True, 37 | axis: str = 'both', 38 | activation: str = 'elu'): 39 | super(TransformerModel, self).__init__() 40 | 41 | self.condition_on_u = condition_on_u 42 | if condition_on_u: 43 | self.u_enc = MLP(u_size, hidden_size, n_layers=2) 44 | self.h_enc = MLP(input_size, hidden_size, n_layers=2) 45 | 46 | self.mask_token = StaticGraphEmbedding(1, hidden_size) 47 | 48 | self.pe = PositionalEncoding(hidden_size) 49 | 50 | kwargs = dict(input_size=hidden_size, 51 | hidden_size=hidden_size, 52 | ff_size=ff_size, 53 | n_heads=n_heads, 54 | activation=activation, 55 | causal=False, 56 | dropout=dropout) 57 | 58 | if axis in ['steps', 'nodes']: 59 | transformer_layer = TransformerLayer 60 | kwargs['axis'] = axis 61 | elif axis == 'both': 62 | transformer_layer = SpatioTemporalTransformerLayer 63 | else: 64 | raise ValueError(f'"{axis}" is not a valid axis.') 65 | 66 | self.encoder = nn.ModuleList() 67 | self.readout = nn.ModuleList() 68 | for _ in range(n_layers): 69 | self.encoder.append(transformer_layer(**kwargs)) 70 | self.readout.append(MLP(input_size=hidden_size, 71 | hidden_size=ff_size, 72 | output_size=output_size, 73 | n_layers=2, 74 | dropout=dropout)) 75 | 76 | def forward(self, x, u, mask): 77 | # x: [batches steps nodes features] 78 | # u: [batches steps (nodes) features] 79 | x = x * mask 80 | 81 | h = self.h_enc(x) 82 | h = mask * h + (1 - mask) * self.mask_token() 83 | 84 | if self.condition_on_u: 85 | h = h + self.u_enc(u).unsqueeze(-2) 86 | 87 | h = self.pe(h) 88 | 89 | out = [] 90 | for encoder, mlp in zip(self.encoder, self.readout): 91 | h = encoder(h) 92 | out.append(mlp(h)) 93 | 94 | x_hat = out.pop(-1) 95 | return x_hat, out 96 | 97 | @staticmethod 98 | def add_model_specific_args(parser: ArgParser): 99 | parser.opt_list('--hidden-size', type=int, default=32, tunable=True, 100 | options=[16, 32, 64, 128, 256]) 101 | parser.opt_list('--ff-size', type=int, default=32, tunable=True, 102 | options=[32, 64, 128, 256, 512, 1024]) 103 | parser.opt_list('--n-layers', type=int, default=1, tunable=True, 104 | options=[1, 2, 3]) 105 | parser.opt_list('--n-heads', type=int, default=1, tunable=True, 106 | options=[1, 2, 3]) 107 | parser.opt_list('--dropout', type=float, default=0., tunable=True, 108 | options=[0., 0.1, 0.25, 0.5]) 109 | parser.add_argument('--condition-on-u', type=str_to_bool, nargs='?', 110 | const=True, default=True) 111 | parser.opt_list('--axis', type=str, default='both', tunable=True, 112 | options=['steps', 'both']) 113 | return parser 114 | -------------------------------------------------------------------------------- /spin/imputers/__init__.py: -------------------------------------------------------------------------------- 1 | from .brits_imputer import BRITSImputer 2 | from .saits_imputer import SAITSImputer 3 | from .spin_imputer import SPINImputer 4 | -------------------------------------------------------------------------------- /spin/imputers/brits_imputer.py: -------------------------------------------------------------------------------- 1 | from tsl.imputers import Imputer 2 | 3 | from ..baselines import BRITS 4 | 5 | 6 | class BRITSImputer(Imputer): 7 | 8 | def shared_step(self, batch, mask): 9 | y = y_loss = batch.y 10 | y_hat = y_hat_loss = self.predict_batch(batch, preprocess=False, 11 | postprocess=not self.scale_target) 12 | 13 | if self.scale_target: 14 | y_loss = batch.transform['y'].transform(y) 15 | y_hat = batch.transform['y'].inverse_transform(y_hat) 16 | 17 | y_hat_loss, y_loss, mask = self.trim_warm_up(y_hat_loss, y_loss, mask) 18 | 19 | imputation, predictions = y_hat_loss 20 | imp_fwd, imp_bwd = predictions[:2] 21 | y_hat = y_hat[0] 22 | 23 | loss = sum([self.loss_fn(pred, y_loss, mask) for pred in predictions]) 24 | loss += BRITS.consistency_loss(imp_fwd, imp_bwd) 25 | 26 | return y_hat.detach(), y, loss 27 | -------------------------------------------------------------------------------- /spin/imputers/saits_imputer.py: -------------------------------------------------------------------------------- 1 | from tsl.imputers import Imputer 2 | 3 | 4 | class SAITSImputer(Imputer): 5 | 6 | def shared_step(self, batch, mask): 7 | y = y_loss = batch.y 8 | y_hat = y_hat_loss = self.predict_batch(batch, preprocess=False, 9 | postprocess=not self.scale_target) 10 | 11 | if self.scale_target: 12 | y_loss = batch.transform['y'].transform(y) 13 | y_hat = batch.transform['y'].inverse_transform(y_hat) 14 | 15 | y_hat_loss, y_loss, mask = self.trim_warm_up(y_hat_loss, y_loss, mask) 16 | 17 | if isinstance(y_hat_loss, (list, tuple)): 18 | imputation, predictions = y_hat_loss 19 | y_hat = y_hat[0] 20 | else: 21 | imputation, predictions = y_hat_loss, [] 22 | 23 | # Imputation loss 24 | if self.training: 25 | injected_missing = batch.original_mask - batch.mask 26 | mask = batch.mask 27 | loss = self.loss_fn(imputation, y_loss, injected_missing) 28 | else: 29 | loss = 0 30 | 31 | # Reconstruction loss 32 | for pred in predictions: 33 | pred_loss = self.loss_fn(pred, y_loss, mask) 34 | loss += self.prediction_loss_weight * pred_loss / 3 35 | 36 | return y_hat.detach(), y, loss 37 | -------------------------------------------------------------------------------- /spin/imputers/spin_imputer.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Mapping, Callable, Optional, Union, List 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | from tsl.imputers import Imputer 6 | from tsl.predictors import Predictor 7 | 8 | from ..utils import k_hop_subgraph_sampler 9 | 10 | 11 | class SPINImputer(Imputer): 12 | 13 | def __init__(self, 14 | model_class: Type, 15 | model_kwargs: Mapping, 16 | optim_class: Type, 17 | optim_kwargs: Mapping, 18 | loss_fn: Callable, 19 | scale_target: bool = True, 20 | whiten_prob: Union[float, List[float]] = 0.2, 21 | n_roots_subgraph: Optional[int] = None, 22 | n_hops: int = 2, 23 | max_edges_subgraph: Optional[int] = 1000, 24 | cut_edges_uniformly: bool = False, 25 | prediction_loss_weight: float = 1.0, 26 | metrics: Optional[Mapping[str, Metric]] = None, 27 | scheduler_class: Optional = None, 28 | scheduler_kwargs: Optional[Mapping] = None): 29 | super(SPINImputer, self).__init__(model_class=model_class, 30 | model_kwargs=model_kwargs, 31 | optim_class=optim_class, 32 | optim_kwargs=optim_kwargs, 33 | loss_fn=loss_fn, 34 | scale_target=scale_target, 35 | whiten_prob=whiten_prob, 36 | prediction_loss_weight=prediction_loss_weight, 37 | metrics=metrics, 38 | scheduler_class=scheduler_class, 39 | scheduler_kwargs=scheduler_kwargs) 40 | self.n_roots = n_roots_subgraph 41 | self.n_hops = n_hops 42 | self.max_edges_subgraph = max_edges_subgraph 43 | self.cut_edges_uniformly = cut_edges_uniformly 44 | 45 | def on_after_batch_transfer(self, batch, dataloader_idx): 46 | if self.training and self.n_roots is not None: 47 | batch = k_hop_subgraph_sampler(batch, self.n_hops, self.n_roots, 48 | max_edges=self.max_edges_subgraph, 49 | cut_edges_uniformly=self.cut_edges_uniformly) 50 | return super(SPINImputer, self).on_after_batch_transfer(batch, 51 | dataloader_idx) 52 | 53 | def training_step(self, batch, batch_idx): 54 | injected_missing = (batch.original_mask - batch.mask) 55 | if 'target_nodes' in batch: 56 | injected_missing = injected_missing[..., batch.target_nodes, :] 57 | # batch.input.target_mask = injected_missing 58 | y_hat, y, loss = self.shared_step(batch, mask=injected_missing) 59 | 60 | # Logging 61 | self.train_metrics.update(y_hat, y, batch.eval_mask) 62 | self.log_metrics(self.train_metrics, batch_size=batch.batch_size) 63 | self.log_loss('train', loss, batch_size=batch.batch_size) 64 | if 'target_nodes' in batch: 65 | torch.cuda.empty_cache() 66 | return loss 67 | 68 | def validation_step(self, batch, batch_idx): 69 | # batch.input.target_mask = batch.eval_mask 70 | y_hat, y, val_loss = self.shared_step(batch, batch.eval_mask) 71 | 72 | # Logging 73 | self.val_metrics.update(y_hat, y, batch.eval_mask) 74 | self.log_metrics(self.val_metrics, batch_size=batch.batch_size) 75 | self.log_loss('val', val_loss, batch_size=batch.batch_size) 76 | return val_loss 77 | 78 | def test_step(self, batch, batch_idx): 79 | # batch.input.target_mask = batch.eval_mask 80 | # Compute outputs and rescale 81 | y_hat = self.predict_batch(batch, preprocess=False, postprocess=True) 82 | 83 | if isinstance(y_hat, (list, tuple)): 84 | y_hat = y_hat[0] 85 | 86 | y, eval_mask = batch.y, batch.eval_mask 87 | test_loss = self.loss_fn(y_hat, y, eval_mask) 88 | 89 | # Logging 90 | self.test_metrics.update(y_hat.detach(), y, eval_mask) 91 | self.log_metrics(self.test_metrics, batch_size=batch.batch_size) 92 | self.log_loss('test', test_loss, batch_size=batch.batch_size) 93 | return test_loss 94 | 95 | @staticmethod 96 | def add_argparse_args(parser, **kwargs): 97 | parser = Predictor.add_argparse_args(parser) 98 | parser.add_argument('--whiten-prob', type=float, default=0.05) 99 | parser.add_argument('--prediction-loss-weight', type=float, default=1.0) 100 | parser.add_argument('--n-roots-subgraph', type=int, default=None) 101 | parser.add_argument('--n-hops', type=int, default=2) 102 | parser.add_argument('--max-edges-subgraph', type=int, default=1000) 103 | parser.add_argument('--cut-edges-uniformly', type=bool, default=False) 104 | return parser 105 | -------------------------------------------------------------------------------- /spin/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .postional_encoding import PositionalEncoder 2 | from .additive_attention import AdditiveAttention 3 | from .temporal_graph_additive_attention import TemporalAdditiveAttention, \ 4 | TemporalGraphAdditiveAttention 5 | from .hierarchical_temporal_graph_attention import \ 6 | HierarchicalTemporalGraphAttention 7 | -------------------------------------------------------------------------------- /spin/layers/additive_attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | from torch.nn import LayerNorm, functional as F 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.dense.linear import Linear 9 | from torch_geometric.typing import Adj, OptTensor, PairTensor 10 | from torch_scatter import scatter 11 | from torch_scatter.utils import broadcast 12 | from tsl.nn.blocks.encoders import MLP 13 | from tsl.nn.functional import sparse_softmax 14 | 15 | 16 | class AdditiveAttention(MessagePassing): 17 | def __init__(self, input_size: Union[int, Tuple[int, int]], 18 | output_size: int, 19 | msg_size: Optional[int] = None, 20 | msg_layers: int = 1, 21 | root_weight: bool = True, 22 | reweight: Optional[str] = None, 23 | norm: bool = True, 24 | dropout: float = 0.0, 25 | dim: int = -2, 26 | **kwargs): 27 | kwargs.setdefault('aggr', 'add') 28 | super().__init__(node_dim=dim, **kwargs) 29 | 30 | self.output_size = output_size 31 | if isinstance(input_size, int): 32 | self.src_size = self.tgt_size = input_size 33 | else: 34 | self.src_size, self.tgt_size = input_size 35 | 36 | self.msg_size = msg_size or self.output_size 37 | self.msg_layers = msg_layers 38 | 39 | assert reweight in ['softmax', 'l1', None] 40 | self.reweight = reweight 41 | 42 | self.root_weight = root_weight 43 | self.dropout = dropout 44 | 45 | # key bias is discarded in softmax 46 | self.lin_src = Linear(self.src_size, self.output_size, 47 | weight_initializer='glorot', 48 | bias_initializer='zeros') 49 | self.lin_tgt = Linear(self.tgt_size, self.output_size, 50 | weight_initializer='glorot', bias=False) 51 | 52 | if self.root_weight: 53 | self.lin_skip = Linear(self.tgt_size, self.output_size, 54 | bias=False) 55 | else: 56 | self.register_parameter('lin_skip', None) 57 | 58 | self.msg_nn = nn.Sequential( 59 | nn.PReLU(init=0.2), 60 | MLP(self.output_size, self.msg_size, self.output_size, 61 | n_layers=self.msg_layers, dropout=self.dropout, 62 | activation='prelu') 63 | ) 64 | 65 | if self.reweight == 'softmax': 66 | self.msg_gate = nn.Linear(self.output_size, 1, bias=False) 67 | else: 68 | self.msg_gate = nn.Sequential(nn.Linear(self.output_size, 1), 69 | nn.Sigmoid()) 70 | 71 | if norm: 72 | self.norm = LayerNorm(self.output_size) 73 | else: 74 | self.register_parameter('norm', None) 75 | 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self): 79 | self.lin_src.reset_parameters() 80 | self.lin_tgt.reset_parameters() 81 | if self.lin_skip is not None: 82 | self.lin_skip.reset_parameters() 83 | 84 | def forward(self, x: PairTensor, edge_index: Adj, mask: OptTensor = None): 85 | # if query/key not provided, defaults to x (e.g., for self-attention) 86 | if isinstance(x, Tensor): 87 | x_src = x_tgt = x 88 | else: 89 | x_src, x_tgt = x 90 | x_tgt = x_tgt if x_tgt is not None else x_src 91 | 92 | N_src, N_tgt = x_src.size(self.node_dim), x_tgt.size(self.node_dim) 93 | 94 | msg_src = self.lin_src(x_src) 95 | msg_tgt = self.lin_tgt(x_tgt) 96 | 97 | msg = (msg_src, msg_tgt) 98 | 99 | # propagate_type: (msg: PairTensor, mask: OptTensor) 100 | out = self.propagate(edge_index, msg=msg, mask=mask, 101 | size=(N_src, N_tgt)) 102 | 103 | # skip connection 104 | if self.root_weight: 105 | out = out + self.lin_skip(x_tgt) 106 | 107 | if self.norm is not None: 108 | out = self.norm(out) 109 | 110 | return out 111 | 112 | def normalize_weights(self, weights, index, num_nodes, mask=None): 113 | # mask weights 114 | if mask is not None: 115 | fill_value = float("-inf") if self.reweight == 'softmax' else 0. 116 | weights = weights.masked_fill(torch.logical_not(mask), fill_value) 117 | # eventually reweight 118 | if self.reweight == 'l1': 119 | expanded_index = broadcast(index, weights, self.node_dim) 120 | weights_sum = scatter(weights, expanded_index, self.node_dim, 121 | dim_size=num_nodes, reduce='sum') 122 | weights_sum = weights_sum.index_select(self.node_dim, index) 123 | weights = weights / (weights_sum + 1e-5) 124 | elif self.reweight == 'softmax': 125 | weights = sparse_softmax(weights, index, num_nodes=num_nodes, 126 | dim=self.node_dim) 127 | return weights 128 | 129 | def message(self, msg_j: Tensor, msg_i: Tensor, index, size_i, 130 | mask_j: OptTensor = None) -> Tensor: 131 | msg = self.msg_nn(msg_j + msg_i) 132 | gate = self.msg_gate(msg) 133 | alpha = self.normalize_weights(gate, index, size_i, mask_j) 134 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 135 | out = alpha * msg 136 | return out 137 | 138 | def __repr__(self) -> str: 139 | return (f'{self.__class__.__name__}({self.output_size}, ' 140 | f'dim={self.node_dim}, ' 141 | f'root_weight={self.root_weight})') 142 | 143 | 144 | class TemporalAdditiveAttention(AdditiveAttention): 145 | def __init__(self, input_size: Union[int, Tuple[int, int]], 146 | output_size: int, 147 | msg_size: Optional[int] = None, 148 | msg_layers: int = 1, 149 | root_weight: bool = True, 150 | reweight: Optional[str] = None, 151 | norm: bool = True, 152 | dropout: float = 0.0, 153 | **kwargs): 154 | kwargs.setdefault('dim', 1) 155 | super().__init__(input_size=input_size, 156 | output_size=output_size, 157 | msg_size=msg_size, 158 | msg_layers=msg_layers, 159 | root_weight=root_weight, 160 | reweight=reweight, 161 | dropout=dropout, 162 | norm=norm, 163 | **kwargs) 164 | 165 | def forward(self, x: PairTensor, mask: OptTensor = None, 166 | temporal_mask: OptTensor = None, 167 | causal_lag: Optional[int] = None): 168 | # x: [b s * c] query: [b l * c] key: [b s * c] 169 | # mask: [b s * c] temporal_mask: [l s] 170 | if isinstance(x, Tensor): 171 | x_src = x_tgt = x 172 | else: 173 | x_src, x_tgt = x 174 | x_tgt = x_tgt if x_tgt is not None else x_src 175 | 176 | l, s = x_tgt.size(self.node_dim), x_src.size(self.node_dim) 177 | i = torch.arange(l, dtype=torch.long, device=x_src.device) 178 | j = torch.arange(s, dtype=torch.long, device=x_src.device) 179 | 180 | # compute temporal index, from j to i 181 | if temporal_mask is None and isinstance(causal_lag, int): 182 | temporal_mask = tuple(torch.tril_indices(l, l, offset=-causal_lag, 183 | device=x_src.device)) 184 | if temporal_mask is not None: 185 | assert temporal_mask.size() == (l, s) 186 | i, j = torch.meshgrid(i, j) 187 | edge_index = torch.stack((j[temporal_mask], i[temporal_mask])) 188 | else: 189 | edge_index = torch.cartesian_prod(j, i).T 190 | 191 | return super(TemporalAdditiveAttention, self).forward(x, edge_index, 192 | mask=mask) 193 | -------------------------------------------------------------------------------- /spin/layers/hierarchical_temporal_graph_attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.nn.dense.linear import Linear 6 | from torch_geometric.typing import Adj, OptTensor 7 | from tsl.nn.functional import sparse_softmax 8 | from tsl.nn.layers.norm import LayerNorm 9 | 10 | from .additive_attention import TemporalAdditiveAttention 11 | 12 | 13 | class HierarchicalTemporalGraphAttention(MessagePassing): 14 | def __init__(self, h_size: int, z_size: int, 15 | msg_size: Optional[int] = None, 16 | msg_layers: int = 1, 17 | root_weight: bool = True, 18 | reweight: Optional[str] = None, 19 | update_z_cross: bool = True, 20 | mask_temporal: bool = True, 21 | mask_spatial: bool = True, 22 | norm: bool = True, 23 | dropout: float = 0., 24 | aggr: str = 'add', 25 | **kwargs): 26 | self.spatial_aggr = aggr 27 | if aggr == 'softmax': 28 | aggr = 'add' 29 | super(HierarchicalTemporalGraphAttention, self).__init__(node_dim=-2, 30 | aggr=aggr, 31 | **kwargs) 32 | 33 | # store dimensions 34 | self.h_size = h_size 35 | self.z_size = z_size 36 | self.msg_size = msg_size or z_size 37 | 38 | self.mask_temporal = mask_temporal 39 | self.mask_spatial = mask_spatial 40 | 41 | self.root_weight = root_weight 42 | self.norm = norm 43 | self.dropout = dropout 44 | self._z_cross = None 45 | 46 | self.zh_self = TemporalAdditiveAttention( 47 | input_size=(h_size, z_size), 48 | output_size=z_size, 49 | msg_size=msg_size, 50 | msg_layers=msg_layers, 51 | reweight=reweight, 52 | dropout=dropout, 53 | root_weight=True, 54 | norm=True 55 | ) 56 | 57 | self.hz_self = TemporalAdditiveAttention( 58 | input_size=(z_size, h_size), 59 | output_size=h_size, 60 | msg_size=msg_size, 61 | msg_layers=msg_layers, 62 | reweight=reweight, 63 | dropout=dropout, 64 | root_weight=True, 65 | norm=False 66 | ) 67 | 68 | if update_z_cross: 69 | self.zh_cross = TemporalAdditiveAttention( 70 | input_size=(h_size, z_size), 71 | output_size=z_size, 72 | msg_size=msg_size, 73 | msg_layers=msg_layers, 74 | reweight=reweight, 75 | dropout=dropout, 76 | root_weight=True, 77 | norm=True 78 | ) 79 | else: 80 | self.register_parameter('zh_cross', None) 81 | 82 | self.hz_cross = TemporalAdditiveAttention( 83 | input_size=(z_size, h_size), 84 | output_size=h_size, 85 | msg_size=msg_size, 86 | msg_layers=msg_layers, 87 | reweight=None, 88 | dropout=dropout, 89 | root_weight=True, 90 | norm=False 91 | ) 92 | 93 | if self.spatial_aggr == 'softmax': 94 | self.lin_alpha_h = Linear(h_size, 1, bias=False) 95 | self.lin_alpha_z = Linear(z_size, 1, bias=False) 96 | else: 97 | self.register_parameter('lin_alpha_h', None) 98 | self.register_parameter('lin_alpha_z', None) 99 | 100 | if self.root_weight: 101 | self.h_skip = Linear(h_size, h_size, bias_initializer='zeros') 102 | self.z_skip = Linear(z_size, z_size, bias_initializer='zeros') 103 | else: 104 | self.register_parameter('h_skip', None) 105 | self.register_parameter('z_skip', None) 106 | 107 | if self.norm: 108 | self.h_norm = LayerNorm(h_size) 109 | self.z_norm = LayerNorm(z_size) 110 | else: 111 | self.register_parameter('h_norm', None) 112 | self.register_parameter('z_norm', None) 113 | 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | self.zh_self.reset_parameters() 118 | self.hz_self.reset_parameters() 119 | if self.zh_cross is not None: 120 | self.zh_cross.reset_parameters() 121 | self.hz_cross.reset_parameters() 122 | if self.spatial_aggr == 'softmax': 123 | self.lin_alpha_h.reset_parameters() 124 | self.lin_alpha_z.reset_parameters() 125 | if self.root_weight: 126 | self.h_skip.reset_parameters() 127 | self.z_skip.reset_parameters() 128 | if self.norm: 129 | self.h_norm.reset_parameters() 130 | self.z_norm.reset_parameters() 131 | 132 | def forward(self, h: Tensor, z: Tensor, edge_index: Adj, 133 | mask: OptTensor = None): 134 | # inputs: [batch, steps, nodes, channels] 135 | 136 | z_out = self.zh_self(x=(h, z), 137 | mask=mask if self.mask_temporal else None) 138 | h_self = self.hz_self(x=(z_out, h)) 139 | 140 | # propagate query, key and value 141 | n_src, n_tgt = h.size(-2), z.size(-2) 142 | h_out = self.propagate(h=h_self, z=z_out, 143 | edge_index=edge_index, 144 | mask=mask if self.mask_spatial else None, 145 | size=(n_src, n_tgt)) 146 | 147 | if self._z_cross is not None: 148 | z_out = self.aggregate(self._z_cross, edge_index[1], dim_size=n_tgt) 149 | self._z_cross = None 150 | 151 | # skip connection 152 | if self.root_weight: 153 | h_out = h_out + self.h_skip(h) 154 | z_out = z_out + self.z_skip(z) 155 | 156 | if self.norm: 157 | h_out = self.h_norm(h_out) 158 | z_out = self.z_norm(z_out) 159 | 160 | return h_out, z_out 161 | 162 | def h_cross_message(self, h_i: Tensor, z_j: Tensor, index, size_i) -> Tensor: 163 | # [batch, steps, edges, channels] 164 | h_cross = self.hz_cross((z_j, h_i)) 165 | if self.spatial_aggr == 'softmax': 166 | alpha_h = self.lin_alpha_h(h_cross) 167 | alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i, 168 | dim=self.node_dim) 169 | h_cross = alpha_h * h_cross 170 | return h_cross 171 | 172 | def hz_cross_message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor, 173 | index, size_i, mask_j: OptTensor) -> Tensor: 174 | # [batch, steps, edges, channels] 175 | z_cross = self.zh_cross((h_j, z_i), mask=mask_j) 176 | h_cross = self.hz_cross((z_cross, h_i)) 177 | if self.spatial_aggr == 'softmax': 178 | # reweight z 179 | alpha_z = self.lin_alpha_z(z_cross) 180 | alpha_z = sparse_softmax(alpha_z, index, num_nodes=size_i, 181 | dim=self.node_dim) 182 | z_cross = alpha_z * z_cross 183 | # reweight h 184 | alpha_h = self.lin_alpha_h(h_cross) 185 | alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i, 186 | dim=self.node_dim) 187 | h_cross = alpha_h * h_cross 188 | self._z_cross = z_cross 189 | return h_cross 190 | 191 | def message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor, z_j: Tensor, 192 | index, size_i, mask_j: OptTensor) -> Tensor: 193 | if self.zh_cross is not None: 194 | return self.hz_cross_message(h_i, h_j, z_i, index, size_i, mask_j) 195 | return self.h_cross_message(h_i, z_j, index, size_i) 196 | -------------------------------------------------------------------------------- /spin/layers/postional_encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import nn 4 | from tsl.nn.base import StaticGraphEmbedding 5 | from tsl.nn.blocks.encoders import MLP 6 | from tsl.nn.layers import PositionalEncoding 7 | 8 | 9 | class PositionalEncoder(nn.Module): 10 | 11 | def __init__(self, in_channels, out_channels, 12 | n_layers: int = 1, 13 | n_nodes: Optional[int] = None): 14 | super(PositionalEncoder, self).__init__() 15 | self.lin = nn.Linear(in_channels, out_channels) 16 | self.activation = nn.LeakyReLU() 17 | self.mlp = MLP(out_channels, out_channels, out_channels, 18 | n_layers=n_layers, activation='relu') 19 | self.positional = PositionalEncoding(out_channels) 20 | if n_nodes is not None: 21 | self.node_emb = StaticGraphEmbedding(n_nodes, out_channels) 22 | else: 23 | self.register_parameter('node_emb', None) 24 | 25 | def forward(self, x, node_emb=None, node_index=None): 26 | if node_emb is None: 27 | node_emb = self.node_emb(token_index=node_index) 28 | # x: [b s c], node_emb: [n c] -> [b s n c] 29 | x = self.lin(x) 30 | x = self.activation(x.unsqueeze(-2) + node_emb) 31 | out = self.mlp(x) 32 | out = self.positional(out) 33 | return out 34 | -------------------------------------------------------------------------------- /spin/layers/temporal_graph_additive_attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.nn.dense.linear import Linear 7 | from torch_geometric.typing import Adj, OptTensor, OptPairTensor 8 | from tsl.nn.layers.norm import LayerNorm 9 | 10 | from .additive_attention import TemporalAdditiveAttention 11 | 12 | 13 | class TemporalGraphAdditiveAttention(MessagePassing): 14 | def __init__(self, input_size: Union[int, Tuple[int, int]], 15 | output_size: int, 16 | msg_size: Optional[int] = None, 17 | msg_layers: int = 1, 18 | root_weight: bool = True, 19 | reweight: Optional[str] = None, 20 | temporal_self_attention: bool = True, 21 | mask_temporal: bool = True, 22 | mask_spatial: bool = True, 23 | norm: bool = True, 24 | dropout: float = 0., 25 | **kwargs): 26 | kwargs.setdefault('aggr', 'add') 27 | super(TemporalGraphAdditiveAttention, self).__init__(node_dim=-2, 28 | **kwargs) 29 | 30 | # store dimensions 31 | if isinstance(input_size, int): 32 | self.src_size = self.tgt_size = input_size 33 | else: 34 | self.src_size, self.tgt_size = input_size 35 | self.output_size = output_size 36 | self.msg_size = msg_size or self.output_size 37 | 38 | self.mask_temporal = mask_temporal 39 | self.mask_spatial = mask_spatial 40 | 41 | self.root_weight = root_weight 42 | self.dropout = dropout 43 | 44 | if temporal_self_attention: 45 | self.self_attention = TemporalAdditiveAttention( 46 | input_size=input_size, 47 | output_size=output_size, 48 | msg_size=msg_size, 49 | msg_layers=msg_layers, 50 | reweight=reweight, 51 | dropout=dropout, 52 | root_weight=False, 53 | norm=False 54 | ) 55 | else: 56 | self.register_parameter('self_attention', None) 57 | 58 | self.cross_attention = TemporalAdditiveAttention(input_size=input_size, 59 | output_size=output_size, 60 | msg_size=msg_size, 61 | msg_layers=msg_layers, 62 | reweight=reweight, 63 | dropout=dropout, 64 | root_weight=False, 65 | norm=False) 66 | 67 | if self.root_weight: 68 | self.lin_skip = Linear(self.tgt_size, self.output_size, 69 | bias_initializer='zeros') 70 | else: 71 | self.register_parameter('lin_skip', None) 72 | 73 | if norm: 74 | self.norm = LayerNorm(output_size) 75 | else: 76 | self.register_parameter('norm', None) 77 | 78 | self.reset_parameters() 79 | 80 | def reset_parameters(self): 81 | self.cross_attention.reset_parameters() 82 | if self.self_attention is not None: 83 | self.self_attention.reset_parameters() 84 | if self.lin_skip is not None: 85 | self.lin_skip.reset_parameters() 86 | if self.norm is not None: 87 | self.norm.reset_parameters() 88 | 89 | def forward(self, x: OptPairTensor, 90 | edge_index: Adj, edge_weight: OptTensor = None, 91 | mask: OptTensor = None): 92 | # inputs: [batch, steps, nodes, channels] 93 | if isinstance(x, Tensor): 94 | x_src = x_tgt = x 95 | else: 96 | x_src, x_tgt = x 97 | x_tgt = x_tgt if x_tgt is not None else x_src 98 | 99 | n_src, n_tgt = x_src.size(-2), x_tgt.size(-2) 100 | 101 | # propagate query, key and value 102 | out = self.propagate(x=(x_src, x_tgt), 103 | edge_index=edge_index, edge_weight=edge_weight, 104 | mask=mask if self.mask_spatial else None, 105 | size=(n_src, n_tgt)) 106 | 107 | if self.self_attention is not None: 108 | s, l = x_src.size(1), x_tgt.size(1) 109 | if s == l: 110 | attn_mask = ~torch.eye(l, l, dtype=torch.bool, 111 | device=x_tgt.device) 112 | else: 113 | attn_mask = None 114 | temp = self.self_attention(x=(x_src, x_tgt), 115 | mask=mask if self.mask_temporal else None, 116 | temporal_mask=attn_mask) 117 | out = out + temp 118 | 119 | # skip connection 120 | if self.root_weight: 121 | out = out + self.lin_skip(x_tgt) 122 | 123 | if self.norm is not None: 124 | out = self.norm(out) 125 | 126 | return out 127 | 128 | def message(self, x_i: Tensor, x_j: Tensor, 129 | edge_weight: OptTensor, mask_j: OptTensor) -> Tensor: 130 | # [batch, steps, edges, channels] 131 | 132 | out = self.cross_attention((x_j, x_i), mask=mask_j) 133 | 134 | if edge_weight is not None: 135 | out = out * edge_weight.view(-1, 1) 136 | return out 137 | -------------------------------------------------------------------------------- /spin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .spin import SPINModel 2 | from .spin_hierarchical import SPINHierarchicalModel 3 | -------------------------------------------------------------------------------- /spin/models/spin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.nn import LayerNorm 6 | from torch_geometric.typing import OptTensor 7 | from tsl.nn.base import StaticGraphEmbedding 8 | from tsl.nn.blocks.encoders import MLP 9 | 10 | from ..layers import PositionalEncoder, TemporalGraphAdditiveAttention 11 | 12 | 13 | class SPINModel(nn.Module): 14 | 15 | def __init__(self, input_size: int, 16 | hidden_size: int, 17 | n_nodes: int, 18 | u_size: Optional[int] = None, 19 | output_size: Optional[int] = None, 20 | temporal_self_attention: bool = True, 21 | reweight: Optional[str] = 'softmax', 22 | n_layers: int = 4, 23 | eta: int = 3, 24 | message_layers: int = 1): 25 | super(SPINModel, self).__init__() 26 | 27 | u_size = u_size or input_size 28 | output_size = output_size or input_size 29 | self.n_nodes = n_nodes 30 | self.n_layers = n_layers 31 | self.eta = eta 32 | self.temporal_self_attention = temporal_self_attention 33 | 34 | self.u_enc = PositionalEncoder(in_channels=u_size, 35 | out_channels=hidden_size, 36 | n_layers=2, 37 | n_nodes=n_nodes) 38 | 39 | self.h_enc = MLP(input_size, hidden_size, n_layers=2) 40 | self.h_norm = LayerNorm(hidden_size) 41 | 42 | self.valid_emb = StaticGraphEmbedding(n_nodes, hidden_size) 43 | self.mask_emb = StaticGraphEmbedding(n_nodes, hidden_size) 44 | 45 | self.x_skip = nn.ModuleList() 46 | self.encoder, self.readout = nn.ModuleList(), nn.ModuleList() 47 | for l in range(n_layers): 48 | x_skip = nn.Linear(input_size, hidden_size) 49 | encoder = TemporalGraphAdditiveAttention( 50 | input_size=hidden_size, 51 | output_size=hidden_size, 52 | msg_size=hidden_size, 53 | msg_layers=message_layers, 54 | temporal_self_attention=temporal_self_attention, 55 | reweight=reweight, 56 | mask_temporal=True, 57 | mask_spatial=l < eta, 58 | norm=True, 59 | root_weight=True, 60 | dropout=0.0 61 | ) 62 | readout = MLP(hidden_size, hidden_size, output_size, 63 | n_layers=2) 64 | self.x_skip.append(x_skip) 65 | self.encoder.append(encoder) 66 | self.readout.append(readout) 67 | 68 | def forward(self, x: Tensor, u: Tensor, mask: Tensor, 69 | edge_index: Tensor, edge_weight: OptTensor = None, 70 | node_index: OptTensor = None, target_nodes: OptTensor = None): 71 | if target_nodes is None: 72 | target_nodes = slice(None) 73 | 74 | # Whiten missing values 75 | x = x * mask 76 | 77 | # POSITIONAL ENCODING ################################################# 78 | # Obtain spatio-temporal positional encoding for every node-step pair # 79 | # in both observed and target sets. Encoding are obtained by jointly # 80 | # processing node and time positional encoding. # 81 | 82 | # Build (node, timestamp) encoding 83 | q = self.u_enc(u, node_index=node_index) 84 | # Condition value on key 85 | h = self.h_enc(x) + q 86 | 87 | # ENCODER ############################################################# 88 | # Obtain representations h^i_t for every (i, t) node-step pair by # 89 | # only taking into account valid data in representation set. # 90 | 91 | # Replace H in missing entries with queries Q 92 | h = torch.where(mask.bool(), h, q) 93 | # Normalize features 94 | h = self.h_norm(h) 95 | 96 | imputations = [] 97 | 98 | for l in range(self.n_layers): 99 | if l == self.eta: 100 | # Condition H on two different embeddings to distinguish 101 | # valid values from masked ones 102 | valid = self.valid_emb(token_index=node_index) 103 | masked = self.mask_emb(token_index=node_index) 104 | h = torch.where(mask.bool(), h + valid, h + masked) 105 | # Masked Temporal GAT for encoding representation 106 | h = h + self.x_skip[l](x) * mask # skip connection for valid x 107 | h = self.encoder[l](h, edge_index, mask=mask) 108 | # Read from H to get imputations 109 | target_readout = self.readout[l](h[..., target_nodes, :]) 110 | imputations.append(target_readout) 111 | 112 | # Get final layer imputations 113 | x_hat = imputations.pop(-1) 114 | 115 | return x_hat, imputations 116 | 117 | @staticmethod 118 | def add_model_specific_args(parser): 119 | parser.opt_list('--hidden-size', type=int, tunable=True, default=32, 120 | options=[32, 64, 128, 256]) 121 | parser.add_argument('--u-size', type=int, default=None) 122 | parser.add_argument('--output-size', type=int, default=None) 123 | parser.add_argument('--temporal-self-attention', type=bool, 124 | default=True) 125 | parser.add_argument('--reweight', type=str, default='softmax') 126 | parser.add_argument('--n-layers', type=int, default=4) 127 | parser.add_argument('--eta', type=int, default=3) 128 | parser.add_argument('--message-layers', type=int, default=1) 129 | return parser 130 | -------------------------------------------------------------------------------- /spin/models/spin_hierarchical.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.nn import LayerNorm 6 | from torch_geometric.nn import inits 7 | from torch_geometric.typing import OptTensor 8 | from tsl.nn.base import StaticGraphEmbedding 9 | from tsl.nn.blocks.encoders import MLP 10 | 11 | from ..layers import PositionalEncoder, HierarchicalTemporalGraphAttention 12 | 13 | 14 | class SPINHierarchicalModel(nn.Module): 15 | 16 | def __init__(self, input_size: int, 17 | h_size: int, 18 | z_size: int, 19 | n_nodes: int, 20 | z_heads: int = 1, 21 | u_size: Optional[int] = None, 22 | output_size: Optional[int] = None, 23 | n_layers: int = 5, 24 | eta: int = 3, 25 | message_layers: int = 1, 26 | reweight: Optional[str] = 'softmax', 27 | update_z_cross: bool = True, 28 | norm: bool = True, 29 | spatial_aggr: str = 'add'): 30 | super(SPINHierarchicalModel, self).__init__() 31 | 32 | u_size = u_size or input_size 33 | output_size = output_size or input_size 34 | self.h_size = h_size 35 | self.z_size = z_size 36 | 37 | self.n_nodes = n_nodes 38 | self.z_heads = z_heads 39 | self.n_layers = n_layers 40 | self.eta = eta 41 | 42 | self.v = StaticGraphEmbedding(n_nodes, h_size) 43 | self.lin_v = nn.Linear(h_size, z_size, bias=False) 44 | self.z = nn.Parameter(torch.Tensor(1, z_heads, n_nodes, z_size)) 45 | inits.uniform(z_size, self.z) 46 | self.z_norm = LayerNorm(z_size) 47 | 48 | self.u_enc = PositionalEncoder(in_channels=u_size, 49 | out_channels=h_size, 50 | n_layers=2) 51 | 52 | self.h_enc = MLP(input_size, h_size, n_layers=2) 53 | self.h_norm = LayerNorm(h_size) 54 | 55 | self.v1 = StaticGraphEmbedding(n_nodes, h_size) 56 | self.m1 = StaticGraphEmbedding(n_nodes, h_size) 57 | 58 | self.v2 = StaticGraphEmbedding(n_nodes, h_size) 59 | self.m2 = StaticGraphEmbedding(n_nodes, h_size) 60 | 61 | self.x_skip = nn.ModuleList() 62 | self.encoder, self.readout = nn.ModuleList(), nn.ModuleList() 63 | for l in range(n_layers): 64 | x_skip = nn.Linear(input_size, h_size) 65 | encoder = HierarchicalTemporalGraphAttention( 66 | h_size=h_size, z_size=z_size, 67 | msg_size=h_size, 68 | msg_layers=message_layers, 69 | reweight=reweight, 70 | mask_temporal=True, 71 | mask_spatial=l < eta, 72 | update_z_cross=update_z_cross, 73 | norm=norm, 74 | root_weight=True, 75 | aggr=spatial_aggr, 76 | dropout=0.0 77 | ) 78 | readout = MLP(h_size, z_size, output_size, 79 | n_layers=2) 80 | self.x_skip.append(x_skip) 81 | self.encoder.append(encoder) 82 | self.readout.append(readout) 83 | 84 | def forward(self, x: Tensor, u: Tensor, mask: Tensor, 85 | edge_index: Tensor, edge_weight: OptTensor = None, 86 | node_index: OptTensor = None, target_nodes: OptTensor = None): 87 | if target_nodes is None: 88 | target_nodes = slice(None) 89 | if node_index is None: 90 | node_index = slice(None) 91 | 92 | # POSITIONAL ENCODING ################################################# 93 | # Obtain spatio-temporal positional encoding for every node-step pair # 94 | # in both observed and target sets. Encoding are obtained by jointly # 95 | # processing node and time positional encoding. # 96 | # Condition also embeddings Z on V. # 97 | 98 | v_nodes = self.v(token_index=node_index) 99 | z = self.z[..., node_index, :] + self.lin_v(v_nodes) 100 | 101 | # Build (node, timestamp) encoding 102 | q = self.u_enc(u, node_index=node_index, node_emb=v_nodes) 103 | # Condition value on key 104 | h = self.h_enc(x) + q 105 | 106 | # ENCODER ############################################################# 107 | # Obtain representations h^i_t for every (i, t) node-step pair by # 108 | # only taking into account valid data in representation set. # 109 | 110 | # Replace H in missing entries with queries Q. Then, condition H on two 111 | # different embeddings to distinguish valid values from masked ones. 112 | h = torch.where(mask.bool(), h + self.v1(), q + self.m1()) 113 | # Normalize features 114 | h, z = self.h_norm(h), self.z_norm(z) 115 | 116 | imputations = [] 117 | 118 | for l in range(self.n_layers): 119 | if l == self.eta: 120 | # Condition H on two different embeddings to distinguish 121 | # valid values from masked ones 122 | h = torch.where(mask.bool(), h + self.v2(), h + self.m2()) 123 | # Skip connection from input x 124 | h = h + self.x_skip[l](x) * mask 125 | # Masked Temporal GAT for encoding representation 126 | h, z = self.encoder[l](h, z, edge_index, mask=mask) 127 | target_readout = self.readout[l](h[..., target_nodes, :]) 128 | imputations.append(target_readout) 129 | 130 | x_hat = imputations.pop(-1) 131 | 132 | return x_hat, imputations 133 | 134 | @staticmethod 135 | def add_model_specific_args(parser): 136 | parser.opt_list('--h-size', type=int, tunable=True, default=32, 137 | options=[16, 32]) 138 | parser.opt_list('--z-size', type=int, tunable=True, default=32, 139 | options=[32, 64, 128]) 140 | parser.opt_list('--z-heads', type=int, tunable=True, default=2, 141 | options=[1, 2, 4, 6]) 142 | parser.add_argument('--u-size', type=int, default=None) 143 | parser.add_argument('--output-size', type=int, default=None) 144 | parser.opt_list('--encoder-layers', type=int, tunable=True, default=2, 145 | options=[1, 2, 3, 4]) 146 | parser.opt_list('--decoder-layers', type=int, tunable=True, default=2, 147 | options=[1, 2, 3, 4]) 148 | parser.add_argument('--message-layers', type=int, default=1) 149 | parser.opt_list('--reweight', type=str, tunable=True, default='softmax', 150 | options=[None, 'softmax']) 151 | parser.add_argument('--update-z-cross', type=bool, default=True) 152 | parser.opt_list('--norm', type=bool, default=True, tunable=True, 153 | options=[True, False]) 154 | parser.opt_list('--spatial-aggr', type=str, tunable=True, 155 | default='add', options=['add', 'softmax']) 156 | return parser 157 | -------------------------------------------------------------------------------- /spin/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | 7 | class CosineSchedulerWithRestarts(LambdaLR): 8 | 9 | def __init__(self, optimizer: Optimizer, 10 | num_warmup_steps: int, 11 | num_training_steps: int, 12 | min_factor: float = 0.1, 13 | linear_decay: float = 0.67, 14 | num_cycles: int = 1, 15 | last_epoch: int = -1): 16 | """From https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/optimization.py#L138 17 | 18 | Create a schedule with a learning rate that decreases following the values 19 | of the cosine function between the initial lr set in the optimizer to 0, 20 | with several hard restarts, after a warmup period during which it increases 21 | linearly between 0 and the initial lr set in the optimizer. 22 | 23 | Args: 24 | optimizer ([`~torch.optim.Optimizer`]): 25 | The optimizer for which to schedule the learning rate. 26 | num_warmup_steps (`int`): 27 | The number of steps for the warmup phase. 28 | num_training_steps (`int`): 29 | The total number of training steps. 30 | num_cycles (`int`, *optional*, defaults to 1): 31 | The number of hard restarts to use. 32 | last_epoch (`int`, *optional*, defaults to -1): 33 | The index of the last epoch when resuming training. 34 | Return: 35 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 36 | """ 37 | 38 | def lr_lambda(current_step): 39 | if current_step < num_warmup_steps: 40 | factor = float(current_step) / float(max(1, num_warmup_steps)) 41 | return max(min_factor, factor) 42 | progress = float(current_step - num_warmup_steps) 43 | progress /= float(max(1, num_training_steps - num_warmup_steps)) 44 | if progress >= 1.0: 45 | return 0.0 46 | factor = (float(num_cycles) * progress) % 1.0 47 | cos = 0.5 * (1.0 + math.cos(math.pi * factor)) 48 | lin = 1.0 - (progress * linear_decay) 49 | return max(min_factor, cos * lin) 50 | 51 | super(CosineSchedulerWithRestarts, self).__init__(optimizer, lr_lambda, 52 | last_epoch) 53 | -------------------------------------------------------------------------------- /spin/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | from numpy.random import choice 5 | from torch_geometric.utils import k_hop_subgraph 6 | from tsl.data import Batch 7 | from tsl.ops.connectivity import weighted_degree 8 | 9 | 10 | def k_hop_subgraph_sampler(batch: Batch, k: int, num_nodes: int, 11 | max_edges: Optional[int] = None, 12 | cut_edges_uniformly: bool = False): 13 | N = batch.x.size(-2) 14 | roots = choice(np.arange(N), num_nodes, replace=False).tolist() 15 | subgraph = k_hop_subgraph(roots, k, batch.edge_index, relabel_nodes=True, 16 | num_nodes=N, flow='target_to_source') 17 | node_idx, edge_index, node_map, edge_mask = subgraph 18 | 19 | col = edge_index[1] 20 | if max_edges is not None and max_edges < edge_index.size(1): 21 | if not cut_edges_uniformly: 22 | in_degree = weighted_degree(col, num_nodes=len(node_idx)) 23 | deg = (1 / in_degree)[col].cpu().numpy() 24 | p = deg / deg.sum() 25 | else: 26 | p = None 27 | keep_edges = sorted(choice(len(col), max_edges, replace=False, p=p)) 28 | else: 29 | keep_edges = slice(None) 30 | for key, pattern in batch.pattern.items(): 31 | if key in batch.target or key == 'eval_mask': 32 | batch[key] = batch[key][..., roots, :] 33 | elif 'n' in pattern: 34 | batch[key] = batch[key][..., node_idx, :] 35 | elif 'e' in pattern and key != 'edge_index': 36 | batch[key] = batch[key][edge_mask][keep_edges] 37 | batch.input.node_index = node_idx # index of nodes in subgraph 38 | batch.input.target_nodes = node_map # index of roots in subgraph 39 | batch.edge_index = edge_index[:, keep_edges] 40 | return batch 41 | -------------------------------------------------------------------------------- /tsl_config.yaml: -------------------------------------------------------------------------------- 1 | config_dir: 'config/' 2 | log_dir: 'log/' --------------------------------------------------------------------------------