├── requirements.txt ├── .gitignore ├── LICENSE ├── run.py ├── generate_training_data.py ├── test.py ├── README.md ├── engine.py ├── train.py ├── util.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.1 2 | numpy>=1.12.1 3 | pandas>=0.19.2 4 | scipy>=0.19.0 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | experiment/ 3 | log/ 4 | run_exp.py 5 | *.out 6 | **/__pycache__/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hyunwook Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--dataset', default = None, type = str, required = True) 6 | parser.add_argument('--device', default = -1, type = int) 7 | parser.add_argument('--exp_id', dest='i', default = 0, help = 'experiment identifier for who want to run experiment multiple time') 8 | args = parser.parse_args() 9 | 10 | device = 'cuda:' + str(args.device) 11 | 12 | BATCH_DICT = { 13 | 'EXPY-TKY': 8, 14 | 'METR-LA': 16, 15 | 'PEMS-BAY': 32 16 | } 17 | 18 | INIT_DICT = { 19 | 'EXPY-TKY': 0, 20 | 'METR-LA': 3, 21 | 'PEMS-BAY': 5 22 | } 23 | 24 | if args.device < 0: 25 | print("Device ID is not Specified!") 26 | print("Continue training with CPU...") 27 | device = 'cpu' 28 | 29 | if args.dataset not in BATCH_DICT.keys(): 30 | raise ValueError("We do not have default setting for the custom dataset. Please specify the datsets from METR-LA, PEMS-BAY, or EXPY-TKY") 31 | 32 | if not os.path.exists('experiment/{}_{}'.format(args.dataset, args.i)): 33 | os.makedirs('experiment/{}_{}'.format(args.dataset, args.i)) 34 | if args.dataset == 'EXPY-TKY': 35 | log = "python -u train.py --batch_size {} --dropout 0.0 --seed -1 --save ./experiment/{}_{}/TESTAM --data ./data/{} --adjdata ./data/{}/adj_mx.pkl --device {}--warmup_epoch {} --out_dim 1" 36 | else: 37 | log = "python -u train.py --batch_size {} --dropout 0.0 --seed -1 --save ./experiment/{}_{}/TESTAM --data ./data/{} --adjdata ./data/{}/adj_mx.pkl --device {} --n_warmup_steps 4000 --warmup_epoch {} --out_dim 1" 38 | batch_size = BATCH_DICT[args.dataset] 39 | warmup_epoch = INIT_DICT[args.dataset] 40 | print(log.format(batch_size, args.dataset, args.i, args.dataset, args.dataset, device, warmup_epoch)) 41 | #os.system(log.format(batch_size, args.dataset, args.i, args.dataset, args.dataset, device, warmup_epoch)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /generate_training_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | feature_list = [data] 31 | if add_time_in_day: 32 | try: 33 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 34 | except: 35 | time_ind = (df.index.values%288)/288 36 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 37 | feature_list.append(time_in_day) 38 | if add_day_in_week: 39 | dow = df.index.dayofweek 40 | dow_tiled = np.tile(dow, [1, num_nodes, 1]).transpose((2, 1, 0)) 41 | feature_list.append(dow_tiled) 42 | 43 | data = np.concatenate(feature_list, axis=-1) 44 | x, y = [], [] 45 | min_t = abs(min(x_offsets)) 46 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 47 | for t in range(min_t, max_t): # t is the index of the last observation. 48 | x.append(data[t + x_offsets, ...]) 49 | y.append(data[t + y_offsets, ...]) 50 | x = np.stack(x, axis=0) 51 | y = np.stack(y, axis=0) 52 | return x, y 53 | 54 | 55 | def generate_train_val_test(args): 56 | seq_length_x, seq_length_y = args.seq_length_x, args.seq_length_y 57 | try: 58 | df = pd.read_hdf(args.traffic_df_filename) 59 | except: 60 | df = pd.read_csv(args.traffic_df_filename, index_col = 0) 61 | df.index = pd.to_datetime(df.index) 62 | #df = pd.read_csv(args.traffic_df_filename,header=None) 63 | # 0 is the latest observed sample. 64 | x_offsets = np.sort(np.concatenate((np.arange(-(seq_length_x - 1), 1, 1),))) 65 | # Predict the next one hour 66 | y_offsets = np.sort(np.arange(args.y_start, (seq_length_y + 1), 1)) 67 | # x: (num_samples, input_length, num_nodes, input_dim) 68 | # y: (num_samples, output_length, num_nodes, output_dim) 69 | x, y = generate_graph_seq2seq_io_data( 70 | df, 71 | x_offsets=x_offsets, 72 | y_offsets=y_offsets, 73 | add_time_in_day=True, 74 | add_day_in_week=args.dow, 75 | ) 76 | 77 | print("x shape: ", x.shape, ", y shape: ", y.shape) 78 | # Write the data into npz file. 79 | num_samples = x.shape[0] 80 | num_test = round(num_samples * 0.2) 81 | num_train = round(num_samples * 0.7) 82 | num_val = num_samples - num_test - num_train 83 | x_train, y_train = x[:num_train], y[:num_train] 84 | x_val, y_val = ( 85 | x[num_train: num_train + num_val], 86 | y[num_train: num_train + num_val], 87 | ) 88 | x_test, y_test = x[-num_test:], y[-num_test:] 89 | 90 | for cat in ["train", "val", "test"]: 91 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 92 | print(cat, "x: ", _x.shape, "y:", _y.shape) 93 | np.savez_compressed( 94 | os.path.join(args.output_dir, f"{cat}.npz"), 95 | x=_x, 96 | y=_y, 97 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 98 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--output_dir", type=str, default="data/METR-LA", help="Output directory.") 105 | parser.add_argument("--traffic_df_filename", type=str, default="data/metr-la.h5", help="Raw traffic readings.",) 106 | parser.add_argument("--seq_length_x", type=int, default=12, help="Sequence Length.",) 107 | parser.add_argument("--seq_length_y", type=int, default=12, help="Sequence Length.",) 108 | parser.add_argument("--y_start", type=int, default=1, help="Y pred start", ) 109 | parser.add_argument("--dow", action='store_true',) 110 | 111 | args = parser.parse_args() 112 | if os.path.exists(args.output_dir): 113 | reply = str(input(f'{args.output_dir} exists. Do you want to overwrite it? (y/n)')).lower().strip() 114 | if reply[0] != 'y': exit 115 | else: 116 | os.makedirs(args.output_dir) 117 | generate_train_val_test(args) 118 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import time, os 5 | import util 6 | from engine import trainer 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--device',type=str,default='cuda:0',help='') 10 | parser.add_argument('--data',type=str,default='data/METR-LA',help='data path') 11 | parser.add_argument('--adjdata',type=str,default=None,help='adj data path') 12 | parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type') 13 | parser.add_argument('--out_dim',type=int,default=1,help='') 14 | parser.add_argument('--nhid',type=int,default=32,help='') 15 | parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension') 16 | parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes') 17 | parser.add_argument('--batch_size',type=int,default=64,help='batch size') 18 | parser.add_argument('--save',type=str,default=None,help='save path') 19 | parser.add_argument('--load_path', type = str, default = None) 20 | 21 | args = parser.parse_args() 22 | 23 | 24 | def count_parameters(model): 25 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 26 | 27 | 28 | def main(): 29 | device = torch.device(args.device) 30 | if args.adjdata: 31 | if os.path.exists(args.adjdata): 32 | sensor_ids, sensor_id_to_ind, adj_mx = util.load_adj(args.adjdata,args.adjtype) 33 | args.num_nodes = len(sensor_ids) 34 | else: 35 | print("Invalid File Path; utliize user-provided args.num_nodes") 36 | dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size) 37 | scaler = dataloader['scaler'] 38 | 39 | print(args) 40 | 41 | engine = trainer(scaler, args.in_dim, args.out_dim, args.num_nodes, args.nhid, 0., device) 42 | if args.load_path is None: 43 | raise ValueError 44 | else: 45 | engine.model.load_state_dict(torch.load(args.load_path, map_location = args.device)) 46 | engine.model.eval() 47 | 48 | 49 | outputs = [] 50 | realy = torch.Tensor(dataloader['y_test']).to(device) 51 | realy = realy.transpose(1,3)[:,:args.out_dim,:,:] 52 | output_gates = [] 53 | output_ind = [] 54 | 55 | for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()): 56 | testx = torch.Tensor(x).to(device) 57 | testx = testx.transpose(1,3) 58 | with torch.no_grad(): 59 | preds, gate, ind_out = engine.model(testx, gate_out = True) 60 | outputs.append(preds) 61 | output_gates.append(gate) 62 | output_ind.append(ind_out) 63 | 64 | yhat = torch.cat(outputs,dim=0) 65 | yhat = yhat[:realy.size(0),...] 66 | 67 | yhat_gates = torch.cat(output_gates, dim = 0)[:realy.size(0),...].permute(0,3,1,2,4).contiguous() 68 | yhat_ind = torch.cat(output_ind, dim = 0)[:realy.size(0),...].permute(0,3,1,2,4).contiguous() 69 | yhat_ind = scaler.inverse_transform(yhat_ind) 70 | tmp = yhat_gates.argmax(dim = -1) 71 | print("Gates!") 72 | for i in range(yhat_gates.size(-1)): 73 | print((tmp == i).sum()) 74 | cur_ind = yhat_ind[:,:,:,-1,i] 75 | metrics = util.metric(cur_ind, realy[...,-1]) 76 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 77 | print(log.format(realy.size(-1), metrics[0], metrics[1], metrics[2])) 78 | print('On average over {} horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'.format(realy.size(-1), *util.metric(yhat_ind[...,i], realy))) 79 | 80 | 81 | amae = [] 82 | amape = [] 83 | armse = [] 84 | results = {'prediction': [], 'ground_truth':[], 'gate':[],} 85 | from copy import deepcopy as cp 86 | for i in range(realy.size(-1)): 87 | pred = scaler.inverse_transform(yhat[...,i]) 88 | real = realy[...,i] 89 | results['prediction'].append(cp(pred).cpu().numpy()) 90 | results['ground_truth'].append(cp(real).cpu().numpy()) 91 | metrics = util.metric(pred,real) 92 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 93 | print(log.format(i+1, metrics[0], metrics[1], metrics[2])) 94 | amae.append(metrics[0]) 95 | amape.append(metrics[1]) 96 | armse.append(metrics[2]) 97 | 98 | log = 'On average over {} horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 99 | print(log.format(realy.size(-1), np.mean(amae),np.mean(amape),np.mean(armse))) 100 | results['prediction'] = np.asarray(results['prediction']) 101 | results['ground_truth'] = np.asarray(results['ground_truth']) 102 | results['gate'] = np.asarray(cp(yhat_gates).cpu().numpy()) 103 | results['indi'] = np.asarray(cp(yhat_ind).cpu().numpy()) 104 | if args.save is not None: 105 | np.savez_compressed(args.save+"_prediction.npz", **results) 106 | 107 | 108 | if __name__ == "__main__": 109 | t1 = time.time() 110 | main() 111 | t2 = time.time() 112 | print("Total time spent: {:.4f}".format(t2-t1)) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TESTAM: A Time-Enhanced Spatio-Temporal Attention Model with Mixture of Experts 2 | This is an official Pytorch implementation of TESTAM in the following paper: [TESTAM: A Time-Enhanced Spatio-Temporal Attention Model with Mixture of Experts](https://openreview.net/forum?id=N0nTk5BSvO), ICLR 2024. 3 | 4 | 5 | ## Updates 6 | Here we describe the changes in official TESTAM code. 7 | (Ongoing updates) 8 | - Revision of gating mechanism and meta node bank for multivariate spatio-temporal data 9 | - Additional implementation of experts -- CNN-based spatial modeling, Graph Attention, etc. 10 | - Providing processed (and also merged) version of EXPY-TKY dataset (original dataset is not processed) 11 | 12 | (2024-08-13) 13 | - We revised TESTAM routing and fixed some issues -- e.g., some of the previous pseudo label-generation process are problematic with deprecated codes 14 | - We partially updated TESTAM to be usable for the multivariate spatio-temporal data modeling. It is now capable for the multitask (or multivariate) forecasting with multiple output dimensions 15 | - We provide multiple additional features such as load balancing loss (as previos MoEs), and uncertainty flag. You may refer to the engine.py 16 | - We additionally revised some of redundancy in the codes 17 | 18 | ## Requirements 19 | - python>=3.8 20 | - torch>=1.7.1 21 | - numpy>=1.12.1 22 | - pandas>=0.19.2 23 | - scipy>=0.19.0 24 | 25 | Dependencies can be installed using the following command: 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Data Preparation 31 | 32 | ### Download Datasets 33 | The EXPY-TKY dataset can be found in [MegaCRN Github](https://github.com/deepkashiwa20/MegaCRN). 34 | The other datasets, including METR-LA, can be found in [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g) links provided by [Li et al. (DCRNN)](https://github.com/liyaguang/DCRNN). 35 | 36 | ### Process Datasets 37 | In the data processing stage, We have the same process as [DCRNN](https://github.com/liyaguang/DCRNN). 38 | ``` 39 | # Create data directories 40 | mkdir -p data/{METR-LA,PEMS-BAY,EXPY-TKY} 41 | 42 | # METR-LA 43 | python generate_training_data.py --output_dir=data/METR-LA --traffic_df_fiilename=data/metr-la.h5 --seq_length_x INPUT_SEQ_LENGTH --seq_length_y PRED_SEQ_LENGTH 44 | 45 | # PEMS-BAY 46 | python generate_training_data.py --output_dir=data/PEMS-BAY --traffic_df_fiilename=data/pems-bay.h5 --seq_length_x INPUT_SEQ_LENGTH --seq_length_y PRED_SEQ_LENGTH 47 | 48 | # EXPY-TKY 49 | python generate_training_data.py --output_dir=data/EXPY-TKY --traffic_df_fiilename=data/expy-tky.csv --seq_length_x INPUT_SEQ_LENGTH --seq_length_y PRED_SEQ_LENGTH 50 | ``` 51 | 52 | ## Usage 53 | 54 | ### Model Training 55 | We provide default training codes in `run.py`. You can train the model as follows: 56 | ``` 57 | # DATASET: {METR-LA, PEMS-BAY, EXPY-TKY} 58 | # DEVICE: {'cpu', 'cuda:0',...,'cuda:N'} 59 | python run.py --dataset DATASET --device DEVICE 60 | ``` 61 | 62 | For more parameter information, please refer to `train.py`. 63 | We provide a more detailed and complete command description for the training code: 64 | 65 | ``` 66 | python -u train.py --device DEVICE --data DATA --adjdata ADJDATA --adjtype ADJTYPE 67 | --nhid NHID --in_dim IN_DIM --seq_length OUTDIM --num_nodes N --batch_size B 68 | --dropout DROPOUT --epochs EPOCHS --print_every PRINT_EVERY --seed SEED 69 | --save SAVE --expid EXPID --load_path LOAD_PATH --patience PATIENCE --lr_mul LR_MUL 70 | --n_warmup_steps N_WARMUP_STEPS --quantile Q --is_quantile IS_QUANTILE --warmup_epoch WARMUP_EPOCH 71 | ``` 72 | 73 | The detailed descriptions of the arguments are as follows: 74 | 75 | | Argument | Description | 76 | |---|---| 77 | |device | Device ID of GPU (default: cuda:0)| 78 | |data | Path to the dataset directory (default: ./data/METR-LA)| 79 | |adjdata | Path to the adjacency matrix file (default: ./data/METR-LA/adj_mx.pkl)| 80 | |adjtype | Type of adjacency matrix. (default: 'doubletransition'). It could be set to 'scalap', 'normlap', 'symnadj', 'transition', 'doubletransition', 'identity'. It is only used to check the number of nodes| 81 | |out_dim | Output dimensionality of TESTAM (default: 1 (i.e., speed)). It is implemented for the better use of TESTAM in the other generic spatio-temporal forecasting problems with multivariate setting.| 82 | |nhid | Dimension of hidden unit (default: 32)| 83 | |in_dim | Dimension of the input signal (default: 2 (speed, tod))| 84 | |num_nodes | Number of total nodes (default: 207). If you provide adjdata, `train.py` will calculate appropriate num_nodes automatically| 85 | |batch_size | The batch size of training input data (default: 64)| 86 | |dropout | The probability of dropout (default: 0.3)| 87 | |epochs | Total number of training epochs (default: 100)| 88 | |print_every | Print out the training loss per P steps (default: 50)| 89 | |seed | Random seed for the debugging (default: -1) -1 means we do not provide seed number| 90 | |save | Path and pre-fix for the model and output files (default: ./experiment/METR-LA_TESTAM)| 91 | |expid | Experiment ID (default: 1)| 92 | |load_path | Path to the pre-trained model. If it exists, continue the training from the saved model (default: None)| 93 | |patience | Patience for the early stopping (default: 15). If validation loss does not improve for previous PATIENCE epochs, the training ends| 94 | |lr_mul | Learning rate multiplier for the CosineWarmupScheduler (default: 1). Please refer to the [Transformer (Vaswani et al. 2017)](https://arxiv.org/pdf/1706.03762.pdf) and [Pytorch documents](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html)| 95 | |n_warmup_steps | Number of steps for the CosineWarmupScheduler (default: 4000). Please refer to the [Transformer (Vaswani et al. 2017)](https://arxiv.org/pdf/1706.03762.pdf) and [Pytorch documents](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html)| 96 | |quantile | Error quantile for the routing loss function (default: 0.7)| 97 | |is_quantile | Flag for the routing loss function. If True, a routing loss function based on the error quantile will be used. Otherwise, a routing function comparing every expert will be used.| 98 | |warmup_epoch | Determines the number of warmup epochs (default: 0). During warmup epochs, routing loss is not calculated, and each expert is trained with all data samples.| 99 | 100 | ### Model Testing 101 | For the testing, you can run the code below: 102 | ``` 103 | python test.py --device DEVICE --data DATA --adjdata ADJDATA --adjtype ADJTYPE 104 | --nhid NHID --in_dim IN_DIM --out_dim OUTDIM --num_nodes N --batch_size B 105 | --save SAVE --load_path LOAD_PATH 106 | ``` 107 | 108 | ## Citation 109 | If you find this repository useful in your research, please consider citing the following paper: 110 | ``` 111 | @inproceedings{lee2024testam, 112 | title = {{TESTAM}: A Time-Enhanced Spatio-Temporal Attention Model with Mixture of Experts}, 113 | author = {Hyunwook Lee and Sungahn Ko}, 114 | booktitle = {The Twelfth International Conference on Learning Representations}, 115 | year = {2024}, 116 | URL = {https://openreview.net/forum?id=N0nTk5BSvO} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from model import * 3 | import util 4 | class trainer(): 5 | def __init__(self, scaler, in_dim, out_dim, num_nodes, nhid, dropout, device, 6 | lr_mul = 1., n_warmup_steps = 2000, quantile = 0.7, is_quantile = False, warmup_epoch = 0, 7 | use_uncertainty = False): 8 | self.model = TESTAM(num_nodes, dropout, in_dim=in_dim, out_dim=out_dim, hidden_size=nhid) 9 | self.model.to(device) 10 | # The learning rate setting below will not affct initial learning rate 11 | self.optimizer = optim.Adam(self.model.parameters(), lr = 1e-3, betas = (0.9, 0.98), eps = 1e-9) 12 | self.schedule = util.CosineWarmupScheduler(self.optimizer, d_model = nhid, n_warmup_steps = n_warmup_steps, lr_mul = lr_mul) 13 | self.loss = util.masked_mae 14 | 15 | self.scaler = scaler 16 | self.clip = 5 17 | self.n_warmup_steps = n_warmup_steps 18 | self.flag = is_quantile 19 | self.quantile = quantile 20 | self.cur_epoch = 0 21 | self.warmup_epoch = warmup_epoch 22 | self.threshold = 0. 23 | self.use_uncertainty = use_uncertainty 24 | 25 | def train(self, input, real, cur_epoch): 26 | self.model.train() 27 | self.schedule.zero_grad() 28 | 29 | output, gate, res = self.model(input) 30 | predict = self.scaler.inverse_transform(output) 31 | #output = [batch_size,out_dim,num_nodes,T] 32 | 33 | ind_loss = self.loss(self.scaler.inverse_transform(res), real.permute(0,2,3,1).unsqueeze(-1), self.threshold, reduce = None) 34 | if self.flag: 35 | gated_loss = self.loss(predict, real, reduce = None).permute(0,2,3,1) 36 | l_worst_avoidance, l_best_choice = self.get_quantile_label(gated_loss, gate, real) 37 | else: 38 | l_worst_avoidance, l_best_choice = self.get_label(ind_loss, gate, real) 39 | 40 | if self.use_uncertainty: 41 | uncertainty = self.get_uncertainty(real.permute(0,2,3,1), threshold = self.threshold) 42 | uncertainty = uncertainty.unsqueeze(dim = -1) 43 | else: 44 | uncertainty = torch.ones_like(gate) 45 | 46 | worst_avoidance = -.5 * l_worst_avoidance * torch.log(gate) * (2 - uncertainty) 47 | best_choice = -.5 * l_best_choice * torch.log(gate) * uncertainty 48 | 49 | if cur_epoch <= self.warmup_epoch: 50 | loss = ind_loss.mean() 51 | else: 52 | loss = ind_loss.mean() + worst_avoidance.mean() + best_choice.mean() 53 | loss.backward() 54 | 55 | if self.clip is not None: 56 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 57 | 58 | self.schedule.step_and_update_lr() 59 | mape = util.masked_mape(predict, real, self.threshold).item() 60 | rmse = util.masked_rmse(predict, real, self.threshold).item() 61 | return loss.item(),mape,rmse 62 | 63 | def eval(self, input, real): 64 | self.model.eval() 65 | output = self.model(input) 66 | #output = [batch_size,12,num_nodes,out_dim] 67 | predict = self.scaler.inverse_transform(output) 68 | loss = self.loss(predict, real, self.threshold) 69 | mape = util.masked_mape(predict,real,self.threshold).item() 70 | rmse = util.masked_rmse(predict,real,self.threshold).item() 71 | return loss.item(),mape,rmse 72 | 73 | def lb_loss(self, gate): 74 | n_experts = gate.size(-1) 75 | _, indices = torch.max(gate, dim = -1) 76 | counts = gate.new_tensor([len(torch.eq(indices, i).nonzero(as_tuple=True)[0]) for i in range(n_experts)]) 77 | proxied_lb = (counts / counts.sum()) * gate.mean(dim = (1,2)) 78 | lb_loss = proxied_lb.mean() 79 | return lb_loss 80 | 81 | def get_uncertainty(self, x, mode = 'psd', threshold = 0.0): 82 | 83 | def _acorr(x, dim = -1): 84 | size = x.size(dim) 85 | x_fft = torch.fft.fft(x, dim = dim) 86 | acorr = torch.fft.ifft(x_fft * x_fft.conj(), dim = dim).real 87 | return acorr / (size ** 2) 88 | 89 | def nanstd(x, dim, keepdim = False): 90 | return torch.sqrt( 91 | torch.nanmean( 92 | torch.pow(torch.abs(x - torch.nanmean(x, dim = dim, keepdim = True)), 2), 93 | dim = dim, keepdim = keepdim 94 | ) 95 | ) 96 | 97 | with torch.no_grad(): 98 | if mode == 'acorr': 99 | std = x.std(dim = -2, keepdim = True) 100 | corr = _acorr(x, dim = -2) 101 | x_noise = x + std * torch.randn((1,1,x.size(-2),1), device = x.device) / 2 102 | corr_w_noise = _acorr(x_noise, dim = -2) 103 | corr_changed = torch.abs(corr - corr_w_noise) 104 | uncertainty = torch.ones_like(corr_changed) * (corr_changed > corr_changed.quantile(1 - self.quantile)) 105 | elif mode == 'psd': 106 | from copy import deepcopy as cp 107 | vals = cp(x) 108 | vals[vals <= threshold] = torch.nan 109 | diff = vals[:,:,1:] - vals[:,:,:-1] 110 | corr_changed = torch.nanmean(torch.abs(diff), dim = -2, keepdim = True) / (nanstd(diff, dim = -2, keepdim = True) + 1e-6) 111 | corr_changed[corr_changed != corr_changed] = 0. 112 | uncertainty = torch.ones_like(corr_changed) * (corr_changed < corr_changed.quantile(self.quantile)) 113 | else: 114 | raise NotImplementedError 115 | return uncertainty 116 | 117 | 118 | def get_quantile_label(self, gated_loss, gate, real): 119 | gated_loss = gated_loss.unsqueeze(dim = -1) 120 | real = real.unsqueeze(dim = -1) 121 | max_quantile = gated_loss.quantile(self.quantile) 122 | min_quantile = gated_loss.quantile(1 - self.quantile) 123 | incorrect = (gated_loss > max_quantile).expand_as(gate) 124 | correct = ((gated_loss < min_quantile) & (real.permute(0,2,3,1,4) > self.threshold)).expand_as(gate) 125 | cur_expert = gate.argmax(dim = -1, keepdim = True) 126 | not_chosen = gate.topk(dim = -1, k = 2, largest = False).indices 127 | selected = torch.zeros_like(gate).scatter_(-1, cur_expert, 1.0) 128 | scaling = torch.zeros_like(gate).scatter_(-1, not_chosen, 0.5) 129 | selected[incorrect] = scaling[incorrect] 130 | l_worst_avoidance = selected.detach() 131 | selected = torch.zeros_like(gate).scatter(-1, cur_expert, 1.0) * correct 132 | l_best_choice = selected.detach() 133 | return l_worst_avoidance, l_best_choice 134 | 135 | def get_label(self, ind_loss, gate, real): 136 | n_experts = gate.size(-1) 137 | empty_val = (real.permute(0,2,3,1).unsqueeze(-1).expand_as(gate)) <= self.threshold 138 | max_error = ind_loss.argmax(dim = -1, keepdim = True) 139 | cur_expert = gate.argmax(dim = -1, keepdim = True) 140 | incorrect = max_error == cur_expert 141 | selected = torch.zeros_like(gate).scatter(-1, cur_expert, 1.0) 142 | scaling = torch.ones_like(gate) * ind_loss 143 | scaling = scaling.scatter(-1, max_error, 0.) 144 | scaling = scaling / (scaling.sum(dim = -1, keepdim = True)) * (1 - selected) 145 | l_worst_avoidance = torch.where(incorrect, scaling, selected) 146 | l_worst_avoidance = torch.where(empty_val, torch.zeros_like(gate), l_worst_avoidance) 147 | l_worst_avoidance = l_worst_avoidance.detach() 148 | min_error = ind_loss.argmin(dim = -1, keepdim = True) 149 | correct = min_error == cur_expert 150 | scaling = torch.zeros_like(gate) 151 | scaling = scaling.scatter(-1, min_error, 1.) 152 | l_best_choice = torch.where(correct, selected, scaling) 153 | l_best_choice = torch.where(empty_val, torch.zeros_like(gate), l_best_choice) 154 | l_best_choice = l_best_choice.detach() 155 | return l_worst_avoidance, l_best_choice 156 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import time, os 5 | import util 6 | import matplotlib.pyplot as plt 7 | from engine import trainer 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--device',type=str,default='cuda:0',help='') 11 | parser.add_argument('--data',type=str,default='data/METR-LA',help='data path') 12 | parser.add_argument('--adjdata',type=str,default=None,help='adj data path') 13 | parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type') 14 | parser.add_argument('--out_dim',type=int,default=1,help='') 15 | parser.add_argument('--nhid',type=int,default=32,help='') 16 | parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension') 17 | parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes') 18 | parser.add_argument('--batch_size',type=int,default=64,help='batch size') 19 | parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate') 20 | parser.add_argument('--epochs',type=int,default=100,help='') 21 | parser.add_argument('--print_every',type=int,default=50,help='') 22 | parser.add_argument('--seed',type=int,default=99,help='random seed') 23 | parser.add_argument('--save',type=str,default='./experiment/METR-LA_TESTAM',help='save path') 24 | parser.add_argument('--expid',type=int,default=1,help='experiment id') 25 | parser.add_argument('--load_path', type = str, default = None) 26 | parser.add_argument('--patience', type = int, default = 15) 27 | parser.add_argument('--lr_mul', type = float, default = 1) 28 | parser.add_argument('--n_warmup_steps', type = int, default = 4000) 29 | parser.add_argument('--quantile', type = float, default = 0.7) 30 | parser.add_argument('--is_quantile', action='store_true') 31 | parser.add_argument('--warmup_epoch', type = int, default = 0) 32 | 33 | args = parser.parse_args() 34 | 35 | 36 | def count_parameters(model): 37 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | 39 | 40 | def main(): 41 | #set seed 42 | if args.seed != -1: 43 | print("Start Deterministic Training with seed {}".format(args.seed)) 44 | torch.manual_seed(args.seed) 45 | np.random.seed(args.seed) 46 | #load data 47 | device = torch.device(args.device) 48 | if args.adjdata: 49 | if os.path.exists(args.adjdata): 50 | sensor_ids, sensor_id_to_ind, adj_mx = util.load_adj(args.adjdata,args.adjtype) 51 | args.num_nodes = len(sensor_ids) 52 | else: 53 | print("Invalid File Path; utliize user-provided args.num_nodes") 54 | 55 | dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size) 56 | dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size) 57 | scaler = dataloader['scaler'] 58 | 59 | print(args) 60 | 61 | 62 | engine = trainer(scaler, args.in_dim, args.out_dim, args.num_nodes, args.nhid, args.dropout, 63 | device, args.lr_mul, args.n_warmup_steps, args.quantile, args.is_quantile, args.warmup_epoch) 64 | 65 | print("Train the model with {} parameters".format(count_parameters(engine.model))) 66 | 67 | 68 | if args.load_path is not None: 69 | engine.model.load_state_dict(torch.load(args.load_path, map_location=device)) 70 | engine.model.to(device) 71 | 72 | print("start training...",flush=True) 73 | his_loss =[] 74 | val_time = [] 75 | train_time = [] 76 | wait = 0 77 | patience = args.patience 78 | best = 1e9 79 | for i in range(1,args.epochs+1): 80 | train_loss = [] 81 | train_mape = [] 82 | train_rmse = [] 83 | t1 = time.time() 84 | dataloader['train_loader'].shuffle() 85 | for iter, (x, y) in enumerate(dataloader['train_loader'].get_iterator()): 86 | trainx = torch.Tensor(x).to(device) 87 | trainx= trainx.transpose(1, 3) 88 | trainy = torch.Tensor(y).to(device) 89 | trainy = trainy.transpose(1, 3) 90 | metrics = engine.train(trainx, trainy[:,:args.out_dim,:,:], i) 91 | train_loss.append(metrics[0]) 92 | train_mape.append(metrics[1]) 93 | train_rmse.append(metrics[2]) 94 | if iter % args.print_every == 0 : 95 | log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' 96 | print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]),flush=True) 97 | t2 = time.time() 98 | train_time.append(t2-t1) 99 | #validation 100 | valid_loss = [] 101 | valid_mape = [] 102 | valid_rmse = [] 103 | 104 | s1 = time.time() 105 | for iter, (x, y) in enumerate(dataloader['val_loader'].get_iterator()): 106 | testx = torch.Tensor(x).to(device) 107 | testx = testx.transpose(1, 3) 108 | testy = torch.Tensor(y).to(device) 109 | testy = testy.transpose(1, 3) 110 | metrics = engine.eval(testx, testy[:,:args.out_dim,:,:]) 111 | valid_loss.append(metrics[0]) 112 | valid_mape.append(metrics[1]) 113 | valid_rmse.append(metrics[2]) 114 | 115 | s2 = time.time() 116 | log = 'Epoch: {:03d}, Inference Time: {:.4f} secs' 117 | print(log.format(i,(s2-s1))) 118 | val_time.append(s2-s1) 119 | mtrain_loss = np.mean(train_loss) 120 | mtrain_mape = np.mean(train_mape) 121 | mtrain_rmse = np.mean(train_rmse) 122 | 123 | mvalid_loss = np.mean(valid_loss) 124 | mvalid_mape = np.mean(valid_mape) 125 | mvalid_rmse = np.mean(valid_rmse) 126 | his_loss.append(mvalid_loss) 127 | 128 | log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch' 129 | print(log.format(i, mtrain_loss, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mape, mvalid_rmse, (t2 - t1)),flush=True) 130 | if best > his_loss[-1]: 131 | best = his_loss[-1] 132 | wait = 0 133 | torch.save(engine.model.state_dict(), args.save+"_epoch_"+str(i)+"_"+str(round(mvalid_loss,2))+".pth") 134 | else: 135 | wait = wait + 1 136 | if wait > patience: 137 | print("Early Termination!") 138 | break 139 | print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time))) 140 | print("Average Inference Time: {:.4f} secs".format(np.mean(val_time))) 141 | 142 | #testing 143 | bestid = np.argmin(his_loss) 144 | engine.model.load_state_dict(torch.load(args.save+"_epoch_"+str(bestid+1)+"_"+str(round(his_loss[bestid],2))+".pth")) 145 | 146 | 147 | outputs = [] 148 | realy = torch.Tensor(dataloader['y_test']).to(device) 149 | realy = realy.transpose(1,3)[:,:args.out_dim,:,:] 150 | 151 | for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()): 152 | testx = torch.Tensor(x).to(device) 153 | testx = testx.transpose(1,3) 154 | with torch.no_grad(): 155 | preds = engine.model(testx) 156 | outputs.append(preds) 157 | 158 | yhat = torch.cat(outputs,dim=0) 159 | yhat = yhat[:realy.size(0),...] 160 | 161 | 162 | print("Training finished") 163 | print("The valid loss on best model is", str(round(his_loss[bestid],4))) 164 | 165 | 166 | amae = [] 167 | amape = [] 168 | armse = [] 169 | results = {'prediction': [], 'ground_truth':[]} 170 | from copy import deepcopy as cp 171 | for i in range(realy.size(-1)): 172 | pred = scaler.inverse_transform(yhat[...,i]) 173 | real = realy[...,i] 174 | results['prediction'].append(cp(pred).cpu().numpy()) 175 | results['ground_truth'].append(cp(real).cpu().numpy()) 176 | metrics = util.metric(pred,real) 177 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 178 | print(log.format(i+1, metrics[0], metrics[1], metrics[2])) 179 | amae.append(metrics[0]) 180 | amape.append(metrics[1]) 181 | armse.append(metrics[2]) 182 | 183 | log = 'On average over 12 horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 184 | print(log.format(np.mean(amae),np.mean(amape),np.mean(armse))) 185 | results['prediction'] = np.asarray(results['prediction']) 186 | results['ground_truth'] = np.asarray(results['ground_truth']) 187 | np.savez_compressed(args.save+"_prediction.npz", **results) 188 | torch.save(engine.model.state_dict(), args.save+"_exp"+str(args.expid)+"_best_"+str(round(his_loss[bestid],2))+".pth") 189 | 190 | 191 | 192 | if __name__ == "__main__": 193 | t1 = time.time() 194 | main() 195 | t2 = time.time() 196 | print("Total time spent: {:.4f}".format(t2-t1)) 197 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | import scipy.sparse as sp 5 | import torch 6 | from scipy.sparse import linalg 7 | 8 | 9 | class ScheduledOptim(): 10 | def __init__(self, optimizer, d_model, n_warmup_steps, lr_mul = 1.0): 11 | self._optimizer = optimizer 12 | self.lr_mul = lr_mul 13 | self.d_model = d_model 14 | self.n_warmup_steps = n_warmup_steps 15 | self.n_steps = 0 16 | 17 | def step_and_update_lr(self): 18 | self._update_lr() 19 | self._optimizer.step() 20 | 21 | def zero_grad(self): 22 | self._optimizer.zero_grad() 23 | 24 | def _get_lr_scale(self): 25 | d_model = self.d_model 26 | n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps 27 | return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)) 28 | 29 | def _update_lr(self): 30 | self.n_steps += 1 31 | lr = self.lr_mul * self._get_lr_scale() 32 | #if self.n_steps > 2000: 33 | # lr = 3e-3 34 | for param_group in self._optimizer.param_groups: 35 | param_group['lr'] = lr 36 | 37 | 38 | class CosineWarmupScheduler(): 39 | def __init__(self, optimizer, d_model, n_warmup_steps, lr_mul = 1.0): 40 | self._optimizer = optimizer 41 | self.lr_mul = lr_mul 42 | self.d_model = d_model 43 | self.n_warmup_steps = n_warmup_steps 44 | self.n_periodic_steps = n_warmup_steps 45 | self.n_steps = 0 46 | 47 | def step_and_update_lr(self): 48 | self._update_lr() 49 | self._optimizer.step() 50 | 51 | def zero_grad(self): 52 | self._optimizer.zero_grad() 53 | 54 | def _get_lr_scale(self): 55 | d_model = self.d_model 56 | n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps 57 | if n_steps <= self.n_warmup_steps: 58 | return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)) 59 | else: 60 | base = (d_model ** -0.5) * n_warmup_steps ** (-0.5) * (1 + np.cos(np.pi * ((n_steps - self.n_warmup_steps) % self.n_periodic_steps) / self.n_periodic_steps)) 61 | return base 62 | 63 | def _update_lr(self): 64 | self.n_steps += 1 65 | lr = self.lr_mul * self._get_lr_scale() 66 | for param_group in self._optimizer.param_groups: 67 | param_group['lr'] = lr 68 | 69 | 70 | class DataLoader(object): 71 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True): 72 | """ 73 | :param xs: 74 | :param ys: 75 | :param batch_size: 76 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 77 | """ 78 | self.batch_size = batch_size 79 | self.current_ind = 0 80 | if pad_with_last_sample: 81 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 82 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 83 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 84 | xs = np.concatenate([xs, x_padding], axis=0) 85 | ys = np.concatenate([ys, y_padding], axis=0) 86 | self.size = len(xs) 87 | self.num_batch = int(self.size // self.batch_size) 88 | self.xs = xs 89 | self.ys = ys 90 | 91 | def shuffle(self): 92 | permutation = np.random.permutation(self.size) 93 | xs, ys = self.xs[permutation], self.ys[permutation] 94 | self.xs = xs 95 | self.ys = ys 96 | 97 | def get_iterator(self): 98 | self.current_ind = 0 99 | 100 | def _wrapper(): 101 | while self.current_ind < self.num_batch: 102 | start_ind = self.batch_size * self.current_ind 103 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 104 | x_i = self.xs[start_ind: end_ind, ...] 105 | y_i = self.ys[start_ind: end_ind, ...] 106 | yield (x_i, y_i) 107 | self.current_ind += 1 108 | 109 | return _wrapper() 110 | 111 | 112 | class StandardScaler(): 113 | """ 114 | Standard the input 115 | """ 116 | 117 | def __init__(self, mean, std): 118 | self.mean = mean 119 | self.std = std 120 | 121 | def transform(self, data): 122 | return (data - self.mean) / self.std 123 | 124 | def inverse_transform(self, data): 125 | return (data * self.std) + self.mean 126 | 127 | 128 | def sym_adj(adj): 129 | """Symmetrically normalize adjacency matrix.""" 130 | adj = sp.coo_matrix(adj) 131 | rowsum = np.array(adj.sum(1)) 132 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 133 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 134 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 135 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense() 136 | 137 | def asym_adj(adj): 138 | adj = sp.coo_matrix(adj) 139 | rowsum = np.array(adj.sum(1)).flatten() 140 | d_inv = np.power(rowsum, -1).flatten() 141 | d_inv[np.isinf(d_inv)] = 0. 142 | d_mat= sp.diags(d_inv) 143 | return d_mat.dot(adj).astype(np.float32).todense() 144 | 145 | def calculate_normalized_laplacian(adj): 146 | """ 147 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 148 | # D = diag(A 1) 149 | :param adj: 150 | :return: 151 | """ 152 | adj = sp.coo_matrix(adj) 153 | d = np.array(adj.sum(1)) 154 | d_inv_sqrt = np.power(d, -0.5).flatten() 155 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 156 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 157 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 158 | return normalized_laplacian 159 | 160 | 161 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 162 | if undirected: 163 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 164 | L = calculate_normalized_laplacian(adj_mx) 165 | if lambda_max is None: 166 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 167 | lambda_max = lambda_max[0] 168 | L = sp.csr_matrix(L) 169 | M, _ = L.shape 170 | I = sp.identity(M, format='csr', dtype=L.dtype) 171 | L = (2 / lambda_max * L) - I 172 | return L.astype(np.float32).todense() 173 | 174 | def load_pickle(pickle_file): 175 | try: 176 | with open(pickle_file, 'rb') as f: 177 | pickle_data = pickle.load(f) 178 | except UnicodeDecodeError as e: 179 | with open(pickle_file, 'rb') as f: 180 | pickle_data = pickle.load(f, encoding='latin1') 181 | except Exception as e: 182 | print('Unable to load data ', pickle_file, ':', e) 183 | raise 184 | return pickle_data 185 | 186 | 187 | def load_adj(pkl_filename, adjtype): 188 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 189 | if adjtype == "scalap": 190 | adj = [calculate_scaled_laplacian(adj_mx)] 191 | elif adjtype == "normlap": 192 | adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32).todense()] 193 | elif adjtype == "symnadj": 194 | adj = [sym_adj(adj_mx)] 195 | elif adjtype == "transition": 196 | adj = [asym_adj(adj_mx)] 197 | elif adjtype == "doubletransition": 198 | adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))] 199 | elif adjtype == "identity": 200 | adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)] 201 | else: 202 | error = 0 203 | assert error, "adj type not defined" 204 | return sensor_ids, sensor_id_to_ind, adj 205 | 206 | 207 | def load_dataset(dataset_dir, batch_size, valid_batch_size= None, test_batch_size=None): 208 | data = {} 209 | for category in ['train', 'val', 'test']: 210 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 211 | data['x_' + category] = cat_data['x'] 212 | data['y_' + category] = cat_data['y'] 213 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 214 | # Data format 215 | for category in ['train', 'val', 'test']: 216 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 217 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size) 218 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], valid_batch_size) 219 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size) 220 | data['scaler'] = scaler 221 | return data 222 | 223 | 224 | def masked_mse(preds, labels, null_val=np.nan, reduce = True): 225 | if np.isnan(null_val): 226 | mask = ~torch.isnan(labels) 227 | else: 228 | mask = (labels > null_val) 229 | mask = mask.float() 230 | mask /= torch.mean((mask)) 231 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 232 | loss = (preds-labels)**2 233 | loss = loss * mask 234 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 235 | if reduce: 236 | return torch.mean(loss) 237 | else: 238 | return loss 239 | 240 | 241 | def masked_rmse(preds, labels, null_val=np.nan): 242 | return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) 243 | 244 | 245 | def masked_mae(preds, labels, null_val=np.nan, reduce = True): 246 | if np.isnan(null_val): 247 | mask = ~torch.isnan(labels) 248 | else: 249 | mask = (labels > null_val) 250 | mask = mask.float() 251 | mask /= torch.mean((mask)) 252 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 253 | loss = torch.abs(preds-labels) 254 | loss = loss * mask 255 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 256 | if reduce: 257 | return torch.mean(loss) 258 | else: 259 | return loss 260 | 261 | 262 | def masked_mape(preds, labels, null_val=np.nan): 263 | if np.isnan(null_val): 264 | mask = ~torch.isnan(labels) 265 | else: 266 | mask = (labels > null_val) 267 | mask = mask.float() 268 | mask /= torch.mean((mask)) 269 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 270 | loss = torch.abs(preds-labels)/labels 271 | loss = loss * mask 272 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 273 | return torch.mean(loss) 274 | 275 | 276 | def metric(pred, real): 277 | mae = masked_mae(pred,real,0.0).item() 278 | mape = masked_mape(pred,real,0.0).item() 279 | rmse = masked_rmse(pred,real,0.0).item() 280 | return mae,mape,rmse 281 | 282 | 283 | if __name__ == "__main__": 284 | data =load_dataset('./data/METR-LA', 64, 64, 64) 285 | print((24000, 12, 207, 2)) 286 | ys = data['train_loader'].ys[:,0] 287 | idxs = (ys[...,1] * 288).astype(int) 288 | spd = ys[...,0] 289 | trends = np.zeros((288, 207)) 290 | for i in range(288): 291 | idx = np.where(idxs == i) 292 | trends[i] = spd[i].mean(axis = 0) 293 | print(idxs.min(), idxs.max()) 294 | print(trends) 295 | print(spd.mean(axis = 0).shape) 296 | 297 | 298 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import sys, math 6 | from copy import deepcopy as cp 7 | 8 | 9 | class nconv(nn.Module): 10 | def __init__(self): 11 | super(nconv,self).__init__() 12 | 13 | def forward(self,x, A): 14 | x = torch.einsum('nvlc,vw->nwlc',(x,A)) 15 | return x.contiguous() 16 | 17 | 18 | class linear(nn.Module): 19 | def __init__(self,c_in,c_out): 20 | super(linear,self).__init__() 21 | self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True) 22 | 23 | def forward(self,x): 24 | return self.mlp(x) 25 | 26 | 27 | class gcn(nn.Module): 28 | def __init__(self,c_in,c_out,dropout,supports_len=3,order=2): 29 | super(gcn,self).__init__() 30 | self.nconv = nconv() 31 | c_in = (order*supports_len+1)*c_in 32 | self.mlp = nn.Linear(c_in,c_out) 33 | self.dropout = dropout 34 | self.order = order 35 | 36 | def forward(self,x,support): 37 | out = [x] 38 | for a in support: 39 | x1 = self.nconv(x,a) 40 | out.append(x1) 41 | for k in range(2, self.order + 1): 42 | x2 = self.nconv(x1,a) 43 | out.append(x2) 44 | x1 = x2 45 | 46 | h = torch.cat(out,dim=-1) 47 | h = self.mlp(h) 48 | h = F.dropout(h, self.dropout, training=self.training) 49 | return h 50 | 51 | 52 | class QKVAttention(nn.Module): 53 | """ 54 | Assume input has shape B, N, T, C or B, T, N, C 55 | Note: Attention map will be B, N, T, T or B, T, N, N 56 | - Could be utilized for both spatial and temporal modeling 57 | - Able to get additional kv-input (for Time-Enhanced Attention) 58 | """ 59 | def __init__(self, in_dim, hidden_size, dropout, num_heads = 4): 60 | super(QKVAttention, self).__init__() 61 | self.query = nn.Linear(in_dim, hidden_size, bias = False) 62 | self.key = nn.Linear(in_dim, hidden_size, bias = False) 63 | self.value = nn.Linear(in_dim, hidden_size, bias = False) 64 | self.num_heads = num_heads 65 | self.proj = nn.Linear(hidden_size, hidden_size) 66 | self.dropout = nn.Dropout(p=dropout) 67 | assert hidden_size % num_heads == 0 68 | 69 | def forward(self, x, kv = None): 70 | if kv is None: 71 | kv = x 72 | query = self.query(x) 73 | key = self.key(kv) 74 | value = self.value(kv) 75 | num_heads = self.num_heads 76 | if num_heads > 1: 77 | query = torch.cat(torch.chunk(query, num_heads, dim = -1), dim = 0) 78 | key = torch.cat(torch.chunk(key, num_heads, dim = -1), dim = 0) 79 | value = torch.cat(torch.chunk(value, num_heads, dim = -1), dim = 0) 80 | d = value.size(-1) 81 | energy = torch.matmul(query, key.transpose(-1,-2)) 82 | energy = energy / (d ** 0.5) 83 | score = torch.softmax(energy, dim = -1) 84 | head_out = torch.matmul(score, value) 85 | out = torch.cat(torch.chunk(head_out, num_heads, dim = 0), dim = -1) 86 | return self.dropout(self.proj(out)) 87 | 88 | 89 | class LayerNorm(nn.Module): 90 | #Assume input has shape B, N, T, C 91 | def __init__(self, normalized_shape, eps = 1e-6): 92 | super(LayerNorm, self).__init__() 93 | self.normalized_shape = normalized_shape 94 | self.eps = eps 95 | self.gamma = nn.Parameter(torch.ones(*normalized_shape)) 96 | self.beta = nn.Parameter(torch.zeros(*normalized_shape)) 97 | 98 | def forward(self, x): 99 | dims = [-(i + 1) for i in range(len(self.normalized_shape))] 100 | #mean --> shape :(B, C, H, W) --> (B) 101 | #mean with keepdims --> shape: (B, C, H, W) --> (B, 1, 1, 1) 102 | mean = x.mean(dim = dims, keepdims = True) 103 | std = x.std(dim = dims, keepdims = True, unbiased = False) 104 | #x_norm = (B, C, H, W) 105 | x_norm = (x - mean) / (std + self.eps) 106 | out = x_norm * self.gamma + self.beta 107 | return out 108 | 109 | 110 | class BatchNorm(nn.Module): 111 | def __init__(self, num_features, momentum = 0.1, eps = 1e-5, track_running_stats = True): 112 | super(BatchNorm, self).__init__() 113 | self.momentum = momentum 114 | self.eps = eps 115 | self.gamma = nn.Parameter(torch.ones(num_features)) 116 | self.beta = nn.Parameter(torch.zeros(num_features)) 117 | if track_running_stats: 118 | self.register_buffer('running_mean', torch.zeros(num_features)) 119 | self.register_buffer('running_var', torch.ones(num_features)) 120 | else: 121 | self.register_buffer('running_mean', None) 122 | self.register_buffer('running_var', None) 123 | 124 | def forward(self, x): 125 | dims = [i for i in range(x.dim() - 1)] 126 | mean = x.mean(dim = dims) 127 | var = x.var(dim = dims, correction = 0) 128 | if (self.training) and (self.running_mean is not None): 129 | avg_factor = self.momentum 130 | moving_avg = lambda prev, cur: (1 - avg_factor) * prev + avg_factor * cur.detach() 131 | dims = [i for i in range(x.dim() - 1)] 132 | self.running_mean = moving_avg(self.running_mean, mean) 133 | self.running_var = moving_avg(self.running_var, var) 134 | mean, var = self.running_mean, self.running_var 135 | 136 | x_norm = (x - mean) / torch.sqrt(var + self.eps) 137 | out = x_norm * self.gamma + self.beta 138 | return out 139 | 140 | 141 | class SkipConnection(nn.Module): 142 | """ 143 | Helper Module to build skip connection 144 | - forward may get auxiliary input to handle multiple inputs (e.g., adjacency matrix or time-enhanced attention) 145 | """ 146 | def __init__(self, module, norm): 147 | super(SkipConnection, self).__init__() 148 | self.module = module 149 | self.norm = norm 150 | 151 | def forward(self, x, aux = None): 152 | return self.norm(x + self.module(x, aux)) 153 | 154 | 155 | class PositionwiseFeedForward(nn.Module): 156 | def __init__(self, in_dim, hidden_size, dropout, activation = nn.GELU()): 157 | super(PositionwiseFeedForward, self).__init__() 158 | self.act = activation 159 | self.l1 = nn.Linear(in_dim, hidden_size) 160 | self.l2 = nn.Linear(hidden_size, in_dim) 161 | self.dropout = nn.Dropout(p = dropout) 162 | 163 | def forward(self, x, kv = None): 164 | return self.dropout(self.l2(self.act(self.l1(x)))) 165 | 166 | 167 | class SwitchPositionwiseFeedForward(nn.Module): 168 | """ 169 | Switch Positionwise Feed Forward module for the normal mixture-of-experts model 170 | - Note: not used for the TESTAM 171 | """ 172 | def __init__(self, in_dim, hidden_size, dropout, activation = nn.ReLU(), n_experts = 4): 173 | super(SwitchPositionwiseFeedForward, self).__init__() 174 | self.n_experts = n_experts 175 | self.activation = activation 176 | self.dropout = nn.Dropout(p = dropout) 177 | expert = PositionwiseFeedForward(in_dim, hidden_size, dropout, activation) 178 | self.experts = nn.ModuleList([cp(expert) for _ in range(n_experts)]) 179 | self.switch = nn.Linear(in_dim, n_experts) 180 | self.softmax = nn.Softmax(dim = -1) 181 | 182 | def forward(self, x, kv = None): 183 | B, N, T, C = x.size() 184 | x = x.view(-1,C) 185 | 186 | route_prob = self.softmax(self.switch(x)) 187 | route_prob_max, routes = torch.max(route_prob, dim = -1) 188 | 189 | # indices: (n_experts, B*T, N) 190 | indices = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)] 191 | 192 | final_output = torch.zeros_like(x) 193 | 194 | for i in range(self.n_experts): 195 | expert_output = self.experts[i](x[indices[i]]) 196 | final_output[indices[i]] = expert_output 197 | 198 | final_output = final_output * (route_prob_max).unsqueeze(dim = -1) 199 | final_output = final_output.view(B,N,T,C) 200 | 201 | return final_output 202 | 203 | 204 | class TemporalInformationEmbedding(nn.Module): 205 | """ 206 | We assume that input shape is B, T 207 | - Only contains temporal information with index 208 | Arguments: 209 | - vocab_size: total number of temporal features (e.g., 7 days) 210 | - freq_act: periodic activation function 211 | - n_freq: number of hidden elements for frequency components 212 | - if 0 or H, it only uses linear or frequency component, respectively 213 | """ 214 | def __init__(self, hidden_size, vocab_size, freq_act = torch.sin, n_freq = 1): 215 | super(TemporalInformationEmbedding, self).__init__() 216 | self.embedding = nn.Embedding(vocab_size, hidden_size) 217 | self.linear = nn.Linear(hidden_size, hidden_size) 218 | self.freq_act = freq_act 219 | self.n_freq = n_freq 220 | 221 | def forward(self, x): 222 | x_emb = self.embedding(x) 223 | x_weight = self.linear(x_emb) 224 | if self.n_freq == 0: 225 | return x_weight 226 | if self.n_freq == x_emb.size(-1): 227 | return self.freq_act(x_weight) 228 | x_linear = x_weight[...,self.n_freq:] 229 | x_act = self.freq_act(x_weight[...,:self.n_freq]) 230 | return torch.cat([x_linear, x_act], dim = -1) 231 | 232 | 233 | class TemporalModel(nn.Module): 234 | """ 235 | Input shape 236 | - x: B, T 237 | - Need modification to use the multiple temporal information with different indexing (e.g., dow and tod) 238 | - speed: B, N, T, in_dim = 1 239 | - Need modification to use them in different dataset 240 | Output shape B, N, T, O 241 | - In the traffic forecasting, O (outdim) is normally one 242 | Arguments: 243 | - vocab_size: total number of temporal features (e.g., 7 days) 244 | - Notes: in the trivial traffic forecasting problem, we have total 288 = 24 * 60 / 5 (5 min interval) 245 | """ 246 | def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim = 1, out_dim = 1, vocab_size = 288, activation = nn.ReLU()): 247 | super(TemporalModel, self).__init__() 248 | self.vocab_size = vocab_size 249 | self.act = activation 250 | self.in_dim = in_dim 251 | self.embedding = TemporalInformationEmbedding(hidden_size, vocab_size = vocab_size) 252 | self.spd_proj = nn.Linear(in_dim, hidden_size) 253 | self.spd_cat = nn.Linear(hidden_size * 2, hidden_size) # Cat speed information and TIM information 254 | 255 | module = QKVAttention(in_dim = hidden_size, hidden_size = hidden_size, dropout = dropout) 256 | ff = PositionwiseFeedForward(in_dim = hidden_size, hidden_size = 4 * hidden_size, dropout = dropout) 257 | norm = LayerNorm(normalized_shape = (hidden_size,)) 258 | 259 | self.node_features = nn.Parameter(torch.randn(num_nodes, hidden_size)) 260 | 261 | self.attn_layers = nn.ModuleList() 262 | self.ff = nn.ModuleList() 263 | for _ in range(layers): 264 | self.attn_layers.append(SkipConnection(cp(module), cp(norm))) 265 | self.ff.append(SkipConnection(cp(ff), cp(norm))) 266 | 267 | self.proj = nn.Linear(hidden_size, out_dim) 268 | 269 | 270 | def forward(self, x, speed = None): 271 | TIM = self.embedding(x) 272 | #For the traffic forecasting, we introduce learnable node features 273 | #The user may modify this node feature into meta-learning based representation, which enables the ability to adopt the model into different dataset 274 | x_nemb = torch.einsum('btc, nc -> bntc', TIM, self.node_features) 275 | if speed is None: 276 | speed = torch.zeros_like(x_nemb[...,:self.in_dim]) 277 | x_spd = self.spd_proj(speed) 278 | x_nemb = self.spd_cat(torch.cat([x_spd, x_nemb], dim = -1)) 279 | 280 | attns = [] 281 | for i, (attn_layer, ff) in enumerate(zip(self.attn_layers, self.ff)): 282 | x_attn = attn_layer(x_nemb) 283 | x_nemb = ff(x_attn) 284 | attns.append(x_nemb) 285 | 286 | out = self.proj(self.act(x_nemb)) 287 | 288 | return out, attns 289 | 290 | 291 | class STModel(nn.Module): 292 | """ 293 | Input shape B, N, T, in_dim 294 | Output shape B, N, T, out_dim 295 | Arguments: 296 | - spatial: Flag that determine when spatial attention will be performed 297 | - True --> spatial first and then temporal attention will be performed 298 | """ 299 | def __init__(self, hidden_size, supports_len, num_nodes, dropout, layers, out_dim = 1, in_dim = 2, spatial = False, activation = nn.ReLU()): 300 | super(STModel, self).__init__() 301 | self.spatial = spatial 302 | self.act = activation 303 | self.out_dim = out_dim 304 | 305 | s_gcn = gcn(c_in = hidden_size, c_out = hidden_size, dropout = dropout, supports_len = supports_len, order = 2) 306 | t_attn = QKVAttention(in_dim = hidden_size, hidden_size = hidden_size, dropout = dropout) 307 | ff = PositionwiseFeedForward(in_dim = hidden_size, hidden_size = 4 * hidden_size, dropout = dropout) 308 | norm = LayerNorm(normalized_shape = (hidden_size, )) 309 | 310 | self.start_linear = nn.Linear(in_dim, hidden_size) 311 | 312 | self.proj = nn.Linear(hidden_size, hidden_size + out_dim) 313 | 314 | self.temporal_layers = nn.ModuleList() 315 | self.spatial_layers = nn.ModuleList() 316 | self.ed_layers = nn.ModuleList() 317 | self.ff = nn.ModuleList() 318 | 319 | for _ in range(layers): 320 | self.temporal_layers.append(SkipConnection(cp(t_attn), cp(norm))) 321 | self.spatial_layers.append(SkipConnection(cp(s_gcn), cp(norm))) 322 | self.ed_layers.append(SkipConnection(cp(t_attn), cp(norm))) 323 | self.ff.append(SkipConnection(cp(ff), cp(norm))) 324 | 325 | def forward(self, x, prev_hidden, supports): 326 | x = self.start_linear(x.permute(0,2,3,1)) 327 | x_start = x 328 | hiddens = [] 329 | for i, (temporal_layer, spatial_layer, ed_layer, ff) in enumerate(zip(self.temporal_layers, self.spatial_layers, self.ed_layers, self.ff)): 330 | if not self.spatial: 331 | x1 = temporal_layer(x) # B, N, T, C 332 | x_attn = spatial_layer(x1, supports) # B, N, T, C 333 | else: 334 | x1 = spatial_layer(x, supports) 335 | x_attn = temporal_layer(x1) 336 | if prev_hidden is not None: 337 | x_attn = ed_layer(x_attn, prev_hidden[-1]) 338 | x = ff(x_attn) 339 | hiddens.append(x) 340 | 341 | out = self.proj(self.act(x)) 342 | res, out = torch.split(out, [out.size(-1) - self.out_dim, self.out_dim], dim = -1) 343 | 344 | return x_start - res, out.contiguous(), hiddens 345 | 346 | 347 | 348 | class AttentionModel(nn.Module): 349 | """ 350 | Input shape B, N, T, in_dim 351 | Output shape B, N, T, out_dim 352 | 353 | """ 354 | def __init__(self, hidden_size, layers, dropout, edproj = False, in_dim = 2, out_dim = 1, spatial = False, activation = nn.ReLU()): 355 | super(AttentionModel, self).__init__() 356 | self.spatial = spatial 357 | self.act = activation 358 | 359 | base_model = SkipConnection(QKVAttention(hidden_size, hidden_size, dropout = dropout), LayerNorm(normalized_shape = (hidden_size, ))) 360 | ff = SkipConnection(PositionwiseFeedForward(hidden_size, 4 * hidden_size, dropout = dropout), LayerNorm(normalized_shape = (hidden_size, ))) 361 | 362 | self.start_linear = nn.Linear(in_dim, hidden_size) 363 | 364 | self.spatial_layers = nn.ModuleList() 365 | self.temporal_layers = nn.ModuleList() 366 | self.ed_layers = nn.ModuleList() 367 | self.ff = nn.ModuleList() 368 | 369 | for i in range(layers): 370 | self.spatial_layers.append(cp(base_model)) 371 | self.temporal_layers.append(cp(base_model)) 372 | self.ed_layers.append(cp(base_model)) 373 | self.ff.append(cp(ff)) 374 | 375 | self.proj = nn.Linear(hidden_size, out_dim) 376 | 377 | 378 | def forward(self, x, prev_hidden = None): 379 | x = self.start_linear(x.permute(0,2,3,1)) 380 | 381 | for i, (s_layer, t_layer, ff) in enumerate(zip(self.spatial_layers, self.temporal_layers, self.ff)): 382 | if not self.spatial: 383 | x1 = t_layer(x) 384 | x_attn = s_layer(x1.transpose(1,2)) 385 | else: 386 | x1 = s_layer(x.transpose(1,2)) 387 | x_attn = t_layer(x1.transpose(1,2)).transpose(1,2) 388 | 389 | if prev_hidden is not None: 390 | x_attn = self.ed_layers[i](x_attn.transpose(1,2), prev_hidden[-1]) 391 | x_attn = x_attn.transpose(1,2) 392 | x = ff(x_attn.transpose(1,2)) 393 | 394 | return self.proj(self.act(x)), x 395 | 396 | 397 | class MemoryGate(nn.Module): 398 | """ 399 | Input 400 | - input: B, N, T, in_dim, original input 401 | - hidden: hidden states from each expert, shape: E-length list of (B, N, T, C) tensors, where E is the number of experts 402 | Output 403 | - similarity score (i.e., routing probability before softmax function) 404 | Arguments 405 | - mem_hid, memory_size: hidden size and total number of memroy units 406 | - sim: similarity function to evaluate routing probability 407 | - nodewise: flag to determine routing level. Traffic forecasting could have a more fine-grained routing, because it has additional dimension for the roads 408 | - True: enables node-wise routing probability calculation, which is coarse-grained one 409 | """ 410 | def __init__(self, hidden_size, num_nodes, mem_hid = 32, in_dim = 2, out_dim = 1, memory_size = 20, sim = nn.CosineSimilarity(dim = -1), nodewise = False, ind_proj = True, attention_type = 'attention'): 411 | super(MemoryGate, self).__init__() 412 | self.attention_type = attention_type 413 | self.sim = sim 414 | self.nodewise = nodewise 415 | self.out_dim = out_dim 416 | 417 | self.memory = nn.Parameter(torch.empty(memory_size, mem_hid)) 418 | 419 | self.hid_query = nn.ParameterList([nn.Parameter(torch.empty(hidden_size, mem_hid)) for _ in range(3)]) 420 | self.key = nn.ParameterList([nn.Parameter(torch.empty(hidden_size, mem_hid)) for _ in range(3)]) 421 | self.value = nn.ParameterList([nn.Parameter(torch.empty(hidden_size, mem_hid)) for _ in range(3)]) 422 | 423 | self.input_query = nn.Parameter(torch.empty(in_dim, mem_hid)) 424 | 425 | self.We1 = nn.Parameter(torch.empty(num_nodes, memory_size)) 426 | self.We2 = nn.Parameter(torch.empty(num_nodes, memory_size)) 427 | 428 | for p in self.parameters(): 429 | if p.dim() > 1: 430 | nn.init.xavier_uniform_(p) 431 | else: 432 | nn.init.zeros_(p) 433 | 434 | def forward(self, input, hidden): 435 | if self.attention_type == 'attention': 436 | attention = self.attention 437 | else: 438 | attention = self.topk_attention 439 | B, N, T, _ = input.size() 440 | memories = self.query_mem(input) 441 | scores = [] 442 | for i, h in enumerate(hidden): 443 | hidden_att = attention(h,i) 444 | scores.append(self.sim(memories, hidden_att)) 445 | 446 | scores = torch.stack(scores, dim = -1) 447 | return scores.unsqueeze(dim = -2).expand(B, N, T, self.out_dim, scores.size(-1)) 448 | 449 | def attention(self, x, i): 450 | B, N, T, _ = x.size() 451 | query = torch.matmul(x, self.hid_query[i]) 452 | key = torch.matmul(x, self.key[i]) 453 | value = torch.matmul(x, self.value[i]) 454 | if self.nodewise: 455 | query = query.sum(dim = -2, keepdim = True) 456 | energy = torch.matmul(query, key.transpose(-1,-2)) 457 | score = torch.softmax(energy, dim = -1) 458 | out = torch.matmul(score, value) 459 | return out.expand_as(value) 460 | 461 | def topk_attention(self, x, i, k = 3): 462 | B, N, T, _ = x.size() 463 | query = torch.matmul(x, self.hid_query[i]) 464 | key = torch.matmul(x, self.key[i]) 465 | value = torch.matmul(x, self.value[i]) 466 | if self.nodewise: 467 | query = query.sum(dim = -2, keepdim = True) 468 | energy = torch.matmul(query, key.transpose(-1,-2)) 469 | values, indices = torch.topk(energy, k = k, dim = -1) 470 | score = energy.zero_().scatter_(-1, indices, torch.relu(values)) 471 | out = torch.matmul(score, value) 472 | return out.expand_as(value) 473 | 474 | def query_mem(self, input): 475 | B, N, T, _ = input.size() 476 | mem = self.memory 477 | query = torch.matmul(input, self.input_query) 478 | energy = torch.matmul(query, mem.T) 479 | score = torch.softmax(energy, dim = -1) 480 | out = torch.matmul(score, mem) 481 | return out 482 | 483 | def reset_queries(self): 484 | with torch.no_grad(): 485 | for p in self.hid_query: 486 | nn.init.xavier_uniform_(p) 487 | nn.init.xavier_uniform_(self.input_query) 488 | 489 | def reset_params(self): 490 | with torch.no_grad(): 491 | for n, p in self.named_parameters(): 492 | if n in "We1 We2 memory".split(): 493 | continue 494 | else: 495 | nn.init.xavier_uniform_(p) 496 | 497 | 498 | class AttnGate(nn.Module): 499 | def __init__(self, hidden_size, num_nodes, in_dim = 2, sim = nn.CosineSimilarity(dim = -1)): 500 | super(AttnGate, self).__init__() 501 | self.in_key = nn.Linear(in_dim, hidden_size, bias = False) 502 | self.hid_query = nn.Linear(hidden_size, hidden_size, bias = False) 503 | self.in_value = nn.Linear(in_dim, hidden_size, bias = False) 504 | sim = lambda x, y: nn.PairwiseDistance()(x, y) * -1 505 | self.sim = sim 506 | self.proj = nn.Linear(hidden_size, 1) 507 | for p in self.parameters(): 508 | if p.dim() > 1: 509 | nn.init.xavier_uniform_(p) 510 | else: 511 | nn.init.zeros_(p) 512 | 513 | def forward(self, input, hidden): 514 | num_heads = 1 515 | key = self.in_key(input) 516 | value = self.in_value(input) 517 | if num_heads > 1: 518 | key = torch.cat(torch.chunk(key, num_heads, dim = -1), dim = 0) 519 | value = torch.cat(torch.chunk(value, num_heads, dim = -1), dim = 0) 520 | scores = [] 521 | for h in hidden: 522 | query = self.hid_query(h) 523 | if num_heads > 1: 524 | head_query = torch.cat(torch.chunk(query, num_heads, dim = -1), dim = 0) 525 | energy = torch.matmul(head_query, key.transpose(-1,-2)) / (head_query.size(-1) ** 0.5) 526 | else: 527 | energy = torch.matmul(query, key.transpose(-1,-2)) / (query.size(-1) ** 0.5) 528 | score = torch.softmax(energy, dim = -1) 529 | head_out = torch.matmul(score, value) 530 | out = torch.cat(torch.chunk(head_out, num_heads, dim = 0), dim = -1) 531 | scores.append(self.sim(query, out)) 532 | return torch.stack(scores,dim = -1) 533 | 534 | 535 | class TESTAM(nn.Module): 536 | """ 537 | TESTAM model 538 | """ 539 | def __init__(self, num_nodes, dropout=0.3, in_dim=2, out_dim = 1, hidden_size = 32, layers = 3, prob_mul = False, max_time_index = 288, **args): 540 | super(TESTAM, self).__init__() 541 | self.dropout = dropout 542 | self.prob_mul = prob_mul 543 | self.supports_len = 2 544 | self.max_time_index = max_time_index 545 | 546 | self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim = in_dim - 1, out_dim = out_dim, layers = layers, dropout = dropout, vocab_size = max_time_index) 547 | self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim = in_dim, out_dim = out_dim, layers = layers, dropout = dropout) 548 | self.attention_expert = AttentionModel(hidden_size, in_dim = in_dim, out_dim = out_dim, layers = layers, dropout = dropout) 549 | 550 | self.gate_network = MemoryGate(hidden_size, num_nodes, in_dim = in_dim, out_dim = out_dim) 551 | 552 | for model in [self.identity_expert, self.adaptive_expert, self.attention_expert]: 553 | for n, p in model.named_parameters(): 554 | if p.dim() > 1: 555 | nn.init.xavier_uniform_(p) 556 | 557 | def forward(self, input, gate_out = False): 558 | """ 559 | input: B, in_dim, N, T 560 | - Note: we assume that the last dimeions of in_dim is temporal feature, such as tod or dow (could be represented as integer) 561 | o_identity shape B, N, T, 1 562 | """ 563 | n1 = torch.matmul(self.gate_network.We1, self.gate_network.memory) 564 | n2 = torch.matmul(self.gate_network.We2, self.gate_network.memory) 565 | g1 = torch.softmax(torch.relu(torch.mm(n1, n2.T)), dim = -1) 566 | g2 = torch.softmax(torch.relu(torch.mm(n2, n1.T)), dim = -1) 567 | new_supports = [g1, g2] 568 | 569 | time_index = input[:,-1,0] # B, T 570 | max_t = self.max_time_index 571 | cur_time_index = ((time_index * max_t) % max_t).long() 572 | next_time_index = ((time_index * max_t + time_index.size(-1)) % max_t).long() 573 | o_identity, h_identity = self.identity_expert(cur_time_index, input[:,:-1].permute(0,2,3,1)) 574 | _, h_future = self.identity_expert(next_time_index) 575 | 576 | 577 | _, o_adaptive, h_adaptive = self.adaptive_expert(input, h_future, new_supports) 578 | 579 | o_attention, h_attention = self.attention_expert(input, h_future) 580 | 581 | ind_out = torch.stack([o_identity, o_adaptive, o_attention], dim = -1) 582 | 583 | B, N, T, _ = o_identity.size() 584 | gate_in = [h_identity[-1], h_adaptive[-1], h_attention] 585 | gate = torch.softmax(self.gate_network(input.permute(0,2,3,1), gate_in), dim = -1) 586 | out = torch.zeros_like(o_identity).view(-1,1) 587 | 588 | outs = [o_identity, o_adaptive, o_attention] 589 | counts = [] 590 | 591 | route_prob_max, routes = torch.max(gate, dim = -1) 592 | route_prob_max = route_prob_max.view(-1) 593 | routes = routes.view(-1) 594 | 595 | for i in range(len(outs)): 596 | cur_out = outs[i].view(-1,1) 597 | indices = torch.eq(routes, i).nonzero(as_tuple = True)[0] 598 | out[indices] = cur_out[indices] 599 | counts.append(len(indices)) 600 | if self.prob_mul: 601 | out = out * (route_prob_max).unsqueeze(dim = -1) 602 | 603 | 604 | out = out.view(B,N,T,-1) 605 | 606 | out = out.permute(0,3,1,2) 607 | if self.training or gate_out: 608 | return out, gate, ind_out 609 | else: 610 | return out 611 | 612 | if __name__ == "__main__": 613 | n = 207 614 | model = TESTAM(num_nodes = n, in_dim = 3, out_dim = 2, supports = [torch.zeros(n,n).cuda() for _ in range(2)]) 615 | x = torch.zeros(8,3,n,6).cuda() 616 | x[:,0] = torch.randn(8,n,6).cuda() 617 | model.cuda() 618 | model.eval() 619 | for p in model.parameters(): 620 | if p.dtype != torch.float32: 621 | print(p.dtype) 622 | out, gate, ind_out = model(x,gate_out = True) 623 | print(out.shape, gate.shape, ind_out.shape) 624 | 625 | 626 | --------------------------------------------------------------------------------