├── .gitignore
├── LICENSE
├── README.md
├── conda_env.yml
├── config
├── imputation
│ ├── brits.yaml
│ ├── grin.yaml
│ ├── saits.yaml
│ ├── spin.yaml
│ ├── spin_h.yaml
│ └── transformer.yaml
└── inference.yaml
├── experiments
├── run_imputation.py
└── run_inference.py
├── paper_neurips.pdf
├── poster_neurips.pdf
├── sparse_att.png
├── spin
├── __init__.py
├── baselines
│ ├── __init__.py
│ ├── brits
│ │ ├── __init__.py
│ │ ├── brits.py
│ │ └── layers.py
│ ├── saits
│ │ ├── __init__.py
│ │ ├── layers.py
│ │ └── saits.py
│ └── transformer
│ │ ├── __init__.py
│ │ └── transformer.py
├── imputers
│ ├── __init__.py
│ ├── brits_imputer.py
│ ├── saits_imputer.py
│ └── spin_imputer.py
├── layers
│ ├── __init__.py
│ ├── additive_attention.py
│ ├── hierarchical_temporal_graph_attention.py
│ ├── postional_encoding.py
│ └── temporal_graph_additive_attention.py
├── models
│ ├── __init__.py
│ ├── spin.py
│ └── spin_hierarchical.py
├── scheduler.py
└── utils.py
└── tsl_config.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_STORE
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Graph Machine Learning Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations (NeurIPS 2022)
2 |
3 | [](#)
4 | [](https://arxiv.org/pdf/2205.13479)
5 | [](https://arxiv.org/abs/2205.13479)
6 |
7 | This repository contains the code for the reproducibility of the experiments presented in the paper "Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations" (NeurIPS 2022). We propose a graph neural network that exploits a novel spatiotemporal attention to impute missing values leveraging only (sparse) valid observations.
8 |
9 | **Authors**: [Ivan Marisca](mailto:ivan.marisca@usi.ch), [Andrea Cini](mailto:andrea.cini@usi.ch), Cesare Alippi
10 |
11 | ## SPIN in a nutshell
12 |
13 | Spatiotemporal graphs are often highly sparse, with time series characterized by multiple, concurrent, and even long sequences of missing data, e.g., due to the unreliable underlying sensor network. In this context, autoregressive models can be brittle and exhibit unstable learning dynamics. The objective of this paper is, then, to tackle the problem of learning effective models to reconstruct, i.e., impute, missing data points by conditioning the reconstruction only on the available observations. In particular, we propose a novel class of attention-based architectures that, given a set of highly sparse discrete observations, learn a representation for points in time and space by exploiting a spatiotemporal diffusion architecture aligned with the imputation task. Representations are trained end-to-end to reconstruct observations w.r.t. the corresponding sensor and its neighboring nodes. Compared to the state of the art, our model handles sparse data without propagating prediction errors or requiring a bidirectional model to encode forward and backward time dependencies.
14 |
15 |
16 |

17 |
Example of the sparse spatiotemporal attention layer. On the left, the input spatiotemporal graph, with time series associated with every node. On the right, how the layer acts to update target representation (highlighted by the green box), by simultaneously performing inter-node spatiotemporal cross-attention (red block) and intra-node temporal self-attention (violet block).
18 |
19 |
20 | ---
21 |
22 | ## Directory structure
23 |
24 | The directory is structured as follows:
25 |
26 | ```
27 | .
28 | ├── config/
29 | │ ├── imputation/
30 | │ │ ├── brits.yaml
31 | │ │ ├── grin.yaml
32 | │ │ ├── saits.yaml
33 | │ │ ├── spin.yaml
34 | │ │ ├── spin_h.yaml
35 | │ │ └── transformer.yaml
36 | │ └── inference.yaml
37 | ├── experiments/
38 | │ ├── run_imputation.py
39 | │ └── run_inference.py
40 | ├── spin/
41 | │ ├── baselines/
42 | │ ├── imputers/
43 | │ ├── layers/
44 | │ ├── models/
45 | │ └── ...
46 | ├── conda_env.yaml
47 | └── tsl_config.yaml
48 |
49 | ```
50 |
51 | ## Installation
52 |
53 | We provide a conda environment with all the project dependencies. To install the environment use:
54 |
55 | ```bash
56 | conda env create -f conda_env.yml
57 | conda activate spin
58 | ```
59 |
60 | ## Configuration files
61 |
62 | The `config/` directory stores all the configuration files used to run the experiment. `config/imputation/` stores model configurations used for experiments on imputation.
63 |
64 | ## Python package `spin`
65 |
66 | The support code, including models and baselines, are packed in a python package named `spin`.
67 |
68 | ## Experiments
69 |
70 | The scripts used for the experiment in the paper are in the `experiments` folder.
71 |
72 | * `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is
73 |
74 | ```bash
75 | conda activate spin
76 | python ./experiments/run_imputation.py --config imputation/spin.yaml --model-name spin --dataset-name bay_block
77 | ```
78 |
79 | * `run_inference.py` is used for the experiments on sparse datasets using pre-trained models. An example of usage is
80 |
81 | ```bash
82 | conda activate spin
83 | python ./experiments/run_inference.py --config inference.yaml --model-name spin --dataset-name bay_point --exp-name {exp_name}
84 | ```
85 |
86 | ## Bibtex reference
87 |
88 | If you find this code useful please consider to cite our paper:
89 |
90 | ```
91 | @article{marisca2022learning,
92 | title={Learning to Reconstruct Missing Data from Spatiotemporal Graphs with Sparse Observations},
93 | author={Marisca, Ivan and Cini, Andrea and Alippi, Cesare},
94 | journal={arXiv preprint arXiv:2205.13479},
95 | year={2022}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: spin
2 | channels:
3 | - pytorch
4 | - pyg
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - pip
9 | - pyg>=2.0
10 | - python=3.8
11 | - pytorch>=1.9
12 | - torchvision
13 | - torchaudio
14 | - wheel
15 | - pip:
16 | - torch_spatiotemporal==0.1.1
--------------------------------------------------------------------------------
/config/imputation/brits.yaml:
--------------------------------------------------------------------------------
1 | ######################### BRITS CONFIG ##########################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: 0.05
12 | scale_target: True
13 |
14 | epochs: 300
15 | loss_fn: l1_loss
16 | lr_scheduler: cosine
17 | lr: 0.001
18 | batch_size: 32
19 | batches_epoch: 160
20 |
21 | #### Model params #############################################################
22 | model_name: 'brits'
23 | hidden_size: 64 # [64, 128, 256]
24 |
--------------------------------------------------------------------------------
/config/imputation/grin.yaml:
--------------------------------------------------------------------------------
1 | ########################## GRIN CONFIG ##########################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: 0.05
12 | scale_target: True
13 |
14 | epochs: 300
15 | loss_fn: l1_loss
16 | lr_scheduler: cosine
17 | lr: 0.001
18 | batch_size: 32
19 | batches_epoch: 160
20 |
21 | #### Model params #############################################################
22 | model_name: 'grin'
23 |
24 | hidden_size: 64
25 | ff_size: 64
26 | embedding_size: 8
27 | n_layers: 1
28 | kernel_size: 2
29 | decoder_order: 1
30 | layer_norm: false
31 | ff_dropout: 0
32 | merge_mode: 'mlp'
33 |
--------------------------------------------------------------------------------
/config/imputation/saits.yaml:
--------------------------------------------------------------------------------
1 | ######################### SAITS CONFIG ##########################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: 0.2
12 | prediction_loss_weight: 1
13 | scale_target: True
14 |
15 | epochs: 300
16 | loss_fn: l1_loss
17 | lr: 0.0003
18 | batch_size: 128
19 | batches_epoch: 300
20 | patience: 40
21 |
22 | #### Model params #############################################################
23 | model_name: 'saits'
24 | input_with_mask: True
25 |
26 | n_groups: 1
27 | n_group_inner_layers: 1
28 | param_sharing_strategy: inner_group
29 | d_model: 1024
30 | d_inner: 1024
31 | n_head: 4
32 | d_k: 256
33 | d_v: 256
34 | dropout: 0
35 | diagonal_attention_mask: True
--------------------------------------------------------------------------------
/config/imputation/spin.yaml:
--------------------------------------------------------------------------------
1 | ########################## SPIN CONFIG ##########################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: [0.2, 0.5, 0.8]
12 | scale_target: True
13 |
14 | epochs: 300
15 | loss_fn: l1_loss
16 | lr_scheduler: magic
17 | lr: 0.0008
18 | patience: 40
19 | precision: 16
20 | batch_size: 8
21 | split_batch_in: 2
22 | batches_epoch: 300
23 |
24 | #### Model params #############################################################
25 | model_name: 'spin'
26 | hidden_size: 32
27 | eta: 3
28 | n_layers: 4
29 | message_layers: 1
30 | temporal_self_attention: True
31 | reweight: 'softmax'
32 |
--------------------------------------------------------------------------------
/config/imputation/spin_h.yaml:
--------------------------------------------------------------------------------
1 | ########################## SPIN-H CONFIG ########################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: [0.2, 0.5, 0.8]
12 | scale_target: True
13 |
14 | epochs: 300
15 | loss_fn: l1_loss
16 | lr_scheduler: magic
17 | lr: 0.0008
18 | patience: 40
19 | precision: 16
20 | batch_size: 8
21 | batch_inference: 20
22 | batches_epoch: 300
23 |
24 | #### Model params #############################################################
25 | model_name: 'spin_h'
26 | h_size: 32
27 | z_size: 128
28 | z_heads: 4
29 | eta: 3
30 | n_layers: 5
31 | message_layers: 1
32 | update_z_cross: False
33 | norm: True
34 | reweight: 'softmax'
35 | spatial_aggr: 'softmax'
36 |
--------------------------------------------------------------------------------
/config/imputation/transformer.yaml:
--------------------------------------------------------------------------------
1 | ###################### TRANSFORMER CONFIG #######################
2 |
3 | #### Dataset params ###########################################################
4 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
5 | val_len: 0.1
6 |
7 | window: 24 # [24, 36]
8 | stride: 1
9 |
10 | #### Training params ##########################################################
11 | whiten_prob: [0.2, 0.5, 0.8]
12 | scale_target: True
13 |
14 | epochs: 300
15 | loss_fn: l1_loss
16 | lr_scheduler: magic
17 | lr: 0.0008
18 | patience: 40
19 | precision: 16
20 | batch_size: 8
21 | batch_inference: 32
22 | batches_epoch: 300
23 |
24 | #### Model params #############################################################
25 | model_name: 'transformer'
26 | condition_on_u: True
27 | hidden_size: 64
28 | ff_size: 128
29 | n_heads: 4
30 | n_layers: 5
31 | dropout: 0
32 | axis: 'both'
33 |
--------------------------------------------------------------------------------
/config/inference.yaml:
--------------------------------------------------------------------------------
1 | ######################### INFERENCE CONFIG ##########################
2 |
3 | #### Experiment params ########################################################
4 | root: 'log'
5 | #dataset_name: [la_point, bay_point, la_bock, bay_bock, air, air36]
6 | #model_name: [spin, spin_h, grin, transformer, saits, brits]
7 | #exp_name: {exp_name}
8 |
9 | #### Data params ##############################################################
10 | p_fault: 0
11 | p_noise: 0.95 # [0.5, 0.75, 0.95]
12 | test_mask_seed: [1043, 2043, 3043, 4043, 5043]
13 |
14 | #### Windowing params #########################################################
15 | batch_size: 128
16 |
--------------------------------------------------------------------------------
/experiments/run_imputation.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import os
4 |
5 | import numpy as np
6 | import pytorch_lightning as pl
7 | import torch
8 | import yaml
9 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
10 | from pytorch_lightning.loggers import TensorBoardLogger
11 | from torch.optim.lr_scheduler import CosineAnnealingLR
12 | from tsl import config, logger
13 | from tsl.data import SpatioTemporalDataModule, ImputationDataset
14 | from tsl.data.preprocessing import StandardScaler
15 | from tsl.datasets import AirQuality, MetrLA, PemsBay
16 | from tsl.imputers import Imputer
17 | from tsl.nn.metrics import MaskedMetric, MaskedMAE, MaskedMSE, MaskedMRE
18 | from tsl.nn.models.imputation import GRINModel
19 | from tsl.nn.utils import casting
20 | from tsl.ops.imputation import add_missing_values
21 | from tsl.utils import parser_utils, numpy_metrics
22 | from tsl.utils.parser_utils import ArgParser
23 |
24 | from spin.baselines import SAITS, TransformerModel, BRITS
25 | from spin.imputers import SPINImputer, SAITSImputer, BRITSImputer
26 | from spin.models import SPINModel, SPINHierarchicalModel
27 | from spin.scheduler import CosineSchedulerWithRestarts
28 |
29 |
30 | def get_model_classes(model_str):
31 | if model_str == 'spin':
32 | model, filler = SPINModel, SPINImputer
33 | elif model_str == 'spin_h':
34 | model, filler = SPINHierarchicalModel, SPINImputer
35 | elif model_str == 'grin':
36 | model, filler = GRINModel, Imputer
37 | elif model_str == 'saits':
38 | model, filler = SAITS, SAITSImputer
39 | elif model_str == 'transformer':
40 | model, filler = TransformerModel, SPINImputer
41 | elif model_str == 'brits':
42 | model, filler = BRITS, BRITSImputer
43 | else:
44 | raise ValueError(f'Model {model_str} not available.')
45 | return model, filler
46 |
47 |
48 | def get_dataset(dataset_name: str):
49 | if dataset_name.startswith('air'):
50 | return AirQuality(impute_nans=True, small=dataset_name[3:] == '36')
51 | # build missing dataset
52 | if dataset_name.endswith('_point'):
53 | p_fault, p_noise = 0., 0.25
54 | dataset_name = dataset_name[:-6]
55 | elif dataset_name.endswith('_block'):
56 | p_fault, p_noise = 0.0015, 0.05
57 | dataset_name = dataset_name[:-6]
58 | else:
59 | raise ValueError(f"Invalid dataset name: {dataset_name}.")
60 | if dataset_name == 'la':
61 | return add_missing_values(MetrLA(), p_fault=p_fault, p_noise=p_noise,
62 | min_seq=12, max_seq=12 * 4, seed=9101112)
63 | if dataset_name == 'bay':
64 | return add_missing_values(PemsBay(), p_fault=p_fault, p_noise=p_noise,
65 | min_seq=12, max_seq=12 * 4, seed=56789)
66 | raise ValueError(f"Invalid dataset name: {dataset_name}.")
67 |
68 |
69 | def get_scheduler(scheduler_name: str = None, args=None):
70 | if scheduler_name is None:
71 | return None, None
72 | scheduler_name = scheduler_name.lower()
73 | if scheduler_name == 'cosine':
74 | scheduler_class = CosineAnnealingLR
75 | scheduler_kwargs = dict(eta_min=0.1 * args.lr, T_max=args.epochs)
76 | elif scheduler_name == 'magic':
77 | scheduler_class = CosineSchedulerWithRestarts
78 | scheduler_kwargs = dict(num_warmup_steps=12, min_factor=0.1,
79 | linear_decay=0.67,
80 | num_training_steps=args.epochs,
81 | num_cycles=args.epochs // 100)
82 | else:
83 | raise ValueError(f"Invalid scheduler name: {scheduler_name}.")
84 | return scheduler_class, scheduler_kwargs
85 |
86 |
87 | def parse_args():
88 | # Argument parser
89 | parser = ArgParser()
90 |
91 | parser.add_argument('--seed', type=int, default=-1)
92 | parser.add_argument('--precision', type=int, default=32)
93 | parser.add_argument("--model-name", type=str, default='spin')
94 | parser.add_argument("--dataset-name", type=str, default='air36')
95 | parser.add_argument("--config", type=str, default='imputation/spin.yaml')
96 |
97 | # Splitting/aggregation params
98 | parser.add_argument('--val-len', type=float, default=0.1)
99 | parser.add_argument('--test-len', type=float, default=0.2)
100 |
101 | # Training params
102 | parser.add_argument('--lr', type=float, default=0.001)
103 | parser.add_argument('--epochs', type=int, default=300)
104 | parser.add_argument('--patience', type=int, default=40)
105 | parser.add_argument('--l2-reg', type=float, default=0.)
106 | parser.add_argument('--batches-epoch', type=int, default=300)
107 | parser.add_argument('--batch-inference', type=int, default=32)
108 | parser.add_argument('--split-batch-in', type=int, default=1)
109 | parser.add_argument('--grad-clip-val', type=float, default=5.)
110 | parser.add_argument('--loss-fn', type=str, default='l1_loss')
111 | parser.add_argument('--lr-scheduler', type=str, default=None)
112 |
113 | # Connectivity params
114 | parser.add_argument("--adj-threshold", type=float, default=0.1)
115 |
116 | known_args, _ = parser.parse_known_args()
117 | model_cls, imputer_cls = get_model_classes(known_args.model_name)
118 | parser = model_cls.add_model_specific_args(parser)
119 | parser = imputer_cls.add_argparse_args(parser)
120 | parser = SpatioTemporalDataModule.add_argparse_args(parser)
121 | parser = ImputationDataset.add_argparse_args(parser)
122 |
123 | args = parser.parse_args()
124 | if args.config is not None:
125 | cfg_path = os.path.join(config.config_dir, args.config)
126 | with open(cfg_path, 'r') as fp:
127 | config_args = yaml.load(fp, Loader=yaml.FullLoader)
128 | for arg in config_args:
129 | setattr(args, arg, config_args[arg])
130 |
131 | return args
132 |
133 |
134 | def run_experiment(args):
135 | # Set configuration and seed
136 | args = copy.deepcopy(args)
137 | if args.seed < 0:
138 | args.seed = np.random.randint(1e9)
139 | torch.set_num_threads(1)
140 | pl.seed_everything(args.seed)
141 |
142 | # script flags
143 | is_spin = args.model_name in ['spin', 'spin_h']
144 |
145 | model_cls, imputer_class = get_model_classes(args.model_name)
146 | dataset = get_dataset(args.dataset_name)
147 |
148 | logger.info(args)
149 |
150 | ########################################
151 | # create logdir and save configuration #
152 | ########################################
153 |
154 | exp_name = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
155 | exp_name = f"{exp_name}_{args.seed}"
156 | logdir = os.path.join(config.log_dir, args.dataset_name,
157 | args.model_name, exp_name)
158 | # save config for logging
159 | os.makedirs(logdir, exist_ok=True)
160 | with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:
161 | yaml.dump(parser_utils.config_dict_from_args(args), fp,
162 | indent=4, sort_keys=True)
163 |
164 | ########################################
165 | # data module #
166 | ########################################
167 |
168 | # time embedding
169 | if is_spin or args.model_name == 'transformer':
170 | time_emb = dataset.datetime_encoded(['day', 'week']).values
171 | exog_map = {'global_temporal_encoding': time_emb}
172 |
173 | input_map = {
174 | 'u': 'temporal_encoding',
175 | 'x': 'data'
176 | }
177 | else:
178 | exog_map = input_map = None
179 |
180 | if is_spin or args.model_name == 'grin':
181 | adj = dataset.get_connectivity(threshold=args.adj_threshold,
182 | include_self=False,
183 | force_symmetric=is_spin)
184 | else:
185 | adj = None
186 |
187 | # instantiate dataset
188 | torch_dataset = ImputationDataset(*dataset.numpy(return_idx=True),
189 | training_mask=dataset.training_mask,
190 | eval_mask=dataset.eval_mask,
191 | connectivity=adj,
192 | exogenous=exog_map,
193 | input_map=input_map,
194 | window=args.window,
195 | stride=args.stride)
196 |
197 | # get train/val/test indices
198 | splitter = dataset.get_splitter(val_len=args.val_len,
199 | test_len=args.test_len)
200 |
201 | scalers = {'data': StandardScaler(axis=(0, 1))}
202 |
203 | dm = SpatioTemporalDataModule(torch_dataset,
204 | scalers=scalers,
205 | splitter=splitter,
206 | batch_size=args.batch_size // args.split_batch_in)
207 | dm.setup()
208 |
209 | ########################################
210 | # predictor #
211 | ########################################
212 |
213 | additional_model_hparams = dict(n_nodes=dm.n_nodes,
214 | input_size=dm.n_channels,
215 | u_size=4,
216 | output_size=dm.n_channels,
217 | window_size=dm.window)
218 |
219 | # model's inputs
220 | model_kwargs = parser_utils.filter_args(
221 | args={**vars(args), **additional_model_hparams},
222 | target_cls=model_cls,
223 | return_dict=True)
224 |
225 | # loss and metrics
226 | loss_fn = MaskedMetric(metric_fn=getattr(torch.nn.functional, args.loss_fn),
227 | compute_on_step=True,
228 | metric_kwargs={'reduction': 'none'})
229 |
230 | metrics = {'mae': MaskedMAE(compute_on_step=False),
231 | 'mse': MaskedMSE(compute_on_step=False),
232 | 'mre': MaskedMRE(compute_on_step=False)}
233 |
234 | scheduler_class, scheduler_kwargs = get_scheduler(args.lr_scheduler, args)
235 |
236 | # setup imputer
237 | imputer_kwargs = parser_utils.filter_argparse_args(args, imputer_class,
238 | return_dict=True)
239 | imputer = imputer_class(
240 | model_class=model_cls,
241 | model_kwargs=model_kwargs,
242 | optim_class=torch.optim.Adam,
243 | optim_kwargs={'lr': args.lr,
244 | 'weight_decay': args.l2_reg},
245 | loss_fn=loss_fn,
246 | metrics=metrics,
247 | scheduler_class=scheduler_class,
248 | scheduler_kwargs=scheduler_kwargs,
249 | **imputer_kwargs
250 | )
251 |
252 | ########################################
253 | # training #
254 | ########################################
255 |
256 | # callbacks
257 | early_stop_callback = EarlyStopping(monitor='val_mae',
258 | patience=args.patience, mode='min')
259 | checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1,
260 | monitor='val_mae', mode='min')
261 |
262 | tb_logger = TensorBoardLogger(logdir, name="model")
263 |
264 | trainer = pl.Trainer(max_epochs=args.epochs,
265 | default_root_dir=logdir,
266 | logger=tb_logger,
267 | precision=args.precision,
268 | accumulate_grad_batches=args.split_batch_in,
269 | gpus=int(torch.cuda.is_available()),
270 | gradient_clip_val=args.grad_clip_val,
271 | limit_train_batches=args.batches_epoch * args.split_batch_in,
272 | callbacks=[early_stop_callback, checkpoint_callback])
273 |
274 | trainer.fit(imputer,
275 | train_dataloaders=dm.train_dataloader(),
276 | val_dataloaders=dm.val_dataloader(
277 | batch_size=args.batch_inference))
278 |
279 | ########################################
280 | # testing #
281 | ########################################
282 |
283 | imputer.load_model(checkpoint_callback.best_model_path)
284 | imputer.freeze()
285 | trainer.test(imputer, dataloaders=dm.test_dataloader(
286 | batch_size=args.batch_inference))
287 |
288 | output = trainer.predict(imputer, dataloaders=dm.test_dataloader(
289 | batch_size=args.batch_inference))
290 | output = casting.numpy(output)
291 | y_hat, y_true, mask = output['y_hat'].squeeze(-1), \
292 | output['y'].squeeze(-1), \
293 | output['mask'].squeeze(-1)
294 | check_mae = numpy_metrics.masked_mae(y_hat, y_true, mask)
295 | print(f'Test MAE: {check_mae:.2f}')
296 | return y_hat
297 |
298 |
299 | if __name__ == '__main__':
300 | args = parse_args()
301 | run_experiment(args)
302 |
--------------------------------------------------------------------------------
/experiments/run_inference.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import torch
7 | import tsl
8 | import yaml
9 | from tsl import config
10 | from tsl.data import SpatioTemporalDataModule, ImputationDataset
11 | from tsl.data.preprocessing import StandardScaler
12 | from tsl.datasets import AirQuality, MetrLA, PemsBay
13 | from tsl.imputers import Imputer
14 | from tsl.nn.models.imputation import GRINModel
15 | from tsl.nn.utils import casting
16 | from tsl.ops.imputation import add_missing_values, sample_mask
17 | from tsl.utils import ArgParser, parser_utils, numpy_metrics
18 | from tsl.utils.python_utils import ensure_list
19 |
20 | from spin.baselines import SAITS, TransformerModel, BRITS
21 | from spin.imputers import SPINImputer, SAITSImputer, BRITSImputer
22 | from spin.models import SPINModel, SPINHierarchicalModel
23 |
24 |
25 | def get_model_classes(model_str):
26 | if model_str == 'spin':
27 | model, filler = SPINModel, SPINImputer
28 | elif model_str == 'spin_h':
29 | model, filler = SPINHierarchicalModel, SPINImputer
30 | elif model_str == 'grin':
31 | model, filler = GRINModel, Imputer
32 | elif model_str == 'saits':
33 | model, filler = SAITS, SAITSImputer
34 | elif model_str == 'transformer':
35 | model, filler = TransformerModel, SPINImputer
36 | elif model_str == 'brits':
37 | model, filler = BRITS, BRITSImputer
38 | else:
39 | raise ValueError(f'Model {model_str} not available.')
40 | return model, filler
41 |
42 |
43 | def get_dataset(dataset_name: str):
44 | if dataset_name.startswith('air'):
45 | return AirQuality(impute_nans=True, small=dataset_name[3:] == '36')
46 | # build missing dataset
47 | if dataset_name.endswith('_point'):
48 | p_fault, p_noise = 0., 0.25
49 | dataset_name = dataset_name[:-6]
50 | elif dataset_name.endswith('_block'):
51 | p_fault, p_noise = 0.0015, 0.05
52 | dataset_name = dataset_name[:-6]
53 | else:
54 | raise ValueError(f"Invalid dataset name: {dataset_name}.")
55 | if dataset_name == 'la':
56 | return add_missing_values(MetrLA(), p_fault=p_fault, p_noise=p_noise,
57 | min_seq=12, max_seq=12 * 4, seed=9101112)
58 | if dataset_name == 'bay':
59 | return add_missing_values(PemsBay(), p_fault=p_fault, p_noise=p_noise,
60 | min_seq=12, max_seq=12 * 4, seed=56789)
61 | raise ValueError(f"Invalid dataset name: {dataset_name}.")
62 |
63 |
64 | def parse_args():
65 | # Argument parser
66 | parser = ArgParser()
67 |
68 | parser.add_argument("--model-name", type=str)
69 | parser.add_argument("--dataset-name", type=str)
70 | parser.add_argument("--exp-name", type=str)
71 | parser.add_argument("--config", type=str, default='inference.yaml')
72 | parser.add_argument("--root", type=str, default='log')
73 |
74 | # Data sparsity params
75 | parser.add_argument('--p-fault', type=float, default=0.0)
76 | parser.add_argument('--p-noise', type=float, default=0.75)
77 | parser.add_argument('--test-mask-seed', type=int, default=None)
78 |
79 | # Splitting/aggregation params
80 | parser.add_argument('--val-len', type=float, default=0.1)
81 | parser.add_argument('--test-len', type=float, default=0.2)
82 | parser.add_argument('--batch-size', type=int, default=32)
83 |
84 | # Connectivity params
85 | parser.add_argument("--adj-threshold", type=float, default=0.1)
86 |
87 | args = parser.parse_args()
88 | if args.config is not None:
89 | cfg_path = os.path.join(config.config_dir, args.config)
90 | with open(cfg_path, 'r') as fp:
91 | config_args = yaml.load(fp, Loader=yaml.FullLoader)
92 | for arg in config_args:
93 | setattr(args, arg, config_args[arg])
94 |
95 | return args
96 |
97 |
98 | def load_model(exp_dir, exp_config, dm):
99 | model_cls, imputer_class = get_model_classes(exp_config['model_name'])
100 | additional_model_hparams = dict(n_nodes=dm.n_nodes,
101 | input_size=dm.n_channels,
102 | u_size=4,
103 | output_size=dm.n_channels,
104 | window_size=dm.window)
105 |
106 | # model's inputs
107 | model_kwargs = parser_utils.filter_args(
108 | args={**exp_config, **additional_model_hparams},
109 | target_cls=model_cls,
110 | return_dict=True)
111 |
112 | # setup imputer
113 | imputer_kwargs = parser_utils.filter_argparse_args(exp_config,
114 | imputer_class,
115 | return_dict=True)
116 | imputer = imputer_class(
117 | model_class=model_cls,
118 | model_kwargs=model_kwargs,
119 | optim_class=torch.optim.Adam,
120 | optim_kwargs={},
121 | loss_fn=None,
122 | **imputer_kwargs
123 | )
124 |
125 | model_path = None
126 | for file in os.listdir(exp_dir):
127 | if file.endswith(".ckpt"):
128 | model_path = os.path.join(exp_dir, file)
129 | break
130 | if model_path is None:
131 | raise ValueError(f"Model not found.")
132 |
133 | imputer.load_model(model_path)
134 | imputer.freeze()
135 | return imputer
136 |
137 |
138 | def update_test_eval_mask(dm, dataset, p_fault, p_noise, seed=None):
139 | if seed is None:
140 | seed = np.random.randint(1e9)
141 | random = np.random.default_rng(seed)
142 | dataset.set_eval_mask(
143 | sample_mask(dataset.shape, p=p_fault, p_noise=p_noise,
144 | min_seq=12, max_seq=36, rng=random)
145 | )
146 | dm.torch_dataset.set_mask(dataset.training_mask)
147 | dm.torch_dataset.update_exogenous('eval_mask', dataset.eval_mask)
148 |
149 |
150 | def run_experiment(args):
151 | # Set configuration
152 | args = copy.deepcopy(args)
153 | tsl.logger.disabled = True
154 |
155 | # script flags
156 | is_spin = args.model_name in ['spin', 'spin_h']
157 |
158 | ########################################
159 | # load config #
160 | ########################################
161 |
162 | if args.root is None:
163 | root = tsl.config.log_dir
164 | else:
165 | root = os.path.join(tsl.config.curr_dir, args.root)
166 | exp_dir = os.path.join(root, args.dataset_name,
167 | args.model_name, args.exp_name)
168 |
169 | with open(os.path.join(exp_dir, 'config.yaml'), 'r') as fp:
170 | exp_config = yaml.load(fp, Loader=yaml.FullLoader)
171 |
172 | ########################################
173 | # load dataset #
174 | ########################################
175 |
176 | dataset = get_dataset(exp_config['dataset_name'])
177 |
178 | ########################################
179 | # load data module #
180 | ########################################
181 |
182 | # time embedding
183 | if is_spin or args.model_name == 'transformer':
184 | time_emb = dataset.datetime_encoded(['day', 'week']).values
185 | exog_map = {'global_temporal_encoding': time_emb}
186 |
187 | input_map = {
188 | 'u': 'temporal_encoding',
189 | 'x': 'data'
190 | }
191 | else:
192 | exog_map = input_map = None
193 |
194 | if is_spin or args.model_name == 'grin':
195 | adj = dataset.get_connectivity(threshold=args.adj_threshold,
196 | include_self=False,
197 | force_symmetric=is_spin)
198 | else:
199 | adj = None
200 |
201 | # instantiate dataset
202 | torch_dataset = ImputationDataset(*dataset.numpy(return_idx=True),
203 | training_mask=dataset.training_mask,
204 | eval_mask=dataset.eval_mask,
205 | connectivity=adj,
206 | exogenous=exog_map,
207 | input_map=input_map,
208 | window=exp_config['window'],
209 | stride=exp_config['stride'])
210 |
211 | # get train/val/test indices
212 | splitter = dataset.get_splitter(val_len=args.val_len,
213 | test_len=args.test_len)
214 |
215 | scalers = {'data': StandardScaler(axis=(0, 1))}
216 |
217 | dm = SpatioTemporalDataModule(torch_dataset,
218 | scalers=scalers,
219 | splitter=splitter,
220 | batch_size=args.batch_size)
221 | dm.setup()
222 |
223 | ########################################
224 | # load model #
225 | ########################################
226 |
227 | imputer = load_model(exp_dir, exp_config, dm)
228 |
229 | trainer = pl.Trainer(gpus=int(torch.cuda.is_available()))
230 |
231 | ########################################
232 | # inference #
233 | ########################################
234 |
235 | seeds = ensure_list(args.test_mask_seed)
236 | mae = []
237 |
238 | for seed in seeds:
239 | # Change evaluation mask
240 | update_test_eval_mask(dm, dataset, args.p_fault, args.p_noise, seed)
241 |
242 | output = trainer.predict(imputer, dataloaders=dm.test_dataloader())
243 | output = casting.numpy(output)
244 | y_hat, y_true, mask = output['y_hat'].squeeze(-1), \
245 | output['y'].squeeze(-1), \
246 | output['mask'].squeeze(-1)
247 |
248 | check_mae = numpy_metrics.masked_mae(y_hat, y_true, mask)
249 | mae.append(check_mae)
250 | print(f'SEED {seed} - Test MAE: {check_mae:.2f}')
251 |
252 | print(f'MAE over {len(seeds)} runs: {np.mean(mae):.2f}±{np.std(mae):.2f}')
253 |
254 |
255 | if __name__ == '__main__':
256 | args = parse_args()
257 | run_experiment(args)
258 |
--------------------------------------------------------------------------------
/paper_neurips.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/paper_neurips.pdf
--------------------------------------------------------------------------------
/poster_neurips.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/poster_neurips.pdf
--------------------------------------------------------------------------------
/sparse_att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/sparse_att.png
--------------------------------------------------------------------------------
/spin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Graph-Machine-Learning-Group/spin/7349ba31da7306e7e96c13668a3f1f0a4df90902/spin/__init__.py
--------------------------------------------------------------------------------
/spin/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | from .brits import BRITS
2 | from .saits import SAITS
3 | from .transformer import TransformerModel
4 |
--------------------------------------------------------------------------------
/spin/baselines/brits/__init__.py:
--------------------------------------------------------------------------------
1 | from .brits import BRITS
2 |
--------------------------------------------------------------------------------
/spin/baselines/brits/brits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import nn
4 | from tsl.nn.functional import reverse_tensor
5 |
6 | from .layers import RITS
7 |
8 |
9 | class BRITS(nn.Module):
10 |
11 | def __init__(self, input_size: int, n_nodes: int, hidden_size: int = 64):
12 | super().__init__()
13 | self.n_nodes = n_nodes
14 | self.rits_fwd = RITS(input_size * n_nodes, hidden_size)
15 | self.rits_bwd = RITS(input_size * n_nodes, hidden_size)
16 |
17 | def forward(self, x, mask=None):
18 | # tsl shape to original shape: [b s n c] -> [b s c]
19 | x = rearrange(x, 'b s n c -> b s (n c)')
20 | mask = rearrange(mask, 'b s n c -> b s (n c)')
21 | # forward
22 | imp_fwd, pred_fwd = self.rits_fwd(x, mask)
23 | # backward
24 | x_bwd = reverse_tensor(x, dim=1)
25 | mask_bwd = reverse_tensor(mask, dim=1) if mask is not None else None
26 | imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd)
27 | imp_bwd, pred_bwd = reverse_tensor(imp_bwd, dim=1), \
28 | [reverse_tensor(pb, dim=1) for pb in pred_bwd]
29 | # stack into shape = [batch, directions, steps, features]
30 | imputation = (imp_fwd + imp_bwd) / 2
31 | predictions = [imp_fwd, imp_bwd] + pred_fwd + pred_bwd
32 |
33 | imputation = rearrange(imputation, 'b s (n c) -> b s n c',
34 | n=self.n_nodes)
35 | predictions = [rearrange(pred, 'b s (n c) -> b s n c', n=self.n_nodes)
36 | for pred in predictions]
37 |
38 | return imputation, predictions
39 |
40 | @staticmethod
41 | def consistency_loss(imp_fwd, imp_bwd):
42 | loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean()
43 | return loss
44 |
45 | @staticmethod
46 | def add_model_specific_args(parser):
47 | parser.add_argument('--hidden-size', type=int, default=64)
48 | return parser
49 |
--------------------------------------------------------------------------------
/spin/baselines/brits/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FeatureRegression(nn.Module):
9 | def __init__(self, input_size):
10 | super(FeatureRegression, self).__init__()
11 | self.W = nn.Parameter(torch.Tensor(input_size, input_size))
12 | self.b = nn.Parameter(torch.Tensor(input_size))
13 |
14 | m = 1 - torch.eye(input_size, input_size)
15 | self.register_buffer('m', m)
16 |
17 | self.reset_parameters()
18 |
19 | def reset_parameters(self):
20 | stdv = 1. / math.sqrt(self.W.shape[0])
21 | self.W.data.uniform_(-stdv, stdv)
22 | self.b.data.uniform_(-stdv, stdv)
23 |
24 | def forward(self, x):
25 | z_h = F.linear(x, self.W * self.m, self.b)
26 | return z_h
27 |
28 |
29 | class TemporalDecay(nn.Module):
30 | def __init__(self, d_in, d_out, diag=False):
31 | super(TemporalDecay, self).__init__()
32 | self.diag = diag
33 | self.W = nn.Parameter(torch.Tensor(d_out, d_in))
34 | self.b = nn.Parameter(torch.Tensor(d_out))
35 |
36 | if self.diag:
37 | assert (d_in == d_out)
38 | m = torch.eye(d_in, d_in)
39 | self.register_buffer('m', m)
40 |
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | stdv = 1. / math.sqrt(self.W.shape[0])
45 | self.W.data.uniform_(-stdv, stdv)
46 | self.b.data.uniform_(-stdv, stdv)
47 |
48 | @staticmethod
49 | def compute_delta(mask, freq=1):
50 | delta = torch.zeros_like(mask).float()
51 | one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device)
52 | for i in range(1, delta.shape[-2]):
53 | m = mask[..., i - 1, :]
54 | delta[..., i, :] = m * one_step + (1 - m) * torch.add(
55 | delta[..., i - 1, :], freq)
56 | return delta
57 |
58 | def forward(self, d):
59 | if self.diag:
60 | gamma = F.relu(F.linear(d, self.W * self.m, self.b))
61 | else:
62 | gamma = F.relu(F.linear(d, self.W, self.b))
63 | gamma = torch.exp(-gamma)
64 | return gamma
65 |
66 |
67 | class RITS(nn.Module):
68 | def __init__(self,
69 | input_size,
70 | hidden_size=64):
71 | super(RITS, self).__init__()
72 | self.input_size = int(input_size)
73 | self.hidden_size = int(hidden_size)
74 |
75 | self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size)
76 |
77 | self.temp_decay_h = TemporalDecay(d_in=self.input_size,
78 | d_out=self.hidden_size, diag=False)
79 | self.temp_decay_x = TemporalDecay(d_in=self.input_size,
80 | d_out=self.input_size, diag=True)
81 |
82 | self.hist_reg = nn.Linear(self.hidden_size, self.input_size)
83 | self.feat_reg = FeatureRegression(self.input_size)
84 |
85 | self.weight_combine = nn.Linear(2 * self.input_size, self.input_size)
86 |
87 | def init_hidden_states(self, x):
88 | return torch.zeros((x.shape[0], self.hidden_size)).to(x.device)
89 |
90 | def forward(self, x, mask=None, delta=None):
91 | # x : [batch, steps, features]
92 | steps = x.shape[-2]
93 |
94 | if mask is None:
95 | mask = torch.ones_like(x, dtype=torch.uint8)
96 | if delta is None:
97 | delta = TemporalDecay.compute_delta(mask)
98 |
99 | # init rnn states
100 | h = self.init_hidden_states(x)
101 | c = self.init_hidden_states(x)
102 |
103 | imputation = []
104 | predictions = []
105 | for step in range(steps):
106 | d = delta[:, step, :]
107 | m = mask[:, step, :]
108 | x_s = x[:, step, :]
109 |
110 | gamma_h = self.temp_decay_h(d)
111 |
112 | # history prediction
113 | x_h = self.hist_reg(h)
114 | x_c = m * x_s + (1 - m) * x_h
115 | h = h * gamma_h
116 |
117 | # feature prediction
118 | z_h = self.feat_reg(x_c)
119 |
120 | # predictions combination
121 | gamma_x = self.temp_decay_x(d)
122 | alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1))
123 | alpha = torch.sigmoid(alpha)
124 | c_h = alpha * z_h + (1 - alpha) * x_h
125 |
126 | c_c = m * x_s + (1 - m) * c_h
127 | inputs = torch.cat([c_c, m], dim=1)
128 | h, c = self.rnn_cell(inputs, (h, c))
129 |
130 | imputation.append(c_h)
131 | predictions.append(torch.stack((z_h, x_h), dim=0))
132 |
133 | # imputation -> [batch, steps, features]
134 | imputation = torch.stack(imputation, dim=-2)
135 | # predictions -> [predictions, batch, steps, features]
136 | predictions = torch.stack(predictions, dim=-2)
137 |
138 | return imputation, [*predictions]
139 |
--------------------------------------------------------------------------------
/spin/baselines/saits/__init__.py:
--------------------------------------------------------------------------------
1 | from .saits import SAITS
2 |
--------------------------------------------------------------------------------
/spin/baselines/saits/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 |
8 | class ScaledDotProductAttention(nn.Module):
9 | """scaled dot-product attention"""
10 |
11 | def __init__(self, temperature, attn_dropout=0.1):
12 | super().__init__()
13 | self.temperature = temperature
14 | self.dropout = nn.Dropout(attn_dropout)
15 |
16 | def forward(self, q, k, v, attn_mask=None):
17 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
18 | if attn_mask is not None:
19 | attn = attn.masked_fill(attn_mask == 1, -1e9)
20 | attn = self.dropout(F.softmax(attn, dim=-1))
21 | output = torch.matmul(attn, v)
22 | return output, attn
23 |
24 |
25 | class MultiHeadAttention(nn.Module):
26 | """original Transformer multi-head attention"""
27 |
28 | def __init__(self, n_head, d_model, d_k, d_v, attn_dropout):
29 | super().__init__()
30 |
31 | self.n_head = n_head
32 | self.d_k = d_k
33 | self.d_v = d_v
34 |
35 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
36 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
37 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
38 |
39 | self.attention = ScaledDotProductAttention(d_k ** 0.5, attn_dropout)
40 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
41 |
42 | def forward(self, q, k, v, attn_mask=None):
43 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
44 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
45 |
46 | # Pass through the pre-attention projection: b x lq x (n*dv)
47 | # Separate different heads: b x lq x n x dv
48 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
49 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
50 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
51 |
52 | # Transpose for attention dot product: b x n x lq x dv
53 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
54 |
55 | if attn_mask is not None:
56 | # this mask is imputation mask, which is not generated from
57 | # each batch, so needs broadcasting on batch dim
58 | attn_mask = attn_mask.unsqueeze(0).unsqueeze(1)
59 |
60 | v, attn_weights = self.attention(q, k, v, attn_mask)
61 |
62 | # Transpose to move the head dimension back: b x lq x n x dv
63 | # Concatenate all the heads together: b x lq x (n*dv)
64 | v = v.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
65 | v = self.fc(v)
66 | return v, attn_weights
67 |
68 |
69 | class PositionWiseFeedForward(nn.Module):
70 | def __init__(self, d_in, d_hid, dropout=0.1):
71 | super().__init__()
72 | self.w_1 = nn.Linear(d_in, d_hid)
73 | self.w_2 = nn.Linear(d_hid, d_in)
74 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
75 | self.dropout = nn.Dropout(dropout)
76 |
77 | def forward(self, x):
78 | residual = x
79 | x = self.layer_norm(x)
80 | x = self.w_2(F.relu(self.w_1(x)))
81 | x = self.dropout(x)
82 | x += residual
83 | return x
84 |
85 |
86 | class EncoderLayer(nn.Module):
87 | def __init__(self, d_time, d_feature, d_model, d_inner, n_head, d_k, d_v,
88 | diagonal_attention_mask: bool = True,
89 | dropout: float = 0.1,
90 | attn_dropout: float = 0.1):
91 | super(EncoderLayer, self).__init__()
92 |
93 | self.diagonal_attention_mask = diagonal_attention_mask
94 | self.d_time = d_time
95 | self.d_feature = d_feature
96 |
97 | self.layer_norm = nn.LayerNorm(d_model)
98 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v,
99 | attn_dropout)
100 | self.dropout = nn.Dropout(dropout)
101 | self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
102 |
103 | def forward(self, x: Tensor):
104 | if self.diagonal_attention_mask:
105 | mask_time = torch.eye(self.d_time).to(x.device)
106 | else:
107 | mask_time = None
108 |
109 | residual = x
110 | # here we apply LN before attention cal, namely Pre-LN
111 | enc_input = self.layer_norm(x)
112 | enc_output, attn_weights = self.slf_attn(enc_input, enc_input,
113 | enc_input, attn_mask=mask_time)
114 | enc_output = self.dropout(enc_output)
115 | enc_output += residual
116 |
117 | enc_output = self.pos_ffn(enc_output)
118 | return enc_output, attn_weights
119 |
120 |
121 | class PositionalEncoding(nn.Module):
122 | def __init__(self, d_hid, n_position=200):
123 | super(PositionalEncoding, self).__init__()
124 | # Not a parameter
125 | pos_table = self._get_sinusoid_encoding_table(n_position, d_hid)
126 | self.register_buffer('pos_table', pos_table)
127 |
128 | def _get_sinusoid_encoding_table(self, n_position, d_hid):
129 | """ Sinusoid position encoding table """
130 |
131 | def get_position_angle_vec(position):
132 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for
133 | hid_j in range(d_hid)]
134 |
135 | sinusoid_table = np.array([get_position_angle_vec(pos_i)
136 | for pos_i in range(n_position)])
137 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
138 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
139 | return torch.Tensor(sinusoid_table)
140 |
141 | def forward(self, x):
142 | return x + self.pos_table[:x.size(1)]
143 |
--------------------------------------------------------------------------------
/spin/baselines/saits/saits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import nn, Tensor
4 | from torch.nn import functional as F
5 | from torch_geometric.nn import inits
6 |
7 | from .layers import EncoderLayer, PositionalEncoding
8 |
9 |
10 | class TransformerEncoder(nn.Module):
11 | def __init__(self, n_groups, n_group_inner_layers, d_time, d_feature,
12 | d_model, d_inner, n_head, d_k, d_v, dropout,
13 | **kwargs):
14 | super().__init__()
15 | self.n_groups = n_groups
16 | self.n_group_inner_layers = n_group_inner_layers
17 | self.input_with_mask = kwargs['input_with_mask']
18 | actual_d_feature = d_feature * 2 if self.input_with_mask else d_feature
19 | self.param_sharing_strategy = kwargs['param_sharing_strategy']
20 | self.MIT = kwargs['MIT']
21 |
22 | if self.param_sharing_strategy == 'between_group':
23 | # For between_group, only need to create 1 group and
24 | # repeat n_groups times while forwarding
25 | self.layer_stack = nn.ModuleList([
26 | EncoderLayer(d_time, actual_d_feature, d_model, d_inner, n_head,
27 | d_k, d_v, dropout, dropout, **kwargs)
28 | for _ in range(n_group_inner_layers)
29 | ])
30 | else: # then inner_group,inner_group is the way used in ALBERT
31 | # For inner_group, only need to create n_groups layers and
32 | # repeat n_group_inner_layers times in each group while forwarding
33 | self.layer_stack = nn.ModuleList([
34 | EncoderLayer(d_time, actual_d_feature, d_model, d_inner, n_head,
35 | d_k, d_v, dropout, dropout, **kwargs)
36 | for _ in range(n_groups)
37 | ])
38 |
39 | self.embedding = nn.Linear(actual_d_feature, d_model)
40 | self.position_enc = PositionalEncoding(d_model, n_position=d_time)
41 | self.dropout = nn.Dropout(p=dropout)
42 | self.reduce_dim = nn.Linear(d_model, d_feature)
43 |
44 | def impute(self, x, mask, **kwargs):
45 | # tsl shape to original shape: [b s n c=1] -> [b s c]
46 | is_bsnc = False
47 | if x.ndim == 4:
48 | is_bsnc = True
49 | x, mask = x.squeeze(-1), mask.squeeze(-1)
50 |
51 | # 1st DMSA ############################################################
52 |
53 | # Cat mask (eventually)
54 | if self.input_with_mask:
55 | input_X = torch.cat([x, mask], dim=2)
56 | else:
57 | input_X = x
58 | input_X = self.embedding(input_X)
59 | enc_output = self.dropout(self.position_enc(input_X))
60 |
61 | if self.param_sharing_strategy == 'between_group':
62 | for _ in range(self.n_groups):
63 | for encoder_layer in self.layer_stack:
64 | enc_output, _ = encoder_layer(enc_output)
65 | else:
66 | for encoder_layer in self.layer_stack:
67 | for _ in range(self.n_group_inner_layers):
68 | enc_output, _ = encoder_layer(enc_output)
69 |
70 | learned_presentation = self.reduce_dim(enc_output)
71 | # replace non-missing part with original data
72 | imputed_data = mask * x + (1 - mask) * learned_presentation
73 |
74 | if is_bsnc:
75 | imputed_data.unsqueeze_(-1), learned_presentation.unsqueeze_(-1)
76 | return imputed_data, learned_presentation
77 |
78 |
79 | class SAITS(nn.Module):
80 | def __init__(self, input_size: int, window_size: int, n_nodes: int,
81 | d_model: int = 256,
82 | d_inner: int = 128,
83 | n_head: int = 4,
84 | d_k: int = None, # or 64
85 | d_v: int = 64,
86 | n_groups: int = 2,
87 | n_group_inner_layers: int = 1,
88 | param_sharing_strategy: str = 'inner_group',
89 | dropout: float = 0.1,
90 | input_with_mask: bool = True,
91 | diagonal_attention_mask: bool = True,
92 | trainable_mask_token: bool = False):
93 | super().__init__()
94 | self.n_nodes = n_nodes
95 | self.input_size = input_size
96 | self.n_groups = n_groups
97 | self.n_group_inner_layers = n_group_inner_layers
98 | self.input_with_mask = input_with_mask
99 | self.param_sharing_strategy = param_sharing_strategy
100 |
101 | d_in = in_features = input_size * n_nodes
102 | if self.input_with_mask:
103 | d_in = 2 * d_in
104 | d_k = d_k or (d_model // n_head) # from the appendix
105 |
106 | if trainable_mask_token:
107 | self.mask_token = nn.Parameter(torch.Tensor(1, 1, in_features))
108 | inits.uniform(in_features, self.mask_token)
109 | else:
110 | self.register_buffer('mask_token', torch.zeros(1, 1, in_features))
111 |
112 | if self.param_sharing_strategy == 'between_group':
113 | # For between_group, only need to create 1 group and
114 | # repeat n_groups times while forwarding
115 | n_layers = n_group_inner_layers
116 | else: # then inner_group,inner_group is the way used in ALBERT
117 | # For inner_group, only need to create n_groups layers and
118 | # repeat n_group_inner_layers times in each group while forwarding
119 | n_layers = n_groups
120 |
121 | self.layer_stack_for_first_block = nn.ModuleList([
122 | EncoderLayer(d_time=window_size,
123 | d_feature=d_in,
124 | d_model=d_model,
125 | d_inner=d_inner,
126 | n_head=n_head,
127 | d_k=d_k,
128 | d_v=d_v,
129 | dropout=dropout,
130 | attn_dropout=0,
131 | diagonal_attention_mask=diagonal_attention_mask)
132 | for _ in range(n_layers)
133 | ])
134 | self.layer_stack_for_second_block = nn.ModuleList([
135 | EncoderLayer(d_time=window_size,
136 | d_feature=d_in,
137 | d_model=d_model,
138 | d_inner=d_inner,
139 | n_head=n_head,
140 | d_k=d_k,
141 | d_v=d_v,
142 | dropout=dropout,
143 | attn_dropout=0,
144 | diagonal_attention_mask=diagonal_attention_mask)
145 | for _ in range(n_layers)
146 | ])
147 |
148 | self.dropout = nn.Dropout(p=dropout)
149 | self.position_enc = PositionalEncoding(d_model, n_position=window_size)
150 | # for operation on time dim
151 | self.embedding_1 = nn.Linear(d_in, d_model)
152 | self.reduce_dim_z = nn.Linear(d_model, in_features)
153 | # for operation on measurement dim
154 | self.embedding_2 = nn.Linear(d_in, d_model)
155 | self.reduce_dim_beta = nn.Linear(d_model, in_features)
156 | self.reduce_dim_gamma = nn.Linear(in_features, in_features)
157 | # for delta decay factor
158 | self.weight_combine = nn.Linear(in_features + window_size, in_features)
159 |
160 | def forward(self, x: Tensor, mask: Tensor, **kwargs):
161 | # tsl shape to original shape: [b s n c] -> [b s c]
162 | x = rearrange(x, 'b s n c -> b s (n c)')
163 | mask = rearrange(mask, 'b s n c -> b s (n c)')
164 | # whiten missing values
165 | x = torch.where(mask.bool(), x, self.mask_token)
166 |
167 | # 1st DMSA ############################################################
168 |
169 | # Cat mask (eventually)
170 | if self.input_with_mask:
171 | x_in = torch.cat([x, mask], dim=-1)
172 | else:
173 | x_in = x
174 | x_in = self.embedding_1(x_in)
175 | z = self.dropout(self.position_enc(x_in))
176 |
177 | # Encode (deeply?)
178 | if self.param_sharing_strategy == 'between_group':
179 | for _ in range(self.n_groups):
180 | for encoder_layer in self.layer_stack_for_first_block:
181 | z, _ = encoder_layer(z)
182 | else:
183 | for encoder_layer in self.layer_stack_for_first_block:
184 | for _ in range(self.n_group_inner_layers):
185 | z, _ = encoder_layer(z)
186 |
187 | x_tilde_1 = self.reduce_dim_z(z)
188 | x_hat_1 = mask * x + (1 - mask) * x_tilde_1
189 |
190 | # 2nd DMSA ############################################################
191 |
192 | # Cat mask (eventually)
193 | if self.input_with_mask:
194 | x_in = torch.cat([x_hat_1, mask], dim=-1)
195 | else:
196 | x_in = x_hat_1
197 | x_in = self.embedding_2(x_in)
198 | z = self.position_enc(x_in)
199 |
200 | # Encode
201 | if self.param_sharing_strategy == 'between_group':
202 | for _ in range(self.n_groups):
203 | for encoder_layer in self.layer_stack_for_second_block:
204 | z, attn_weights = encoder_layer(z)
205 | else:
206 | for encoder_layer in self.layer_stack_for_second_block:
207 | for _ in range(self.n_group_inner_layers):
208 | z, attn_weights = encoder_layer(z)
209 |
210 | x_tilde_2 = self.reduce_dim_gamma(F.relu(self.reduce_dim_beta(z)))
211 |
212 | # Average attention heads
213 | if attn_weights.size(1) > 1:
214 | attn_weights = attn_weights.mean(dim=1)
215 |
216 | weights = torch.cat([mask, attn_weights], dim=2)
217 | weights = F.sigmoid(self.weight_combine(weights))
218 | # combine x_tilde_1 and X_tilde_2
219 | # x_tilde_3 = (1 - weights) * x_tilde_2 + weights * x_tilde_1
220 | x_hat = torch.lerp(x_tilde_2, x_tilde_1, weights)
221 | # replace non-missing part with original data
222 | x_tilde = [x_tilde_1, x_tilde_2, x_hat]
223 |
224 | # restore original shape
225 | x_hat = rearrange(x_hat, 'b s (n c) -> b s n c', n=self.n_nodes)
226 | x_tilde = [rearrange(tens, 'b s (n c) -> b s n c', n=self.n_nodes)
227 | for tens in x_tilde]
228 |
229 | return x_hat, x_tilde
230 |
231 | @staticmethod
232 | def add_model_specific_args(parser):
233 | parser.opt_list('--d-model', type=int, default=256, tunable=True,
234 | options=[64, 128, 256, 512, 1024])
235 | parser.opt_list('--d-inner', type=int, default=128, tunable=True,
236 | options=[128, 256, 512, 1024, 2048, 4096])
237 | parser.opt_list('--n-head', type=int, default=4, tunable=True,
238 | options=[2, 4, 8])
239 | parser.add_argument('--d-k', type=int, default=None)
240 | parser.opt_list('--d-v', type=int, default=64, tunable=True,
241 | options=[64, 128, 256, 512])
242 | parser.add_argument('--dropout', type=float, default=0.1)
243 | #
244 | parser.opt_list('--n-groups', type=int, default=2, tunable=True,
245 | options=[1, 2, 4, 6, 8])
246 | parser.add_argument('--n-group-inner-layers', type=int, default=1)
247 | parser.add_argument('--param-sharing-strategy', type=str,
248 | default='inner_group')
249 | #
250 | parser.add_argument('--input-with-mask', type=bool, default=True)
251 | parser.add_argument('--diagonal-attention-mask', type=bool,
252 | default=True)
253 | parser.add_argument('--trainable-mask-token', type=bool,
254 | default=False)
255 | return parser
256 |
--------------------------------------------------------------------------------
/spin/baselines/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import TransformerModel
2 |
--------------------------------------------------------------------------------
/spin/baselines/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from tsl.nn.base import StaticGraphEmbedding
3 | from tsl.nn.blocks.encoders.mlp import MLP
4 | from tsl.nn.blocks.encoders.transformer import SpatioTemporalTransformerLayer, \
5 | TransformerLayer
6 | from tsl.nn.layers import PositionalEncoding
7 | from tsl.utils.parser_utils import ArgParser, str_to_bool
8 |
9 |
10 | class TransformerModel(nn.Module):
11 | r"""Spatiotemporal Transformer for multivariate time series imputation.
12 |
13 | Args:
14 | input_size (int): Input size.
15 | hidden_size (int): Dimension of the learned representations.
16 | output_size (int): Dimension of the output.
17 | ff_size (int): Units in the MLP after self attention.
18 | u_size (int): Dimension of the exogenous variables.
19 | n_heads (int, optional): Number of parallel attention heads.
20 | n_layers (int, optional): Number of layers.
21 | dropout (float, optional): Dropout probability.
22 | axis (str, optional): Dimension on which to apply attention to update
23 | the representations.
24 | activation (str, optional): Activation function.
25 | """
26 |
27 | def __init__(self,
28 | input_size: int,
29 | hidden_size: int,
30 | output_size: int,
31 | ff_size: int,
32 | u_size: int,
33 | n_heads: int = 1,
34 | n_layers: int = 1,
35 | dropout: float = 0.,
36 | condition_on_u: bool = True,
37 | axis: str = 'both',
38 | activation: str = 'elu'):
39 | super(TransformerModel, self).__init__()
40 |
41 | self.condition_on_u = condition_on_u
42 | if condition_on_u:
43 | self.u_enc = MLP(u_size, hidden_size, n_layers=2)
44 | self.h_enc = MLP(input_size, hidden_size, n_layers=2)
45 |
46 | self.mask_token = StaticGraphEmbedding(1, hidden_size)
47 |
48 | self.pe = PositionalEncoding(hidden_size)
49 |
50 | kwargs = dict(input_size=hidden_size,
51 | hidden_size=hidden_size,
52 | ff_size=ff_size,
53 | n_heads=n_heads,
54 | activation=activation,
55 | causal=False,
56 | dropout=dropout)
57 |
58 | if axis in ['steps', 'nodes']:
59 | transformer_layer = TransformerLayer
60 | kwargs['axis'] = axis
61 | elif axis == 'both':
62 | transformer_layer = SpatioTemporalTransformerLayer
63 | else:
64 | raise ValueError(f'"{axis}" is not a valid axis.')
65 |
66 | self.encoder = nn.ModuleList()
67 | self.readout = nn.ModuleList()
68 | for _ in range(n_layers):
69 | self.encoder.append(transformer_layer(**kwargs))
70 | self.readout.append(MLP(input_size=hidden_size,
71 | hidden_size=ff_size,
72 | output_size=output_size,
73 | n_layers=2,
74 | dropout=dropout))
75 |
76 | def forward(self, x, u, mask):
77 | # x: [batches steps nodes features]
78 | # u: [batches steps (nodes) features]
79 | x = x * mask
80 |
81 | h = self.h_enc(x)
82 | h = mask * h + (1 - mask) * self.mask_token()
83 |
84 | if self.condition_on_u:
85 | h = h + self.u_enc(u).unsqueeze(-2)
86 |
87 | h = self.pe(h)
88 |
89 | out = []
90 | for encoder, mlp in zip(self.encoder, self.readout):
91 | h = encoder(h)
92 | out.append(mlp(h))
93 |
94 | x_hat = out.pop(-1)
95 | return x_hat, out
96 |
97 | @staticmethod
98 | def add_model_specific_args(parser: ArgParser):
99 | parser.opt_list('--hidden-size', type=int, default=32, tunable=True,
100 | options=[16, 32, 64, 128, 256])
101 | parser.opt_list('--ff-size', type=int, default=32, tunable=True,
102 | options=[32, 64, 128, 256, 512, 1024])
103 | parser.opt_list('--n-layers', type=int, default=1, tunable=True,
104 | options=[1, 2, 3])
105 | parser.opt_list('--n-heads', type=int, default=1, tunable=True,
106 | options=[1, 2, 3])
107 | parser.opt_list('--dropout', type=float, default=0., tunable=True,
108 | options=[0., 0.1, 0.25, 0.5])
109 | parser.add_argument('--condition-on-u', type=str_to_bool, nargs='?',
110 | const=True, default=True)
111 | parser.opt_list('--axis', type=str, default='both', tunable=True,
112 | options=['steps', 'both'])
113 | return parser
114 |
--------------------------------------------------------------------------------
/spin/imputers/__init__.py:
--------------------------------------------------------------------------------
1 | from .brits_imputer import BRITSImputer
2 | from .saits_imputer import SAITSImputer
3 | from .spin_imputer import SPINImputer
4 |
--------------------------------------------------------------------------------
/spin/imputers/brits_imputer.py:
--------------------------------------------------------------------------------
1 | from tsl.imputers import Imputer
2 |
3 | from ..baselines import BRITS
4 |
5 |
6 | class BRITSImputer(Imputer):
7 |
8 | def shared_step(self, batch, mask):
9 | y = y_loss = batch.y
10 | y_hat = y_hat_loss = self.predict_batch(batch, preprocess=False,
11 | postprocess=not self.scale_target)
12 |
13 | if self.scale_target:
14 | y_loss = batch.transform['y'].transform(y)
15 | y_hat = batch.transform['y'].inverse_transform(y_hat)
16 |
17 | y_hat_loss, y_loss, mask = self.trim_warm_up(y_hat_loss, y_loss, mask)
18 |
19 | imputation, predictions = y_hat_loss
20 | imp_fwd, imp_bwd = predictions[:2]
21 | y_hat = y_hat[0]
22 |
23 | loss = sum([self.loss_fn(pred, y_loss, mask) for pred in predictions])
24 | loss += BRITS.consistency_loss(imp_fwd, imp_bwd)
25 |
26 | return y_hat.detach(), y, loss
27 |
--------------------------------------------------------------------------------
/spin/imputers/saits_imputer.py:
--------------------------------------------------------------------------------
1 | from tsl.imputers import Imputer
2 |
3 |
4 | class SAITSImputer(Imputer):
5 |
6 | def shared_step(self, batch, mask):
7 | y = y_loss = batch.y
8 | y_hat = y_hat_loss = self.predict_batch(batch, preprocess=False,
9 | postprocess=not self.scale_target)
10 |
11 | if self.scale_target:
12 | y_loss = batch.transform['y'].transform(y)
13 | y_hat = batch.transform['y'].inverse_transform(y_hat)
14 |
15 | y_hat_loss, y_loss, mask = self.trim_warm_up(y_hat_loss, y_loss, mask)
16 |
17 | if isinstance(y_hat_loss, (list, tuple)):
18 | imputation, predictions = y_hat_loss
19 | y_hat = y_hat[0]
20 | else:
21 | imputation, predictions = y_hat_loss, []
22 |
23 | # Imputation loss
24 | if self.training:
25 | injected_missing = batch.original_mask - batch.mask
26 | mask = batch.mask
27 | loss = self.loss_fn(imputation, y_loss, injected_missing)
28 | else:
29 | loss = 0
30 |
31 | # Reconstruction loss
32 | for pred in predictions:
33 | pred_loss = self.loss_fn(pred, y_loss, mask)
34 | loss += self.prediction_loss_weight * pred_loss / 3
35 |
36 | return y_hat.detach(), y, loss
37 |
--------------------------------------------------------------------------------
/spin/imputers/spin_imputer.py:
--------------------------------------------------------------------------------
1 | from typing import Type, Mapping, Callable, Optional, Union, List
2 |
3 | import torch
4 | from torchmetrics import Metric
5 | from tsl.imputers import Imputer
6 | from tsl.predictors import Predictor
7 |
8 | from ..utils import k_hop_subgraph_sampler
9 |
10 |
11 | class SPINImputer(Imputer):
12 |
13 | def __init__(self,
14 | model_class: Type,
15 | model_kwargs: Mapping,
16 | optim_class: Type,
17 | optim_kwargs: Mapping,
18 | loss_fn: Callable,
19 | scale_target: bool = True,
20 | whiten_prob: Union[float, List[float]] = 0.2,
21 | n_roots_subgraph: Optional[int] = None,
22 | n_hops: int = 2,
23 | max_edges_subgraph: Optional[int] = 1000,
24 | cut_edges_uniformly: bool = False,
25 | prediction_loss_weight: float = 1.0,
26 | metrics: Optional[Mapping[str, Metric]] = None,
27 | scheduler_class: Optional = None,
28 | scheduler_kwargs: Optional[Mapping] = None):
29 | super(SPINImputer, self).__init__(model_class=model_class,
30 | model_kwargs=model_kwargs,
31 | optim_class=optim_class,
32 | optim_kwargs=optim_kwargs,
33 | loss_fn=loss_fn,
34 | scale_target=scale_target,
35 | whiten_prob=whiten_prob,
36 | prediction_loss_weight=prediction_loss_weight,
37 | metrics=metrics,
38 | scheduler_class=scheduler_class,
39 | scheduler_kwargs=scheduler_kwargs)
40 | self.n_roots = n_roots_subgraph
41 | self.n_hops = n_hops
42 | self.max_edges_subgraph = max_edges_subgraph
43 | self.cut_edges_uniformly = cut_edges_uniformly
44 |
45 | def on_after_batch_transfer(self, batch, dataloader_idx):
46 | if self.training and self.n_roots is not None:
47 | batch = k_hop_subgraph_sampler(batch, self.n_hops, self.n_roots,
48 | max_edges=self.max_edges_subgraph,
49 | cut_edges_uniformly=self.cut_edges_uniformly)
50 | return super(SPINImputer, self).on_after_batch_transfer(batch,
51 | dataloader_idx)
52 |
53 | def training_step(self, batch, batch_idx):
54 | injected_missing = (batch.original_mask - batch.mask)
55 | if 'target_nodes' in batch:
56 | injected_missing = injected_missing[..., batch.target_nodes, :]
57 | # batch.input.target_mask = injected_missing
58 | y_hat, y, loss = self.shared_step(batch, mask=injected_missing)
59 |
60 | # Logging
61 | self.train_metrics.update(y_hat, y, batch.eval_mask)
62 | self.log_metrics(self.train_metrics, batch_size=batch.batch_size)
63 | self.log_loss('train', loss, batch_size=batch.batch_size)
64 | if 'target_nodes' in batch:
65 | torch.cuda.empty_cache()
66 | return loss
67 |
68 | def validation_step(self, batch, batch_idx):
69 | # batch.input.target_mask = batch.eval_mask
70 | y_hat, y, val_loss = self.shared_step(batch, batch.eval_mask)
71 |
72 | # Logging
73 | self.val_metrics.update(y_hat, y, batch.eval_mask)
74 | self.log_metrics(self.val_metrics, batch_size=batch.batch_size)
75 | self.log_loss('val', val_loss, batch_size=batch.batch_size)
76 | return val_loss
77 |
78 | def test_step(self, batch, batch_idx):
79 | # batch.input.target_mask = batch.eval_mask
80 | # Compute outputs and rescale
81 | y_hat = self.predict_batch(batch, preprocess=False, postprocess=True)
82 |
83 | if isinstance(y_hat, (list, tuple)):
84 | y_hat = y_hat[0]
85 |
86 | y, eval_mask = batch.y, batch.eval_mask
87 | test_loss = self.loss_fn(y_hat, y, eval_mask)
88 |
89 | # Logging
90 | self.test_metrics.update(y_hat.detach(), y, eval_mask)
91 | self.log_metrics(self.test_metrics, batch_size=batch.batch_size)
92 | self.log_loss('test', test_loss, batch_size=batch.batch_size)
93 | return test_loss
94 |
95 | @staticmethod
96 | def add_argparse_args(parser, **kwargs):
97 | parser = Predictor.add_argparse_args(parser)
98 | parser.add_argument('--whiten-prob', type=float, default=0.05)
99 | parser.add_argument('--prediction-loss-weight', type=float, default=1.0)
100 | parser.add_argument('--n-roots-subgraph', type=int, default=None)
101 | parser.add_argument('--n-hops', type=int, default=2)
102 | parser.add_argument('--max-edges-subgraph', type=int, default=1000)
103 | parser.add_argument('--cut-edges-uniformly', type=bool, default=False)
104 | return parser
105 |
--------------------------------------------------------------------------------
/spin/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .postional_encoding import PositionalEncoder
2 | from .additive_attention import AdditiveAttention
3 | from .temporal_graph_additive_attention import TemporalAdditiveAttention, \
4 | TemporalGraphAdditiveAttention
5 | from .hierarchical_temporal_graph_attention import \
6 | HierarchicalTemporalGraphAttention
7 |
--------------------------------------------------------------------------------
/spin/layers/additive_attention.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch import nn
6 | from torch.nn import LayerNorm, functional as F
7 | from torch_geometric.nn.conv import MessagePassing
8 | from torch_geometric.nn.dense.linear import Linear
9 | from torch_geometric.typing import Adj, OptTensor, PairTensor
10 | from torch_scatter import scatter
11 | from torch_scatter.utils import broadcast
12 | from tsl.nn.blocks.encoders import MLP
13 | from tsl.nn.functional import sparse_softmax
14 |
15 |
16 | class AdditiveAttention(MessagePassing):
17 | def __init__(self, input_size: Union[int, Tuple[int, int]],
18 | output_size: int,
19 | msg_size: Optional[int] = None,
20 | msg_layers: int = 1,
21 | root_weight: bool = True,
22 | reweight: Optional[str] = None,
23 | norm: bool = True,
24 | dropout: float = 0.0,
25 | dim: int = -2,
26 | **kwargs):
27 | kwargs.setdefault('aggr', 'add')
28 | super().__init__(node_dim=dim, **kwargs)
29 |
30 | self.output_size = output_size
31 | if isinstance(input_size, int):
32 | self.src_size = self.tgt_size = input_size
33 | else:
34 | self.src_size, self.tgt_size = input_size
35 |
36 | self.msg_size = msg_size or self.output_size
37 | self.msg_layers = msg_layers
38 |
39 | assert reweight in ['softmax', 'l1', None]
40 | self.reweight = reweight
41 |
42 | self.root_weight = root_weight
43 | self.dropout = dropout
44 |
45 | # key bias is discarded in softmax
46 | self.lin_src = Linear(self.src_size, self.output_size,
47 | weight_initializer='glorot',
48 | bias_initializer='zeros')
49 | self.lin_tgt = Linear(self.tgt_size, self.output_size,
50 | weight_initializer='glorot', bias=False)
51 |
52 | if self.root_weight:
53 | self.lin_skip = Linear(self.tgt_size, self.output_size,
54 | bias=False)
55 | else:
56 | self.register_parameter('lin_skip', None)
57 |
58 | self.msg_nn = nn.Sequential(
59 | nn.PReLU(init=0.2),
60 | MLP(self.output_size, self.msg_size, self.output_size,
61 | n_layers=self.msg_layers, dropout=self.dropout,
62 | activation='prelu')
63 | )
64 |
65 | if self.reweight == 'softmax':
66 | self.msg_gate = nn.Linear(self.output_size, 1, bias=False)
67 | else:
68 | self.msg_gate = nn.Sequential(nn.Linear(self.output_size, 1),
69 | nn.Sigmoid())
70 |
71 | if norm:
72 | self.norm = LayerNorm(self.output_size)
73 | else:
74 | self.register_parameter('norm', None)
75 |
76 | self.reset_parameters()
77 |
78 | def reset_parameters(self):
79 | self.lin_src.reset_parameters()
80 | self.lin_tgt.reset_parameters()
81 | if self.lin_skip is not None:
82 | self.lin_skip.reset_parameters()
83 |
84 | def forward(self, x: PairTensor, edge_index: Adj, mask: OptTensor = None):
85 | # if query/key not provided, defaults to x (e.g., for self-attention)
86 | if isinstance(x, Tensor):
87 | x_src = x_tgt = x
88 | else:
89 | x_src, x_tgt = x
90 | x_tgt = x_tgt if x_tgt is not None else x_src
91 |
92 | N_src, N_tgt = x_src.size(self.node_dim), x_tgt.size(self.node_dim)
93 |
94 | msg_src = self.lin_src(x_src)
95 | msg_tgt = self.lin_tgt(x_tgt)
96 |
97 | msg = (msg_src, msg_tgt)
98 |
99 | # propagate_type: (msg: PairTensor, mask: OptTensor)
100 | out = self.propagate(edge_index, msg=msg, mask=mask,
101 | size=(N_src, N_tgt))
102 |
103 | # skip connection
104 | if self.root_weight:
105 | out = out + self.lin_skip(x_tgt)
106 |
107 | if self.norm is not None:
108 | out = self.norm(out)
109 |
110 | return out
111 |
112 | def normalize_weights(self, weights, index, num_nodes, mask=None):
113 | # mask weights
114 | if mask is not None:
115 | fill_value = float("-inf") if self.reweight == 'softmax' else 0.
116 | weights = weights.masked_fill(torch.logical_not(mask), fill_value)
117 | # eventually reweight
118 | if self.reweight == 'l1':
119 | expanded_index = broadcast(index, weights, self.node_dim)
120 | weights_sum = scatter(weights, expanded_index, self.node_dim,
121 | dim_size=num_nodes, reduce='sum')
122 | weights_sum = weights_sum.index_select(self.node_dim, index)
123 | weights = weights / (weights_sum + 1e-5)
124 | elif self.reweight == 'softmax':
125 | weights = sparse_softmax(weights, index, num_nodes=num_nodes,
126 | dim=self.node_dim)
127 | return weights
128 |
129 | def message(self, msg_j: Tensor, msg_i: Tensor, index, size_i,
130 | mask_j: OptTensor = None) -> Tensor:
131 | msg = self.msg_nn(msg_j + msg_i)
132 | gate = self.msg_gate(msg)
133 | alpha = self.normalize_weights(gate, index, size_i, mask_j)
134 | alpha = F.dropout(alpha, p=self.dropout, training=self.training)
135 | out = alpha * msg
136 | return out
137 |
138 | def __repr__(self) -> str:
139 | return (f'{self.__class__.__name__}({self.output_size}, '
140 | f'dim={self.node_dim}, '
141 | f'root_weight={self.root_weight})')
142 |
143 |
144 | class TemporalAdditiveAttention(AdditiveAttention):
145 | def __init__(self, input_size: Union[int, Tuple[int, int]],
146 | output_size: int,
147 | msg_size: Optional[int] = None,
148 | msg_layers: int = 1,
149 | root_weight: bool = True,
150 | reweight: Optional[str] = None,
151 | norm: bool = True,
152 | dropout: float = 0.0,
153 | **kwargs):
154 | kwargs.setdefault('dim', 1)
155 | super().__init__(input_size=input_size,
156 | output_size=output_size,
157 | msg_size=msg_size,
158 | msg_layers=msg_layers,
159 | root_weight=root_weight,
160 | reweight=reweight,
161 | dropout=dropout,
162 | norm=norm,
163 | **kwargs)
164 |
165 | def forward(self, x: PairTensor, mask: OptTensor = None,
166 | temporal_mask: OptTensor = None,
167 | causal_lag: Optional[int] = None):
168 | # x: [b s * c] query: [b l * c] key: [b s * c]
169 | # mask: [b s * c] temporal_mask: [l s]
170 | if isinstance(x, Tensor):
171 | x_src = x_tgt = x
172 | else:
173 | x_src, x_tgt = x
174 | x_tgt = x_tgt if x_tgt is not None else x_src
175 |
176 | l, s = x_tgt.size(self.node_dim), x_src.size(self.node_dim)
177 | i = torch.arange(l, dtype=torch.long, device=x_src.device)
178 | j = torch.arange(s, dtype=torch.long, device=x_src.device)
179 |
180 | # compute temporal index, from j to i
181 | if temporal_mask is None and isinstance(causal_lag, int):
182 | temporal_mask = tuple(torch.tril_indices(l, l, offset=-causal_lag,
183 | device=x_src.device))
184 | if temporal_mask is not None:
185 | assert temporal_mask.size() == (l, s)
186 | i, j = torch.meshgrid(i, j)
187 | edge_index = torch.stack((j[temporal_mask], i[temporal_mask]))
188 | else:
189 | edge_index = torch.cartesian_prod(j, i).T
190 |
191 | return super(TemporalAdditiveAttention, self).forward(x, edge_index,
192 | mask=mask)
193 |
--------------------------------------------------------------------------------
/spin/layers/hierarchical_temporal_graph_attention.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor
4 | from torch_geometric.nn.conv import MessagePassing
5 | from torch_geometric.nn.dense.linear import Linear
6 | from torch_geometric.typing import Adj, OptTensor
7 | from tsl.nn.functional import sparse_softmax
8 | from tsl.nn.layers.norm import LayerNorm
9 |
10 | from .additive_attention import TemporalAdditiveAttention
11 |
12 |
13 | class HierarchicalTemporalGraphAttention(MessagePassing):
14 | def __init__(self, h_size: int, z_size: int,
15 | msg_size: Optional[int] = None,
16 | msg_layers: int = 1,
17 | root_weight: bool = True,
18 | reweight: Optional[str] = None,
19 | update_z_cross: bool = True,
20 | mask_temporal: bool = True,
21 | mask_spatial: bool = True,
22 | norm: bool = True,
23 | dropout: float = 0.,
24 | aggr: str = 'add',
25 | **kwargs):
26 | self.spatial_aggr = aggr
27 | if aggr == 'softmax':
28 | aggr = 'add'
29 | super(HierarchicalTemporalGraphAttention, self).__init__(node_dim=-2,
30 | aggr=aggr,
31 | **kwargs)
32 |
33 | # store dimensions
34 | self.h_size = h_size
35 | self.z_size = z_size
36 | self.msg_size = msg_size or z_size
37 |
38 | self.mask_temporal = mask_temporal
39 | self.mask_spatial = mask_spatial
40 |
41 | self.root_weight = root_weight
42 | self.norm = norm
43 | self.dropout = dropout
44 | self._z_cross = None
45 |
46 | self.zh_self = TemporalAdditiveAttention(
47 | input_size=(h_size, z_size),
48 | output_size=z_size,
49 | msg_size=msg_size,
50 | msg_layers=msg_layers,
51 | reweight=reweight,
52 | dropout=dropout,
53 | root_weight=True,
54 | norm=True
55 | )
56 |
57 | self.hz_self = TemporalAdditiveAttention(
58 | input_size=(z_size, h_size),
59 | output_size=h_size,
60 | msg_size=msg_size,
61 | msg_layers=msg_layers,
62 | reweight=reweight,
63 | dropout=dropout,
64 | root_weight=True,
65 | norm=False
66 | )
67 |
68 | if update_z_cross:
69 | self.zh_cross = TemporalAdditiveAttention(
70 | input_size=(h_size, z_size),
71 | output_size=z_size,
72 | msg_size=msg_size,
73 | msg_layers=msg_layers,
74 | reweight=reweight,
75 | dropout=dropout,
76 | root_weight=True,
77 | norm=True
78 | )
79 | else:
80 | self.register_parameter('zh_cross', None)
81 |
82 | self.hz_cross = TemporalAdditiveAttention(
83 | input_size=(z_size, h_size),
84 | output_size=h_size,
85 | msg_size=msg_size,
86 | msg_layers=msg_layers,
87 | reweight=None,
88 | dropout=dropout,
89 | root_weight=True,
90 | norm=False
91 | )
92 |
93 | if self.spatial_aggr == 'softmax':
94 | self.lin_alpha_h = Linear(h_size, 1, bias=False)
95 | self.lin_alpha_z = Linear(z_size, 1, bias=False)
96 | else:
97 | self.register_parameter('lin_alpha_h', None)
98 | self.register_parameter('lin_alpha_z', None)
99 |
100 | if self.root_weight:
101 | self.h_skip = Linear(h_size, h_size, bias_initializer='zeros')
102 | self.z_skip = Linear(z_size, z_size, bias_initializer='zeros')
103 | else:
104 | self.register_parameter('h_skip', None)
105 | self.register_parameter('z_skip', None)
106 |
107 | if self.norm:
108 | self.h_norm = LayerNorm(h_size)
109 | self.z_norm = LayerNorm(z_size)
110 | else:
111 | self.register_parameter('h_norm', None)
112 | self.register_parameter('z_norm', None)
113 |
114 | self.reset_parameters()
115 |
116 | def reset_parameters(self):
117 | self.zh_self.reset_parameters()
118 | self.hz_self.reset_parameters()
119 | if self.zh_cross is not None:
120 | self.zh_cross.reset_parameters()
121 | self.hz_cross.reset_parameters()
122 | if self.spatial_aggr == 'softmax':
123 | self.lin_alpha_h.reset_parameters()
124 | self.lin_alpha_z.reset_parameters()
125 | if self.root_weight:
126 | self.h_skip.reset_parameters()
127 | self.z_skip.reset_parameters()
128 | if self.norm:
129 | self.h_norm.reset_parameters()
130 | self.z_norm.reset_parameters()
131 |
132 | def forward(self, h: Tensor, z: Tensor, edge_index: Adj,
133 | mask: OptTensor = None):
134 | # inputs: [batch, steps, nodes, channels]
135 |
136 | z_out = self.zh_self(x=(h, z),
137 | mask=mask if self.mask_temporal else None)
138 | h_self = self.hz_self(x=(z_out, h))
139 |
140 | # propagate query, key and value
141 | n_src, n_tgt = h.size(-2), z.size(-2)
142 | h_out = self.propagate(h=h_self, z=z_out,
143 | edge_index=edge_index,
144 | mask=mask if self.mask_spatial else None,
145 | size=(n_src, n_tgt))
146 |
147 | if self._z_cross is not None:
148 | z_out = self.aggregate(self._z_cross, edge_index[1], dim_size=n_tgt)
149 | self._z_cross = None
150 |
151 | # skip connection
152 | if self.root_weight:
153 | h_out = h_out + self.h_skip(h)
154 | z_out = z_out + self.z_skip(z)
155 |
156 | if self.norm:
157 | h_out = self.h_norm(h_out)
158 | z_out = self.z_norm(z_out)
159 |
160 | return h_out, z_out
161 |
162 | def h_cross_message(self, h_i: Tensor, z_j: Tensor, index, size_i) -> Tensor:
163 | # [batch, steps, edges, channels]
164 | h_cross = self.hz_cross((z_j, h_i))
165 | if self.spatial_aggr == 'softmax':
166 | alpha_h = self.lin_alpha_h(h_cross)
167 | alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i,
168 | dim=self.node_dim)
169 | h_cross = alpha_h * h_cross
170 | return h_cross
171 |
172 | def hz_cross_message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor,
173 | index, size_i, mask_j: OptTensor) -> Tensor:
174 | # [batch, steps, edges, channels]
175 | z_cross = self.zh_cross((h_j, z_i), mask=mask_j)
176 | h_cross = self.hz_cross((z_cross, h_i))
177 | if self.spatial_aggr == 'softmax':
178 | # reweight z
179 | alpha_z = self.lin_alpha_z(z_cross)
180 | alpha_z = sparse_softmax(alpha_z, index, num_nodes=size_i,
181 | dim=self.node_dim)
182 | z_cross = alpha_z * z_cross
183 | # reweight h
184 | alpha_h = self.lin_alpha_h(h_cross)
185 | alpha_h = sparse_softmax(alpha_h, index, num_nodes=size_i,
186 | dim=self.node_dim)
187 | h_cross = alpha_h * h_cross
188 | self._z_cross = z_cross
189 | return h_cross
190 |
191 | def message(self, h_i: Tensor, h_j: Tensor, z_i: Tensor, z_j: Tensor,
192 | index, size_i, mask_j: OptTensor) -> Tensor:
193 | if self.zh_cross is not None:
194 | return self.hz_cross_message(h_i, h_j, z_i, index, size_i, mask_j)
195 | return self.h_cross_message(h_i, z_j, index, size_i)
196 |
--------------------------------------------------------------------------------
/spin/layers/postional_encoding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import nn
4 | from tsl.nn.base import StaticGraphEmbedding
5 | from tsl.nn.blocks.encoders import MLP
6 | from tsl.nn.layers import PositionalEncoding
7 |
8 |
9 | class PositionalEncoder(nn.Module):
10 |
11 | def __init__(self, in_channels, out_channels,
12 | n_layers: int = 1,
13 | n_nodes: Optional[int] = None):
14 | super(PositionalEncoder, self).__init__()
15 | self.lin = nn.Linear(in_channels, out_channels)
16 | self.activation = nn.LeakyReLU()
17 | self.mlp = MLP(out_channels, out_channels, out_channels,
18 | n_layers=n_layers, activation='relu')
19 | self.positional = PositionalEncoding(out_channels)
20 | if n_nodes is not None:
21 | self.node_emb = StaticGraphEmbedding(n_nodes, out_channels)
22 | else:
23 | self.register_parameter('node_emb', None)
24 |
25 | def forward(self, x, node_emb=None, node_index=None):
26 | if node_emb is None:
27 | node_emb = self.node_emb(token_index=node_index)
28 | # x: [b s c], node_emb: [n c] -> [b s n c]
29 | x = self.lin(x)
30 | x = self.activation(x.unsqueeze(-2) + node_emb)
31 | out = self.mlp(x)
32 | out = self.positional(out)
33 | return out
34 |
--------------------------------------------------------------------------------
/spin/layers/temporal_graph_additive_attention.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch_geometric.nn.conv import MessagePassing
6 | from torch_geometric.nn.dense.linear import Linear
7 | from torch_geometric.typing import Adj, OptTensor, OptPairTensor
8 | from tsl.nn.layers.norm import LayerNorm
9 |
10 | from .additive_attention import TemporalAdditiveAttention
11 |
12 |
13 | class TemporalGraphAdditiveAttention(MessagePassing):
14 | def __init__(self, input_size: Union[int, Tuple[int, int]],
15 | output_size: int,
16 | msg_size: Optional[int] = None,
17 | msg_layers: int = 1,
18 | root_weight: bool = True,
19 | reweight: Optional[str] = None,
20 | temporal_self_attention: bool = True,
21 | mask_temporal: bool = True,
22 | mask_spatial: bool = True,
23 | norm: bool = True,
24 | dropout: float = 0.,
25 | **kwargs):
26 | kwargs.setdefault('aggr', 'add')
27 | super(TemporalGraphAdditiveAttention, self).__init__(node_dim=-2,
28 | **kwargs)
29 |
30 | # store dimensions
31 | if isinstance(input_size, int):
32 | self.src_size = self.tgt_size = input_size
33 | else:
34 | self.src_size, self.tgt_size = input_size
35 | self.output_size = output_size
36 | self.msg_size = msg_size or self.output_size
37 |
38 | self.mask_temporal = mask_temporal
39 | self.mask_spatial = mask_spatial
40 |
41 | self.root_weight = root_weight
42 | self.dropout = dropout
43 |
44 | if temporal_self_attention:
45 | self.self_attention = TemporalAdditiveAttention(
46 | input_size=input_size,
47 | output_size=output_size,
48 | msg_size=msg_size,
49 | msg_layers=msg_layers,
50 | reweight=reweight,
51 | dropout=dropout,
52 | root_weight=False,
53 | norm=False
54 | )
55 | else:
56 | self.register_parameter('self_attention', None)
57 |
58 | self.cross_attention = TemporalAdditiveAttention(input_size=input_size,
59 | output_size=output_size,
60 | msg_size=msg_size,
61 | msg_layers=msg_layers,
62 | reweight=reweight,
63 | dropout=dropout,
64 | root_weight=False,
65 | norm=False)
66 |
67 | if self.root_weight:
68 | self.lin_skip = Linear(self.tgt_size, self.output_size,
69 | bias_initializer='zeros')
70 | else:
71 | self.register_parameter('lin_skip', None)
72 |
73 | if norm:
74 | self.norm = LayerNorm(output_size)
75 | else:
76 | self.register_parameter('norm', None)
77 |
78 | self.reset_parameters()
79 |
80 | def reset_parameters(self):
81 | self.cross_attention.reset_parameters()
82 | if self.self_attention is not None:
83 | self.self_attention.reset_parameters()
84 | if self.lin_skip is not None:
85 | self.lin_skip.reset_parameters()
86 | if self.norm is not None:
87 | self.norm.reset_parameters()
88 |
89 | def forward(self, x: OptPairTensor,
90 | edge_index: Adj, edge_weight: OptTensor = None,
91 | mask: OptTensor = None):
92 | # inputs: [batch, steps, nodes, channels]
93 | if isinstance(x, Tensor):
94 | x_src = x_tgt = x
95 | else:
96 | x_src, x_tgt = x
97 | x_tgt = x_tgt if x_tgt is not None else x_src
98 |
99 | n_src, n_tgt = x_src.size(-2), x_tgt.size(-2)
100 |
101 | # propagate query, key and value
102 | out = self.propagate(x=(x_src, x_tgt),
103 | edge_index=edge_index, edge_weight=edge_weight,
104 | mask=mask if self.mask_spatial else None,
105 | size=(n_src, n_tgt))
106 |
107 | if self.self_attention is not None:
108 | s, l = x_src.size(1), x_tgt.size(1)
109 | if s == l:
110 | attn_mask = ~torch.eye(l, l, dtype=torch.bool,
111 | device=x_tgt.device)
112 | else:
113 | attn_mask = None
114 | temp = self.self_attention(x=(x_src, x_tgt),
115 | mask=mask if self.mask_temporal else None,
116 | temporal_mask=attn_mask)
117 | out = out + temp
118 |
119 | # skip connection
120 | if self.root_weight:
121 | out = out + self.lin_skip(x_tgt)
122 |
123 | if self.norm is not None:
124 | out = self.norm(out)
125 |
126 | return out
127 |
128 | def message(self, x_i: Tensor, x_j: Tensor,
129 | edge_weight: OptTensor, mask_j: OptTensor) -> Tensor:
130 | # [batch, steps, edges, channels]
131 |
132 | out = self.cross_attention((x_j, x_i), mask=mask_j)
133 |
134 | if edge_weight is not None:
135 | out = out * edge_weight.view(-1, 1)
136 | return out
137 |
--------------------------------------------------------------------------------
/spin/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .spin import SPINModel
2 | from .spin_hierarchical import SPINHierarchicalModel
3 |
--------------------------------------------------------------------------------
/spin/models/spin.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn, Tensor
5 | from torch.nn import LayerNorm
6 | from torch_geometric.typing import OptTensor
7 | from tsl.nn.base import StaticGraphEmbedding
8 | from tsl.nn.blocks.encoders import MLP
9 |
10 | from ..layers import PositionalEncoder, TemporalGraphAdditiveAttention
11 |
12 |
13 | class SPINModel(nn.Module):
14 |
15 | def __init__(self, input_size: int,
16 | hidden_size: int,
17 | n_nodes: int,
18 | u_size: Optional[int] = None,
19 | output_size: Optional[int] = None,
20 | temporal_self_attention: bool = True,
21 | reweight: Optional[str] = 'softmax',
22 | n_layers: int = 4,
23 | eta: int = 3,
24 | message_layers: int = 1):
25 | super(SPINModel, self).__init__()
26 |
27 | u_size = u_size or input_size
28 | output_size = output_size or input_size
29 | self.n_nodes = n_nodes
30 | self.n_layers = n_layers
31 | self.eta = eta
32 | self.temporal_self_attention = temporal_self_attention
33 |
34 | self.u_enc = PositionalEncoder(in_channels=u_size,
35 | out_channels=hidden_size,
36 | n_layers=2,
37 | n_nodes=n_nodes)
38 |
39 | self.h_enc = MLP(input_size, hidden_size, n_layers=2)
40 | self.h_norm = LayerNorm(hidden_size)
41 |
42 | self.valid_emb = StaticGraphEmbedding(n_nodes, hidden_size)
43 | self.mask_emb = StaticGraphEmbedding(n_nodes, hidden_size)
44 |
45 | self.x_skip = nn.ModuleList()
46 | self.encoder, self.readout = nn.ModuleList(), nn.ModuleList()
47 | for l in range(n_layers):
48 | x_skip = nn.Linear(input_size, hidden_size)
49 | encoder = TemporalGraphAdditiveAttention(
50 | input_size=hidden_size,
51 | output_size=hidden_size,
52 | msg_size=hidden_size,
53 | msg_layers=message_layers,
54 | temporal_self_attention=temporal_self_attention,
55 | reweight=reweight,
56 | mask_temporal=True,
57 | mask_spatial=l < eta,
58 | norm=True,
59 | root_weight=True,
60 | dropout=0.0
61 | )
62 | readout = MLP(hidden_size, hidden_size, output_size,
63 | n_layers=2)
64 | self.x_skip.append(x_skip)
65 | self.encoder.append(encoder)
66 | self.readout.append(readout)
67 |
68 | def forward(self, x: Tensor, u: Tensor, mask: Tensor,
69 | edge_index: Tensor, edge_weight: OptTensor = None,
70 | node_index: OptTensor = None, target_nodes: OptTensor = None):
71 | if target_nodes is None:
72 | target_nodes = slice(None)
73 |
74 | # Whiten missing values
75 | x = x * mask
76 |
77 | # POSITIONAL ENCODING #################################################
78 | # Obtain spatio-temporal positional encoding for every node-step pair #
79 | # in both observed and target sets. Encoding are obtained by jointly #
80 | # processing node and time positional encoding. #
81 |
82 | # Build (node, timestamp) encoding
83 | q = self.u_enc(u, node_index=node_index)
84 | # Condition value on key
85 | h = self.h_enc(x) + q
86 |
87 | # ENCODER #############################################################
88 | # Obtain representations h^i_t for every (i, t) node-step pair by #
89 | # only taking into account valid data in representation set. #
90 |
91 | # Replace H in missing entries with queries Q
92 | h = torch.where(mask.bool(), h, q)
93 | # Normalize features
94 | h = self.h_norm(h)
95 |
96 | imputations = []
97 |
98 | for l in range(self.n_layers):
99 | if l == self.eta:
100 | # Condition H on two different embeddings to distinguish
101 | # valid values from masked ones
102 | valid = self.valid_emb(token_index=node_index)
103 | masked = self.mask_emb(token_index=node_index)
104 | h = torch.where(mask.bool(), h + valid, h + masked)
105 | # Masked Temporal GAT for encoding representation
106 | h = h + self.x_skip[l](x) * mask # skip connection for valid x
107 | h = self.encoder[l](h, edge_index, mask=mask)
108 | # Read from H to get imputations
109 | target_readout = self.readout[l](h[..., target_nodes, :])
110 | imputations.append(target_readout)
111 |
112 | # Get final layer imputations
113 | x_hat = imputations.pop(-1)
114 |
115 | return x_hat, imputations
116 |
117 | @staticmethod
118 | def add_model_specific_args(parser):
119 | parser.opt_list('--hidden-size', type=int, tunable=True, default=32,
120 | options=[32, 64, 128, 256])
121 | parser.add_argument('--u-size', type=int, default=None)
122 | parser.add_argument('--output-size', type=int, default=None)
123 | parser.add_argument('--temporal-self-attention', type=bool,
124 | default=True)
125 | parser.add_argument('--reweight', type=str, default='softmax')
126 | parser.add_argument('--n-layers', type=int, default=4)
127 | parser.add_argument('--eta', type=int, default=3)
128 | parser.add_argument('--message-layers', type=int, default=1)
129 | return parser
130 |
--------------------------------------------------------------------------------
/spin/models/spin_hierarchical.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn, Tensor
5 | from torch.nn import LayerNorm
6 | from torch_geometric.nn import inits
7 | from torch_geometric.typing import OptTensor
8 | from tsl.nn.base import StaticGraphEmbedding
9 | from tsl.nn.blocks.encoders import MLP
10 |
11 | from ..layers import PositionalEncoder, HierarchicalTemporalGraphAttention
12 |
13 |
14 | class SPINHierarchicalModel(nn.Module):
15 |
16 | def __init__(self, input_size: int,
17 | h_size: int,
18 | z_size: int,
19 | n_nodes: int,
20 | z_heads: int = 1,
21 | u_size: Optional[int] = None,
22 | output_size: Optional[int] = None,
23 | n_layers: int = 5,
24 | eta: int = 3,
25 | message_layers: int = 1,
26 | reweight: Optional[str] = 'softmax',
27 | update_z_cross: bool = True,
28 | norm: bool = True,
29 | spatial_aggr: str = 'add'):
30 | super(SPINHierarchicalModel, self).__init__()
31 |
32 | u_size = u_size or input_size
33 | output_size = output_size or input_size
34 | self.h_size = h_size
35 | self.z_size = z_size
36 |
37 | self.n_nodes = n_nodes
38 | self.z_heads = z_heads
39 | self.n_layers = n_layers
40 | self.eta = eta
41 |
42 | self.v = StaticGraphEmbedding(n_nodes, h_size)
43 | self.lin_v = nn.Linear(h_size, z_size, bias=False)
44 | self.z = nn.Parameter(torch.Tensor(1, z_heads, n_nodes, z_size))
45 | inits.uniform(z_size, self.z)
46 | self.z_norm = LayerNorm(z_size)
47 |
48 | self.u_enc = PositionalEncoder(in_channels=u_size,
49 | out_channels=h_size,
50 | n_layers=2)
51 |
52 | self.h_enc = MLP(input_size, h_size, n_layers=2)
53 | self.h_norm = LayerNorm(h_size)
54 |
55 | self.v1 = StaticGraphEmbedding(n_nodes, h_size)
56 | self.m1 = StaticGraphEmbedding(n_nodes, h_size)
57 |
58 | self.v2 = StaticGraphEmbedding(n_nodes, h_size)
59 | self.m2 = StaticGraphEmbedding(n_nodes, h_size)
60 |
61 | self.x_skip = nn.ModuleList()
62 | self.encoder, self.readout = nn.ModuleList(), nn.ModuleList()
63 | for l in range(n_layers):
64 | x_skip = nn.Linear(input_size, h_size)
65 | encoder = HierarchicalTemporalGraphAttention(
66 | h_size=h_size, z_size=z_size,
67 | msg_size=h_size,
68 | msg_layers=message_layers,
69 | reweight=reweight,
70 | mask_temporal=True,
71 | mask_spatial=l < eta,
72 | update_z_cross=update_z_cross,
73 | norm=norm,
74 | root_weight=True,
75 | aggr=spatial_aggr,
76 | dropout=0.0
77 | )
78 | readout = MLP(h_size, z_size, output_size,
79 | n_layers=2)
80 | self.x_skip.append(x_skip)
81 | self.encoder.append(encoder)
82 | self.readout.append(readout)
83 |
84 | def forward(self, x: Tensor, u: Tensor, mask: Tensor,
85 | edge_index: Tensor, edge_weight: OptTensor = None,
86 | node_index: OptTensor = None, target_nodes: OptTensor = None):
87 | if target_nodes is None:
88 | target_nodes = slice(None)
89 | if node_index is None:
90 | node_index = slice(None)
91 |
92 | # POSITIONAL ENCODING #################################################
93 | # Obtain spatio-temporal positional encoding for every node-step pair #
94 | # in both observed and target sets. Encoding are obtained by jointly #
95 | # processing node and time positional encoding. #
96 | # Condition also embeddings Z on V. #
97 |
98 | v_nodes = self.v(token_index=node_index)
99 | z = self.z[..., node_index, :] + self.lin_v(v_nodes)
100 |
101 | # Build (node, timestamp) encoding
102 | q = self.u_enc(u, node_index=node_index, node_emb=v_nodes)
103 | # Condition value on key
104 | h = self.h_enc(x) + q
105 |
106 | # ENCODER #############################################################
107 | # Obtain representations h^i_t for every (i, t) node-step pair by #
108 | # only taking into account valid data in representation set. #
109 |
110 | # Replace H in missing entries with queries Q. Then, condition H on two
111 | # different embeddings to distinguish valid values from masked ones.
112 | h = torch.where(mask.bool(), h + self.v1(), q + self.m1())
113 | # Normalize features
114 | h, z = self.h_norm(h), self.z_norm(z)
115 |
116 | imputations = []
117 |
118 | for l in range(self.n_layers):
119 | if l == self.eta:
120 | # Condition H on two different embeddings to distinguish
121 | # valid values from masked ones
122 | h = torch.where(mask.bool(), h + self.v2(), h + self.m2())
123 | # Skip connection from input x
124 | h = h + self.x_skip[l](x) * mask
125 | # Masked Temporal GAT for encoding representation
126 | h, z = self.encoder[l](h, z, edge_index, mask=mask)
127 | target_readout = self.readout[l](h[..., target_nodes, :])
128 | imputations.append(target_readout)
129 |
130 | x_hat = imputations.pop(-1)
131 |
132 | return x_hat, imputations
133 |
134 | @staticmethod
135 | def add_model_specific_args(parser):
136 | parser.opt_list('--h-size', type=int, tunable=True, default=32,
137 | options=[16, 32])
138 | parser.opt_list('--z-size', type=int, tunable=True, default=32,
139 | options=[32, 64, 128])
140 | parser.opt_list('--z-heads', type=int, tunable=True, default=2,
141 | options=[1, 2, 4, 6])
142 | parser.add_argument('--u-size', type=int, default=None)
143 | parser.add_argument('--output-size', type=int, default=None)
144 | parser.opt_list('--encoder-layers', type=int, tunable=True, default=2,
145 | options=[1, 2, 3, 4])
146 | parser.opt_list('--decoder-layers', type=int, tunable=True, default=2,
147 | options=[1, 2, 3, 4])
148 | parser.add_argument('--message-layers', type=int, default=1)
149 | parser.opt_list('--reweight', type=str, tunable=True, default='softmax',
150 | options=[None, 'softmax'])
151 | parser.add_argument('--update-z-cross', type=bool, default=True)
152 | parser.opt_list('--norm', type=bool, default=True, tunable=True,
153 | options=[True, False])
154 | parser.opt_list('--spatial-aggr', type=str, tunable=True,
155 | default='add', options=['add', 'softmax'])
156 | return parser
157 |
--------------------------------------------------------------------------------
/spin/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.optim import Optimizer
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 |
7 | class CosineSchedulerWithRestarts(LambdaLR):
8 |
9 | def __init__(self, optimizer: Optimizer,
10 | num_warmup_steps: int,
11 | num_training_steps: int,
12 | min_factor: float = 0.1,
13 | linear_decay: float = 0.67,
14 | num_cycles: int = 1,
15 | last_epoch: int = -1):
16 | """From https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/optimization.py#L138
17 |
18 | Create a schedule with a learning rate that decreases following the values
19 | of the cosine function between the initial lr set in the optimizer to 0,
20 | with several hard restarts, after a warmup period during which it increases
21 | linearly between 0 and the initial lr set in the optimizer.
22 |
23 | Args:
24 | optimizer ([`~torch.optim.Optimizer`]):
25 | The optimizer for which to schedule the learning rate.
26 | num_warmup_steps (`int`):
27 | The number of steps for the warmup phase.
28 | num_training_steps (`int`):
29 | The total number of training steps.
30 | num_cycles (`int`, *optional*, defaults to 1):
31 | The number of hard restarts to use.
32 | last_epoch (`int`, *optional*, defaults to -1):
33 | The index of the last epoch when resuming training.
34 | Return:
35 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
36 | """
37 |
38 | def lr_lambda(current_step):
39 | if current_step < num_warmup_steps:
40 | factor = float(current_step) / float(max(1, num_warmup_steps))
41 | return max(min_factor, factor)
42 | progress = float(current_step - num_warmup_steps)
43 | progress /= float(max(1, num_training_steps - num_warmup_steps))
44 | if progress >= 1.0:
45 | return 0.0
46 | factor = (float(num_cycles) * progress) % 1.0
47 | cos = 0.5 * (1.0 + math.cos(math.pi * factor))
48 | lin = 1.0 - (progress * linear_decay)
49 | return max(min_factor, cos * lin)
50 |
51 | super(CosineSchedulerWithRestarts, self).__init__(optimizer, lr_lambda,
52 | last_epoch)
53 |
--------------------------------------------------------------------------------
/spin/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | from numpy.random import choice
5 | from torch_geometric.utils import k_hop_subgraph
6 | from tsl.data import Batch
7 | from tsl.ops.connectivity import weighted_degree
8 |
9 |
10 | def k_hop_subgraph_sampler(batch: Batch, k: int, num_nodes: int,
11 | max_edges: Optional[int] = None,
12 | cut_edges_uniformly: bool = False):
13 | N = batch.x.size(-2)
14 | roots = choice(np.arange(N), num_nodes, replace=False).tolist()
15 | subgraph = k_hop_subgraph(roots, k, batch.edge_index, relabel_nodes=True,
16 | num_nodes=N, flow='target_to_source')
17 | node_idx, edge_index, node_map, edge_mask = subgraph
18 |
19 | col = edge_index[1]
20 | if max_edges is not None and max_edges < edge_index.size(1):
21 | if not cut_edges_uniformly:
22 | in_degree = weighted_degree(col, num_nodes=len(node_idx))
23 | deg = (1 / in_degree)[col].cpu().numpy()
24 | p = deg / deg.sum()
25 | else:
26 | p = None
27 | keep_edges = sorted(choice(len(col), max_edges, replace=False, p=p))
28 | else:
29 | keep_edges = slice(None)
30 | for key, pattern in batch.pattern.items():
31 | if key in batch.target or key == 'eval_mask':
32 | batch[key] = batch[key][..., roots, :]
33 | elif 'n' in pattern:
34 | batch[key] = batch[key][..., node_idx, :]
35 | elif 'e' in pattern and key != 'edge_index':
36 | batch[key] = batch[key][edge_mask][keep_edges]
37 | batch.input.node_index = node_idx # index of nodes in subgraph
38 | batch.input.target_nodes = node_map # index of roots in subgraph
39 | batch.edge_index = edge_index[:, keep_edges]
40 | return batch
41 |
--------------------------------------------------------------------------------
/tsl_config.yaml:
--------------------------------------------------------------------------------
1 | config_dir: 'config/'
2 | log_dir: 'log/'
--------------------------------------------------------------------------------