├── fig ├── model.pdf └── model.png ├── requirements.txt ├── LICENSE ├── README.md ├── engine.py ├── generate_training_data.py ├── test.py ├── train.py ├── util.py └── model.py /fig/model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/Graph-WaveNet/HEAD/fig/model.pdf -------------------------------------------------------------------------------- /fig/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/Graph-WaveNet/HEAD/fig/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | scipy 4 | pandas 5 | torch 6 | argparse 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) <2019> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph WaveNet for Deep Spatial-Temporal Graph Modeling 2 | 3 | This is the original pytorch implementation of Graph WaveNet in the following paper: 4 | [Graph WaveNet for Deep Spatial-Temporal Graph Modeling, IJCAI 2019] (https://arxiv.org/abs/1906.00121). A nice improvement over GraphWavenet is presented by Shleifer et al. [paper](https://arxiv.org/abs/1912.07390) [code](https://github.com/sshleifer/Graph-WaveNet). 5 | 6 | 7 | 8 |

9 | 10 |

11 | 12 | ## Requirements 13 | - python 3 14 | - see `requirements.txt` 15 | 16 | 17 | ## Data Preparation 18 | 19 | ### Step1: Download METR-LA and PEMS-BAY data from [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g) links provided by [DCRNN](https://github.com/liyaguang/DCRNN). 20 | 21 | ### Step2: Process raw data 22 | 23 | ``` 24 | # Create data directories 25 | mkdir -p data/{METR-LA,PEMS-BAY} 26 | 27 | # METR-LA 28 | python generate_training_data.py --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 29 | 30 | # PEMS-BAY 31 | python generate_training_data.py --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5 32 | 33 | ``` 34 | ## Train Commands 35 | 36 | ``` 37 | python train.py --gcn_bool --adjtype doubletransition --addaptadj --randomadj 38 | ``` 39 | 40 | 41 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from model import * 3 | import util 4 | class trainer(): 5 | def __init__(self, scaler, in_dim, seq_length, num_nodes, nhid , dropout, lrate, wdecay, device, supports, gcn_bool, addaptadj, aptinit): 6 | self.model = gwnet(device, num_nodes, dropout, supports=supports, gcn_bool=gcn_bool, addaptadj=addaptadj, aptinit=aptinit, in_dim=in_dim, out_dim=seq_length, residual_channels=nhid, dilation_channels=nhid, skip_channels=nhid * 8, end_channels=nhid * 16) 7 | self.model.to(device) 8 | self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay) 9 | self.loss = util.masked_mae 10 | self.scaler = scaler 11 | self.clip = 5 12 | 13 | def train(self, input, real_val): 14 | self.model.train() 15 | self.optimizer.zero_grad() 16 | input = nn.functional.pad(input,(1,0,0,0)) 17 | output = self.model(input) 18 | output = output.transpose(1,3) 19 | #output = [batch_size,12,num_nodes,1] 20 | real = torch.unsqueeze(real_val,dim=1) 21 | predict = self.scaler.inverse_transform(output) 22 | 23 | loss = self.loss(predict, real, 0.0) 24 | loss.backward() 25 | if self.clip is not None: 26 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 27 | self.optimizer.step() 28 | mape = util.masked_mape(predict,real,0.0).item() 29 | rmse = util.masked_rmse(predict,real,0.0).item() 30 | return loss.item(),mape,rmse 31 | 32 | def eval(self, input, real_val): 33 | self.model.eval() 34 | input = nn.functional.pad(input,(1,0,0,0)) 35 | output = self.model(input) 36 | output = output.transpose(1,3) 37 | #output = [batch_size,12,num_nodes,1] 38 | real = torch.unsqueeze(real_val,dim=1) 39 | predict = self.scaler.inverse_transform(output) 40 | loss = self.loss(predict, real, 0.0) 41 | mape = util.masked_mape(predict,real,0.0).item() 42 | rmse = util.masked_rmse(predict,real,0.0).item() 43 | return loss.item(),mape,rmse 44 | -------------------------------------------------------------------------------- /generate_training_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | feature_list = [data] 31 | if add_time_in_day: 32 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 33 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 34 | feature_list.append(time_in_day) 35 | if add_day_in_week: 36 | dow = df.index.dayofweek 37 | dow_tiled = np.tile(dow, [1, num_nodes, 1]).transpose((2, 1, 0)) 38 | feature_list.append(dow_tiled) 39 | 40 | data = np.concatenate(feature_list, axis=-1) 41 | x, y = [], [] 42 | min_t = abs(min(x_offsets)) 43 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 44 | for t in range(min_t, max_t): # t is the index of the last observation. 45 | x.append(data[t + x_offsets, ...]) 46 | y.append(data[t + y_offsets, ...]) 47 | x = np.stack(x, axis=0) 48 | y = np.stack(y, axis=0) 49 | return x, y 50 | 51 | 52 | def generate_train_val_test(args): 53 | seq_length_x, seq_length_y = args.seq_length_x, args.seq_length_y 54 | df = pd.read_hdf(args.traffic_df_filename) 55 | # 0 is the latest observed sample. 56 | x_offsets = np.sort(np.concatenate((np.arange(-(seq_length_x - 1), 1, 1),))) 57 | # Predict the next one hour 58 | y_offsets = np.sort(np.arange(args.y_start, (seq_length_y + 1), 1)) 59 | # x: (num_samples, input_length, num_nodes, input_dim) 60 | # y: (num_samples, output_length, num_nodes, output_dim) 61 | x, y = generate_graph_seq2seq_io_data( 62 | df, 63 | x_offsets=x_offsets, 64 | y_offsets=y_offsets, 65 | add_time_in_day=True, 66 | add_day_in_week=args.dow, 67 | ) 68 | 69 | print("x shape: ", x.shape, ", y shape: ", y.shape) 70 | # Write the data into npz file. 71 | num_samples = x.shape[0] 72 | num_test = round(num_samples * 0.2) 73 | num_train = round(num_samples * 0.7) 74 | num_val = num_samples - num_test - num_train 75 | x_train, y_train = x[:num_train], y[:num_train] 76 | x_val, y_val = ( 77 | x[num_train: num_train + num_val], 78 | y[num_train: num_train + num_val], 79 | ) 80 | x_test, y_test = x[-num_test:], y[-num_test:] 81 | 82 | for cat in ["train", "val", "test"]: 83 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 84 | print(cat, "x: ", _x.shape, "y:", _y.shape) 85 | np.savez_compressed( 86 | os.path.join(args.output_dir, f"{cat}.npz"), 87 | x=_x, 88 | y=_y, 89 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 90 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("--output_dir", type=str, default="data/METR-LA", help="Output directory.") 97 | parser.add_argument("--traffic_df_filename", type=str, default="data/metr-la.h5", help="Raw traffic readings.",) 98 | parser.add_argument("--seq_length_x", type=int, default=12, help="Sequence Length.",) 99 | parser.add_argument("--seq_length_y", type=int, default=12, help="Sequence Length.",) 100 | parser.add_argument("--y_start", type=int, default=1, help="Y pred start", ) 101 | parser.add_argument("--dow", action='store_true',) 102 | 103 | args = parser.parse_args() 104 | if os.path.exists(args.output_dir): 105 | reply = str(input(f'{args.output_dir} exists. Do you want to overwrite it? (y/n)')).lower().strip() 106 | if reply[0] != 'y': exit 107 | else: 108 | os.makedirs(args.output_dir) 109 | generate_train_val_test(args) 110 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import util 2 | import argparse 3 | from model import * 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--device',type=str,default='cuda:3',help='') 11 | parser.add_argument('--data',type=str,default='data/METR-LA',help='data path') 12 | parser.add_argument('--adjdata',type=str,default='data/sensor_graph/adj_mx.pkl',help='adj data path') 13 | parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type') 14 | parser.add_argument('--gcn_bool',action='store_true',help='whether to add graph convolution layer') 15 | parser.add_argument('--aptonly',action='store_true',help='whether only adaptive adj') 16 | parser.add_argument('--addaptadj',action='store_true',help='whether add adaptive adj') 17 | parser.add_argument('--randomadj',action='store_true',help='whether random initialize adaptive adj') 18 | parser.add_argument('--seq_length',type=int,default=12,help='') 19 | parser.add_argument('--nhid',type=int,default=32,help='') 20 | parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension') 21 | parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes') 22 | parser.add_argument('--batch_size',type=int,default=64,help='batch size') 23 | parser.add_argument('--learning_rate',type=float,default=0.001,help='learning rate') 24 | parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate') 25 | parser.add_argument('--weight_decay',type=float,default=0.0001,help='weight decay rate') 26 | parser.add_argument('--checkpoint',type=str,help='') 27 | parser.add_argument('--plotheatmap',type=str,default='True',help='') 28 | 29 | 30 | args = parser.parse_args() 31 | 32 | 33 | 34 | 35 | def main(): 36 | device = torch.device(args.device) 37 | 38 | _, _, adj_mx = util.load_adj(args.adjdata,args.adjtype) 39 | supports = [torch.tensor(i).to(device) for i in adj_mx] 40 | if args.randomadj: 41 | adjinit = None 42 | else: 43 | adjinit = supports[0] 44 | 45 | if args.aptonly: 46 | supports = None 47 | 48 | model = gwnet(device, args.num_nodes, args.dropout, supports=supports, gcn_bool=args.gcn_bool, addaptadj=args.addaptadj, aptinit=adjinit) 49 | model.to(device) 50 | model.load_state_dict(torch.load(args.checkpoint)) 51 | model.eval() 52 | 53 | 54 | print('model load successfully') 55 | 56 | dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size) 57 | scaler = dataloader['scaler'] 58 | outputs = [] 59 | realy = torch.Tensor(dataloader['y_test']).to(device) 60 | realy = realy.transpose(1,3)[:,0,:,:] 61 | 62 | for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()): 63 | testx = torch.Tensor(x).to(device) 64 | testx = testx.transpose(1,3) 65 | with torch.no_grad(): 66 | preds = model(testx).transpose(1,3) 67 | outputs.append(preds.squeeze()) 68 | 69 | yhat = torch.cat(outputs,dim=0) 70 | yhat = yhat[:realy.size(0),...] 71 | 72 | 73 | amae = [] 74 | amape = [] 75 | armse = [] 76 | for i in range(12): 77 | pred = scaler.inverse_transform(yhat[:,:,i]) 78 | real = realy[:,:,i] 79 | metrics = util.metric(pred,real) 80 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 81 | print(log.format(i+1, metrics[0], metrics[1], metrics[2])) 82 | amae.append(metrics[0]) 83 | amape.append(metrics[1]) 84 | armse.append(metrics[2]) 85 | 86 | log = 'On average over 12 horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 87 | print(log.format(np.mean(amae),np.mean(amape),np.mean(armse))) 88 | 89 | 90 | if args.plotheatmap == "True": 91 | adp = F.softmax(F.relu(torch.mm(model.nodevec1, model.nodevec2)), dim=1) 92 | device = torch.device('cpu') 93 | adp.to(device) 94 | adp = adp.cpu().detach().numpy() 95 | adp = adp*(1/np.max(adp)) 96 | df = pd.DataFrame(adp) 97 | sns.heatmap(df, cmap="RdYlBu") 98 | plt.savefig("./emb"+ '.pdf') 99 | 100 | y12 = realy[:,99,11].cpu().detach().numpy() 101 | yhat12 = scaler.inverse_transform(yhat[:,99,11]).cpu().detach().numpy() 102 | 103 | y3 = realy[:,99,2].cpu().detach().numpy() 104 | yhat3 = scaler.inverse_transform(yhat[:,99,2]).cpu().detach().numpy() 105 | 106 | df2 = pd.DataFrame({'real12':y12,'pred12':yhat12, 'real3': y3, 'pred3':yhat3}) 107 | df2.to_csv('./wave.csv',index=False) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import time 5 | import util 6 | import matplotlib.pyplot as plt 7 | from engine import trainer 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--device',type=str,default='cuda:3',help='') 11 | parser.add_argument('--data',type=str,default='data/METR-LA',help='data path') 12 | parser.add_argument('--adjdata',type=str,default='data/sensor_graph/adj_mx.pkl',help='adj data path') 13 | parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type') 14 | parser.add_argument('--gcn_bool',action='store_true',help='whether to add graph convolution layer') 15 | parser.add_argument('--aptonly',action='store_true',help='whether only adaptive adj') 16 | parser.add_argument('--addaptadj',action='store_true',help='whether add adaptive adj') 17 | parser.add_argument('--randomadj',action='store_true',help='whether random initialize adaptive adj') 18 | parser.add_argument('--seq_length',type=int,default=12,help='') 19 | parser.add_argument('--nhid',type=int,default=32,help='') 20 | parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension') 21 | parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes') 22 | parser.add_argument('--batch_size',type=int,default=64,help='batch size') 23 | parser.add_argument('--learning_rate',type=float,default=0.001,help='learning rate') 24 | parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate') 25 | parser.add_argument('--weight_decay',type=float,default=0.0001,help='weight decay rate') 26 | parser.add_argument('--epochs',type=int,default=100,help='') 27 | parser.add_argument('--print_every',type=int,default=50,help='') 28 | #parser.add_argument('--seed',type=int,default=99,help='random seed') 29 | parser.add_argument('--save',type=str,default='./garage/metr',help='save path') 30 | parser.add_argument('--expid',type=int,default=1,help='experiment id') 31 | 32 | args = parser.parse_args() 33 | 34 | 35 | 36 | 37 | def main(): 38 | #set seed 39 | #torch.manual_seed(args.seed) 40 | #np.random.seed(args.seed) 41 | #load data 42 | device = torch.device(args.device) 43 | sensor_ids, sensor_id_to_ind, adj_mx = util.load_adj(args.adjdata,args.adjtype) 44 | dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size) 45 | scaler = dataloader['scaler'] 46 | supports = [torch.tensor(i).to(device) for i in adj_mx] 47 | 48 | print(args) 49 | 50 | if args.randomadj: 51 | adjinit = None 52 | else: 53 | adjinit = supports[0] 54 | 55 | if args.aptonly: 56 | supports = None 57 | 58 | 59 | 60 | engine = trainer(scaler, args.in_dim, args.seq_length, args.num_nodes, args.nhid, args.dropout, 61 | args.learning_rate, args.weight_decay, device, supports, args.gcn_bool, args.addaptadj, 62 | adjinit) 63 | 64 | 65 | print("start training...",flush=True) 66 | his_loss =[] 67 | val_time = [] 68 | train_time = [] 69 | for i in range(1,args.epochs+1): 70 | #if i % 10 == 0: 71 | #lr = max(0.000002,args.learning_rate * (0.1 ** (i // 10))) 72 | #for g in engine.optimizer.param_groups: 73 | #g['lr'] = lr 74 | train_loss = [] 75 | train_mape = [] 76 | train_rmse = [] 77 | t1 = time.time() 78 | dataloader['train_loader'].shuffle() 79 | for iter, (x, y) in enumerate(dataloader['train_loader'].get_iterator()): 80 | trainx = torch.Tensor(x).to(device) 81 | trainx= trainx.transpose(1, 3) 82 | trainy = torch.Tensor(y).to(device) 83 | trainy = trainy.transpose(1, 3) 84 | metrics = engine.train(trainx, trainy[:,0,:,:]) 85 | train_loss.append(metrics[0]) 86 | train_mape.append(metrics[1]) 87 | train_rmse.append(metrics[2]) 88 | if iter % args.print_every == 0 : 89 | log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' 90 | print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]),flush=True) 91 | t2 = time.time() 92 | train_time.append(t2-t1) 93 | #validation 94 | valid_loss = [] 95 | valid_mape = [] 96 | valid_rmse = [] 97 | 98 | 99 | s1 = time.time() 100 | for iter, (x, y) in enumerate(dataloader['val_loader'].get_iterator()): 101 | testx = torch.Tensor(x).to(device) 102 | testx = testx.transpose(1, 3) 103 | testy = torch.Tensor(y).to(device) 104 | testy = testy.transpose(1, 3) 105 | metrics = engine.eval(testx, testy[:,0,:,:]) 106 | valid_loss.append(metrics[0]) 107 | valid_mape.append(metrics[1]) 108 | valid_rmse.append(metrics[2]) 109 | s2 = time.time() 110 | log = 'Epoch: {:03d}, Inference Time: {:.4f} secs' 111 | print(log.format(i,(s2-s1))) 112 | val_time.append(s2-s1) 113 | mtrain_loss = np.mean(train_loss) 114 | mtrain_mape = np.mean(train_mape) 115 | mtrain_rmse = np.mean(train_rmse) 116 | 117 | mvalid_loss = np.mean(valid_loss) 118 | mvalid_mape = np.mean(valid_mape) 119 | mvalid_rmse = np.mean(valid_rmse) 120 | his_loss.append(mvalid_loss) 121 | 122 | log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch' 123 | print(log.format(i, mtrain_loss, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mape, mvalid_rmse, (t2 - t1)),flush=True) 124 | torch.save(engine.model.state_dict(), args.save+"_epoch_"+str(i)+"_"+str(round(mvalid_loss,2))+".pth") 125 | print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time))) 126 | print("Average Inference Time: {:.4f} secs".format(np.mean(val_time))) 127 | 128 | #testing 129 | bestid = np.argmin(his_loss) 130 | engine.model.load_state_dict(torch.load(args.save+"_epoch_"+str(bestid+1)+"_"+str(round(his_loss[bestid],2))+".pth")) 131 | 132 | 133 | outputs = [] 134 | realy = torch.Tensor(dataloader['y_test']).to(device) 135 | realy = realy.transpose(1,3)[:,0,:,:] 136 | 137 | for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()): 138 | testx = torch.Tensor(x).to(device) 139 | testx = testx.transpose(1,3) 140 | with torch.no_grad(): 141 | preds = engine.model(testx).transpose(1,3) 142 | outputs.append(preds.squeeze()) 143 | 144 | yhat = torch.cat(outputs,dim=0) 145 | yhat = yhat[:realy.size(0),...] 146 | 147 | 148 | print("Training finished") 149 | print("The valid loss on best model is", str(round(his_loss[bestid],4))) 150 | 151 | 152 | amae = [] 153 | amape = [] 154 | armse = [] 155 | for i in range(12): 156 | pred = scaler.inverse_transform(yhat[:,:,i]) 157 | real = realy[:,:,i] 158 | metrics = util.metric(pred,real) 159 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 160 | print(log.format(i+1, metrics[0], metrics[1], metrics[2])) 161 | amae.append(metrics[0]) 162 | amape.append(metrics[1]) 163 | armse.append(metrics[2]) 164 | 165 | log = 'On average over 12 horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 166 | print(log.format(np.mean(amae),np.mean(amape),np.mean(armse))) 167 | torch.save(engine.model.state_dict(), args.save+"_exp"+str(args.expid)+"_best_"+str(round(his_loss[bestid],2))+".pth") 168 | 169 | 170 | 171 | if __name__ == "__main__": 172 | t1 = time.time() 173 | main() 174 | t2 = time.time() 175 | print("Total time spent: {:.4f}".format(t2-t1)) 176 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | import scipy.sparse as sp 5 | import torch 6 | from scipy.sparse import linalg 7 | 8 | 9 | class DataLoader(object): 10 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True): 11 | """ 12 | :param xs: 13 | :param ys: 14 | :param batch_size: 15 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 16 | """ 17 | self.batch_size = batch_size 18 | self.current_ind = 0 19 | if pad_with_last_sample: 20 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 21 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 22 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 23 | xs = np.concatenate([xs, x_padding], axis=0) 24 | ys = np.concatenate([ys, y_padding], axis=0) 25 | self.size = len(xs) 26 | self.num_batch = int(self.size // self.batch_size) 27 | self.xs = xs 28 | self.ys = ys 29 | 30 | def shuffle(self): 31 | permutation = np.random.permutation(self.size) 32 | xs, ys = self.xs[permutation], self.ys[permutation] 33 | self.xs = xs 34 | self.ys = ys 35 | 36 | def get_iterator(self): 37 | self.current_ind = 0 38 | 39 | def _wrapper(): 40 | while self.current_ind < self.num_batch: 41 | start_ind = self.batch_size * self.current_ind 42 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 43 | x_i = self.xs[start_ind: end_ind, ...] 44 | y_i = self.ys[start_ind: end_ind, ...] 45 | yield (x_i, y_i) 46 | self.current_ind += 1 47 | 48 | return _wrapper() 49 | 50 | class StandardScaler(): 51 | """ 52 | Standard the input 53 | """ 54 | 55 | def __init__(self, mean, std): 56 | self.mean = mean 57 | self.std = std 58 | 59 | def transform(self, data): 60 | return (data - self.mean) / self.std 61 | 62 | def inverse_transform(self, data): 63 | return (data * self.std) + self.mean 64 | 65 | 66 | 67 | def sym_adj(adj): 68 | """Symmetrically normalize adjacency matrix.""" 69 | adj = sp.coo_matrix(adj) 70 | rowsum = np.array(adj.sum(1)) 71 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 72 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 73 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 74 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense() 75 | 76 | def asym_adj(adj): 77 | adj = sp.coo_matrix(adj) 78 | rowsum = np.array(adj.sum(1)).flatten() 79 | d_inv = np.power(rowsum, -1).flatten() 80 | d_inv[np.isinf(d_inv)] = 0. 81 | d_mat= sp.diags(d_inv) 82 | return d_mat.dot(adj).astype(np.float32).todense() 83 | 84 | def calculate_normalized_laplacian(adj): 85 | """ 86 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 87 | # D = diag(A 1) 88 | :param adj: 89 | :return: 90 | """ 91 | adj = sp.coo_matrix(adj) 92 | d = np.array(adj.sum(1)) 93 | d_inv_sqrt = np.power(d, -0.5).flatten() 94 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 95 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 96 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 97 | return normalized_laplacian 98 | 99 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 100 | if undirected: 101 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 102 | L = calculate_normalized_laplacian(adj_mx) 103 | if lambda_max is None: 104 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 105 | lambda_max = lambda_max[0] 106 | L = sp.csr_matrix(L) 107 | M, _ = L.shape 108 | I = sp.identity(M, format='csr', dtype=L.dtype) 109 | L = (2 / lambda_max * L) - I 110 | return L.astype(np.float32).todense() 111 | 112 | def load_pickle(pickle_file): 113 | try: 114 | with open(pickle_file, 'rb') as f: 115 | pickle_data = pickle.load(f) 116 | except UnicodeDecodeError as e: 117 | with open(pickle_file, 'rb') as f: 118 | pickle_data = pickle.load(f, encoding='latin1') 119 | except Exception as e: 120 | print('Unable to load data ', pickle_file, ':', e) 121 | raise 122 | return pickle_data 123 | 124 | def load_adj(pkl_filename, adjtype): 125 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 126 | if adjtype == "scalap": 127 | adj = [calculate_scaled_laplacian(adj_mx)] 128 | elif adjtype == "normlap": 129 | adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32).todense()] 130 | elif adjtype == "symnadj": 131 | adj = [sym_adj(adj_mx)] 132 | elif adjtype == "transition": 133 | adj = [asym_adj(adj_mx)] 134 | elif adjtype == "doubletransition": 135 | adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))] 136 | elif adjtype == "identity": 137 | adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)] 138 | else: 139 | error = 0 140 | assert error, "adj type not defined" 141 | return sensor_ids, sensor_id_to_ind, adj 142 | 143 | 144 | def load_dataset(dataset_dir, batch_size, valid_batch_size= None, test_batch_size=None): 145 | data = {} 146 | for category in ['train', 'val', 'test']: 147 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 148 | data['x_' + category] = cat_data['x'] 149 | data['y_' + category] = cat_data['y'] 150 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 151 | # Data format 152 | for category in ['train', 'val', 'test']: 153 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 154 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size) 155 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], valid_batch_size) 156 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size) 157 | data['scaler'] = scaler 158 | return data 159 | 160 | def masked_mse(preds, labels, null_val=np.nan): 161 | if np.isnan(null_val): 162 | mask = ~torch.isnan(labels) 163 | else: 164 | mask = (labels!=null_val) 165 | mask = mask.float() 166 | mask /= torch.mean((mask)) 167 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 168 | loss = (preds-labels)**2 169 | loss = loss * mask 170 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 171 | return torch.mean(loss) 172 | 173 | def masked_rmse(preds, labels, null_val=np.nan): 174 | return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) 175 | 176 | 177 | def masked_mae(preds, labels, null_val=np.nan): 178 | if np.isnan(null_val): 179 | mask = ~torch.isnan(labels) 180 | else: 181 | mask = (labels!=null_val) 182 | mask = mask.float() 183 | mask /= torch.mean((mask)) 184 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 185 | loss = torch.abs(preds-labels) 186 | loss = loss * mask 187 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 188 | return torch.mean(loss) 189 | 190 | 191 | def masked_mape(preds, labels, null_val=np.nan): 192 | if np.isnan(null_val): 193 | mask = ~torch.isnan(labels) 194 | else: 195 | mask = (labels!=null_val) 196 | mask = mask.float() 197 | mask /= torch.mean((mask)) 198 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 199 | loss = torch.abs(preds-labels)/labels 200 | loss = loss * mask 201 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 202 | return torch.mean(loss) 203 | 204 | 205 | def metric(pred, real): 206 | mae = masked_mae(pred,real,0.0).item() 207 | mape = masked_mape(pred,real,0.0).item() 208 | rmse = masked_rmse(pred,real,0.0).item() 209 | return mae,mape,rmse 210 | 211 | 212 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import sys 6 | 7 | 8 | class nconv(nn.Module): 9 | def __init__(self): 10 | super(nconv,self).__init__() 11 | 12 | def forward(self,x, A): 13 | x = torch.einsum('ncvl,vw->ncwl',(x,A)) 14 | return x.contiguous() 15 | 16 | class linear(nn.Module): 17 | def __init__(self,c_in,c_out): 18 | super(linear,self).__init__() 19 | self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True) 20 | 21 | def forward(self,x): 22 | return self.mlp(x) 23 | 24 | class gcn(nn.Module): 25 | def __init__(self,c_in,c_out,dropout,support_len=3,order=2): 26 | super(gcn,self).__init__() 27 | self.nconv = nconv() 28 | c_in = (order*support_len+1)*c_in 29 | self.mlp = linear(c_in,c_out) 30 | self.dropout = dropout 31 | self.order = order 32 | 33 | def forward(self,x,support): 34 | out = [x] 35 | for a in support: 36 | x1 = self.nconv(x,a) 37 | out.append(x1) 38 | for k in range(2, self.order + 1): 39 | x2 = self.nconv(x1,a) 40 | out.append(x2) 41 | x1 = x2 42 | 43 | h = torch.cat(out,dim=1) 44 | h = self.mlp(h) 45 | h = F.dropout(h, self.dropout, training=self.training) 46 | return h 47 | 48 | 49 | class gwnet(nn.Module): 50 | def __init__(self, device, num_nodes, dropout=0.3, supports=None, gcn_bool=True, addaptadj=True, aptinit=None, in_dim=2,out_dim=12,residual_channels=32,dilation_channels=32,skip_channels=256,end_channels=512,kernel_size=2,blocks=4,layers=2): 51 | super(gwnet, self).__init__() 52 | self.dropout = dropout 53 | self.blocks = blocks 54 | self.layers = layers 55 | self.gcn_bool = gcn_bool 56 | self.addaptadj = addaptadj 57 | 58 | self.filter_convs = nn.ModuleList() 59 | self.gate_convs = nn.ModuleList() 60 | self.residual_convs = nn.ModuleList() 61 | self.skip_convs = nn.ModuleList() 62 | self.bn = nn.ModuleList() 63 | self.gconv = nn.ModuleList() 64 | 65 | self.start_conv = nn.Conv2d(in_channels=in_dim, 66 | out_channels=residual_channels, 67 | kernel_size=(1,1)) 68 | self.supports = supports 69 | 70 | receptive_field = 1 71 | 72 | self.supports_len = 0 73 | if supports is not None: 74 | self.supports_len += len(supports) 75 | 76 | if gcn_bool and addaptadj: 77 | if aptinit is None: 78 | if supports is None: 79 | self.supports = [] 80 | self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) 81 | self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) 82 | self.supports_len +=1 83 | else: 84 | if supports is None: 85 | self.supports = [] 86 | m, p, n = torch.svd(aptinit) 87 | initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) 88 | initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) 89 | self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(device) 90 | self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(device) 91 | self.supports_len += 1 92 | 93 | 94 | 95 | 96 | for b in range(blocks): 97 | additional_scope = kernel_size - 1 98 | new_dilation = 1 99 | for i in range(layers): 100 | # dilated convolutions 101 | self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, 102 | out_channels=dilation_channels, 103 | kernel_size=(1,kernel_size),dilation=new_dilation)) 104 | 105 | self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, 106 | out_channels=dilation_channels, 107 | kernel_size=(1, kernel_size), dilation=new_dilation)) 108 | 109 | # 1x1 convolution for residual connection 110 | self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, 111 | out_channels=residual_channels, 112 | kernel_size=(1, 1))) 113 | 114 | # 1x1 convolution for skip connection 115 | self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, 116 | out_channels=skip_channels, 117 | kernel_size=(1, 1))) 118 | self.bn.append(nn.BatchNorm2d(residual_channels)) 119 | new_dilation *=2 120 | receptive_field += additional_scope 121 | additional_scope *= 2 122 | if self.gcn_bool: 123 | self.gconv.append(gcn(dilation_channels,residual_channels,dropout,support_len=self.supports_len)) 124 | 125 | 126 | 127 | self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, 128 | out_channels=end_channels, 129 | kernel_size=(1,1), 130 | bias=True) 131 | 132 | self.end_conv_2 = nn.Conv2d(in_channels=end_channels, 133 | out_channels=out_dim, 134 | kernel_size=(1,1), 135 | bias=True) 136 | 137 | self.receptive_field = receptive_field 138 | 139 | 140 | 141 | def forward(self, input): 142 | in_len = input.size(3) 143 | if in_len dilate -|----| * ----|-- 1x1 -- + --> *input* 163 | # |-- conv -- sigm --| | 164 | # 1x1 165 | # | 166 | # ---------------------------------------> + -------------> *skip* 167 | 168 | #(dilation, init_dilation) = self.dilations[i] 169 | 170 | #residual = dilation_func(x, dilation, init_dilation, i) 171 | residual = x 172 | # dilated convolution 173 | filter = self.filter_convs[i](residual) 174 | filter = torch.tanh(filter) 175 | gate = self.gate_convs[i](residual) 176 | gate = torch.sigmoid(gate) 177 | x = filter * gate 178 | 179 | # parametrized skip connection 180 | 181 | s = x 182 | s = self.skip_convs[i](s) 183 | try: 184 | skip = skip[:, :, :, -s.size(3):] 185 | except: 186 | skip = 0 187 | skip = s + skip 188 | 189 | 190 | if self.gcn_bool and self.supports is not None: 191 | if self.addaptadj: 192 | x = self.gconv[i](x, new_supports) 193 | else: 194 | x = self.gconv[i](x,self.supports) 195 | else: 196 | x = self.residual_convs[i](x) 197 | 198 | x = x + residual[:, :, :, -x.size(3):] 199 | 200 | 201 | x = self.bn[i](x) 202 | 203 | x = F.relu(skip) 204 | x = F.relu(self.end_conv_1(x)) 205 | x = self.end_conv_2(x) 206 | return x 207 | 208 | 209 | 210 | 211 | 212 | --------------------------------------------------------------------------------