├── LICENSE ├── README.md ├── data └── sensor_graph │ ├── adj_mx.pkl │ └── adj_mx_bay.pkl ├── generate_training_data.py ├── layer.py ├── net.py ├── requirements.txt ├── train_multi_step.py ├── train_single_step.py ├── trainer.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zonghan Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTGNN 2 | This is a PyTorch implementation of the paper: [Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks](https://arxiv.org/abs/2005.11650), published in KDD-2020. 3 | 4 | ## Requirements 5 | The model is implemented using Python3 with dependencies specified in requirements.txt 6 | ## Data Preparation 7 | ### Multivariate time series datasets 8 | 9 | Download Solar-Energy, Traffic, Electricity, Exchange-rate datasets from [https://github.com/laiguokun/multivariate-time-series-data](https://github.com/laiguokun/multivariate-time-series-data). Uncompress them and move them to the data folder. 10 | 11 | ### Traffic datasets 12 | Download the METR-LA and PEMS-BAY dataset from [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g) provided by [Li et al.](https://github.com/liyaguang/DCRNN.git) . Move them into the data folder. 13 | 14 | ``` 15 | 16 | # Create data directories 17 | mkdir -p data/{METR-LA,PEMS-BAY} 18 | 19 | # METR-LA 20 | python generate_training_data.py --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 21 | 22 | # PEMS-BAY 23 | python generate_training_data.py --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5 24 | 25 | ``` 26 | 27 | ## Model Training 28 | 29 | ### Single-step 30 | 31 | * Solar-Energy 32 | 33 | ``` 34 | python train_single_step.py --save ./model-solar-3.pt --data ./data/solar_AL.txt --num_nodes 137 --batch_size 4 --epochs 30 --horizon 3 35 | #sampling 36 | python train_single_step.py --num_split 3 --save ./model-solar-sampling-3.pt --data ./data/solar_AL.txt --num_nodes 137 --batch_size 16 --epochs 30 --horizon 3 37 | ``` 38 | * Traffic 39 | 40 | ``` 41 | python train_single_step.py --save ./model-traffic3.pt --data ./data/traffic.txt --num_nodes 862 --batch_size 16 --epochs 30 --horizon 3 42 | #sampling 43 | python train_single_step.py --num_split 3 --save ./model-traffic-sampling-3.pt --data ./data/traffic --num_nodes 321 --batch_size 16 --epochs 30 --horizon 3 44 | ``` 45 | 46 | * Electricity 47 | 48 | ``` 49 | python train_single_step.py --save ./model-electricity-3.pt --data ./data/electricity.txt --num_nodes 321 --batch_size 4 --epochs 30 --horizon 3 50 | #sampling 51 | python train_single_step.py --num_split 3 --save ./model-electricity-sampling-3.pt --data ./data/electricity.txt --num_nodes 321 --batch_size 16 --epochs 30 --horizon 3 52 | ``` 53 | 54 | * Exchange-Rate 55 | 56 | ``` 57 | python train_single_step.py --save ./model/model-exchange-3.pt --data ./data/exchange_rate.txt --num_nodes 8 --subgraph_size 8 --batch_size 4 --epochs 30 --horizon 3 58 | #sampling 59 | python train_single_step.py --num_split 3 --save ./model-exchange-3.pt --data ./data/exchange_rate.txt --num_nodes 8 --subgraph_size 2 --batch_size 16 --epochs 30 --horizon 3 60 | ``` 61 | ### Multi-step 62 | * METR-LA 63 | 64 | ``` 65 | python train_multi_step.py --adj_data ./data/sensor_graph/adj_mx.pkl --data ./data/METR-LA --num_nodes 207 66 | ``` 67 | * PEMS-BAY 68 | 69 | ``` 70 | python train_multi_step.py --adj_data ./data/sensor_graph/adj_mx_bay.pkl --data ./data/PEMS-BAY/ --num_nodes 325 71 | ``` 72 | 73 | ## Citation 74 | 75 | ``` 76 | @inproceedings{wu2020connecting, 77 | title={Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks}, 78 | author={Wu, Zonghan and Pan, Shirui and Long, Guodong and Jiang, Jing and Chang, Xiaojun and Zhang, Chengqi}, 79 | booktitle={Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}, 80 | year={2020} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /generate_training_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | data_list = [data] 31 | if add_time_in_day: 32 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 33 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 34 | data_list.append(time_in_day) 35 | if add_day_in_week: 36 | day_in_week = np.zeros(shape=(num_samples, num_nodes, 7)) 37 | day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1 38 | data_list.append(day_in_week) 39 | 40 | data = np.concatenate(data_list, axis=-1) 41 | # epoch_len = num_samples + min(x_offsets) - max(y_offsets) 42 | x, y = [], [] 43 | # t is the index of the last observation. 44 | min_t = abs(min(x_offsets)) 45 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 46 | for t in range(min_t, max_t): 47 | x_t = data[t + x_offsets, ...] 48 | y_t = data[t + y_offsets, ...] 49 | x.append(x_t) 50 | y.append(y_t) 51 | x = np.stack(x, axis=0) 52 | y = np.stack(y, axis=0) 53 | return x, y 54 | 55 | 56 | def generate_train_val_test(args): 57 | df = pd.read_hdf(args.traffic_df_filename) 58 | # 0 is the latest observed sample. 59 | x_offsets = np.sort( 60 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1))) 61 | np.concatenate((np.arange(-11, 1, 1),)) 62 | ) 63 | # Predict the next one hour 64 | y_offsets = np.sort(np.arange(1, 13, 1)) 65 | # x: (num_samples, input_length, num_nodes, input_dim) 66 | # y: (num_samples, output_length, num_nodes, output_dim) 67 | x, y = generate_graph_seq2seq_io_data( 68 | df, 69 | x_offsets=x_offsets, 70 | y_offsets=y_offsets, 71 | add_time_in_day=True, 72 | add_day_in_week=False, 73 | ) 74 | 75 | print("x shape: ", x.shape, ", y shape: ", y.shape) 76 | # Write the data into npz file. 77 | # num_test = 6831, using the last 6831 examples as testing. 78 | # for the rest: 7/8 is used for training, and 1/8 is used for validation. 79 | num_samples = x.shape[0] 80 | num_test = round(num_samples * 0.2) 81 | num_train = round(num_samples * 0.7) 82 | num_val = num_samples - num_test - num_train 83 | 84 | # train 85 | x_train, y_train = x[:num_train], y[:num_train] 86 | # val 87 | x_val, y_val = ( 88 | x[num_train: num_train + num_val], 89 | y[num_train: num_train + num_val], 90 | ) 91 | # test 92 | x_test, y_test = x[-num_test:], y[-num_test:] 93 | 94 | for cat in ["train", "val", "test"]: 95 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 96 | print(cat, "x: ", _x.shape, "y:", _y.shape) 97 | np.savez_compressed( 98 | os.path.join(args.output_dir, "%s.npz" % cat), 99 | x=_x, 100 | y=_y, 101 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 102 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 103 | ) 104 | 105 | 106 | def main(args): 107 | print("Generating training data") 108 | generate_train_val_test(args) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "--output_dir", type=str, default="data/", help="Output directory." 115 | ) 116 | parser.add_argument( 117 | "--traffic_df_filename", 118 | type=str, 119 | default="data/metr-la.h5", 120 | help="Raw traffic readings.", 121 | ) 122 | args = parser.parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import numbers 6 | import torch.nn.functional as F 7 | 8 | 9 | class nconv(nn.Module): 10 | def __init__(self): 11 | super(nconv,self).__init__() 12 | 13 | def forward(self,x, A): 14 | x = torch.einsum('ncwl,vw->ncvl',(x,A)) 15 | return x.contiguous() 16 | 17 | class dy_nconv(nn.Module): 18 | def __init__(self): 19 | super(dy_nconv,self).__init__() 20 | 21 | def forward(self,x, A): 22 | x = torch.einsum('ncvl,nvwl->ncwl',(x,A)) 23 | return x.contiguous() 24 | 25 | class linear(nn.Module): 26 | def __init__(self,c_in,c_out,bias=True): 27 | super(linear,self).__init__() 28 | self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias) 29 | 30 | def forward(self,x): 31 | return self.mlp(x) 32 | 33 | 34 | class prop(nn.Module): 35 | def __init__(self,c_in,c_out,gdep,dropout,alpha): 36 | super(prop, self).__init__() 37 | self.nconv = nconv() 38 | self.mlp = linear(c_in,c_out) 39 | self.gdep = gdep 40 | self.dropout = dropout 41 | self.alpha = alpha 42 | 43 | def forward(self,x,adj): 44 | adj = adj + torch.eye(adj.size(0)).to(x.device) 45 | d = adj.sum(1) 46 | h = x 47 | dv = d 48 | a = adj / dv.view(-1, 1) 49 | for i in range(self.gdep): 50 | h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) 51 | ho = self.mlp(h) 52 | return ho 53 | 54 | 55 | class mixprop(nn.Module): 56 | def __init__(self,c_in,c_out,gdep,dropout,alpha): 57 | super(mixprop, self).__init__() 58 | self.nconv = nconv() 59 | self.mlp = linear((gdep+1)*c_in,c_out) 60 | self.gdep = gdep 61 | self.dropout = dropout 62 | self.alpha = alpha 63 | 64 | 65 | def forward(self,x,adj): 66 | adj = adj + torch.eye(adj.size(0)).to(x.device) 67 | d = adj.sum(1) 68 | h = x 69 | out = [h] 70 | a = adj / d.view(-1, 1) 71 | for i in range(self.gdep): 72 | h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) 73 | out.append(h) 74 | ho = torch.cat(out,dim=1) 75 | ho = self.mlp(ho) 76 | return ho 77 | 78 | class dy_mixprop(nn.Module): 79 | def __init__(self,c_in,c_out,gdep,dropout,alpha): 80 | super(dy_mixprop, self).__init__() 81 | self.nconv = dy_nconv() 82 | self.mlp1 = linear((gdep+1)*c_in,c_out) 83 | self.mlp2 = linear((gdep+1)*c_in,c_out) 84 | 85 | self.gdep = gdep 86 | self.dropout = dropout 87 | self.alpha = alpha 88 | self.lin1 = linear(c_in,c_in) 89 | self.lin2 = linear(c_in,c_in) 90 | 91 | 92 | def forward(self,x): 93 | #adj = adj + torch.eye(adj.size(0)).to(x.device) 94 | #d = adj.sum(1) 95 | x1 = torch.tanh(self.lin1(x)) 96 | x2 = torch.tanh(self.lin2(x)) 97 | adj = self.nconv(x1.transpose(2,1),x2) 98 | adj0 = torch.softmax(adj, dim=2) 99 | adj1 = torch.softmax(adj.transpose(2,1), dim=2) 100 | 101 | h = x 102 | out = [h] 103 | for i in range(self.gdep): 104 | h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0) 105 | out.append(h) 106 | ho = torch.cat(out,dim=1) 107 | ho1 = self.mlp1(ho) 108 | 109 | 110 | h = x 111 | out = [h] 112 | for i in range(self.gdep): 113 | h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) 114 | out.append(h) 115 | ho = torch.cat(out, dim=1) 116 | ho2 = self.mlp2(ho) 117 | 118 | return ho1+ho2 119 | 120 | 121 | 122 | class dilated_1D(nn.Module): 123 | def __init__(self, cin, cout, dilation_factor=2): 124 | super(dilated_1D, self).__init__() 125 | self.tconv = nn.ModuleList() 126 | self.kernel_set = [2,3,6,7] 127 | self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor)) 128 | 129 | def forward(self,input): 130 | x = self.tconv(input) 131 | return x 132 | 133 | class dilated_inception(nn.Module): 134 | def __init__(self, cin, cout, dilation_factor=2): 135 | super(dilated_inception, self).__init__() 136 | self.tconv = nn.ModuleList() 137 | self.kernel_set = [2,3,6,7] 138 | cout = int(cout/len(self.kernel_set)) 139 | for kern in self.kernel_set: 140 | self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor))) 141 | 142 | def forward(self,input): 143 | x = [] 144 | for i in range(len(self.kernel_set)): 145 | x.append(self.tconv[i](input)) 146 | for i in range(len(self.kernel_set)): 147 | x[i] = x[i][...,-x[-1].size(3):] 148 | x = torch.cat(x,dim=1) 149 | return x 150 | 151 | 152 | class graph_constructor(nn.Module): 153 | def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): 154 | super(graph_constructor, self).__init__() 155 | self.nnodes = nnodes 156 | if static_feat is not None: 157 | xd = static_feat.shape[1] 158 | self.lin1 = nn.Linear(xd, dim) 159 | self.lin2 = nn.Linear(xd, dim) 160 | else: 161 | self.emb1 = nn.Embedding(nnodes, dim) 162 | self.emb2 = nn.Embedding(nnodes, dim) 163 | self.lin1 = nn.Linear(dim,dim) 164 | self.lin2 = nn.Linear(dim,dim) 165 | 166 | self.device = device 167 | self.k = k 168 | self.dim = dim 169 | self.alpha = alpha 170 | self.static_feat = static_feat 171 | 172 | def forward(self, idx): 173 | if self.static_feat is None: 174 | nodevec1 = self.emb1(idx) 175 | nodevec2 = self.emb2(idx) 176 | else: 177 | nodevec1 = self.static_feat[idx,:] 178 | nodevec2 = nodevec1 179 | 180 | nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) 181 | nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) 182 | 183 | a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) 184 | adj = F.relu(torch.tanh(self.alpha*a)) 185 | mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) 186 | mask.fill_(float('0')) 187 | s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) 188 | mask.scatter_(1,t1,s1.fill_(1)) 189 | adj = adj*mask 190 | return adj 191 | 192 | def fullA(self, idx): 193 | if self.static_feat is None: 194 | nodevec1 = self.emb1(idx) 195 | nodevec2 = self.emb2(idx) 196 | else: 197 | nodevec1 = self.static_feat[idx,:] 198 | nodevec2 = nodevec1 199 | 200 | nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) 201 | nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) 202 | 203 | a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) 204 | adj = F.relu(torch.tanh(self.alpha*a)) 205 | return adj 206 | 207 | class graph_global(nn.Module): 208 | def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): 209 | super(graph_global, self).__init__() 210 | self.nnodes = nnodes 211 | self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device) 212 | 213 | def forward(self, idx): 214 | return F.relu(self.A) 215 | 216 | 217 | class graph_undirected(nn.Module): 218 | def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): 219 | super(graph_undirected, self).__init__() 220 | self.nnodes = nnodes 221 | if static_feat is not None: 222 | xd = static_feat.shape[1] 223 | self.lin1 = nn.Linear(xd, dim) 224 | else: 225 | self.emb1 = nn.Embedding(nnodes, dim) 226 | self.lin1 = nn.Linear(dim,dim) 227 | 228 | self.device = device 229 | self.k = k 230 | self.dim = dim 231 | self.alpha = alpha 232 | self.static_feat = static_feat 233 | 234 | def forward(self, idx): 235 | if self.static_feat is None: 236 | nodevec1 = self.emb1(idx) 237 | nodevec2 = self.emb1(idx) 238 | else: 239 | nodevec1 = self.static_feat[idx,:] 240 | nodevec2 = nodevec1 241 | 242 | nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) 243 | nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2)) 244 | 245 | a = torch.mm(nodevec1, nodevec2.transpose(1,0)) 246 | adj = F.relu(torch.tanh(self.alpha*a)) 247 | mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) 248 | mask.fill_(float('0')) 249 | s1,t1 = adj.topk(self.k,1) 250 | mask.scatter_(1,t1,s1.fill_(1)) 251 | adj = adj*mask 252 | return adj 253 | 254 | 255 | 256 | class graph_directed(nn.Module): 257 | def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): 258 | super(graph_directed, self).__init__() 259 | self.nnodes = nnodes 260 | if static_feat is not None: 261 | xd = static_feat.shape[1] 262 | self.lin1 = nn.Linear(xd, dim) 263 | self.lin2 = nn.Linear(xd, dim) 264 | else: 265 | self.emb1 = nn.Embedding(nnodes, dim) 266 | self.emb2 = nn.Embedding(nnodes, dim) 267 | self.lin1 = nn.Linear(dim,dim) 268 | self.lin2 = nn.Linear(dim,dim) 269 | 270 | self.device = device 271 | self.k = k 272 | self.dim = dim 273 | self.alpha = alpha 274 | self.static_feat = static_feat 275 | 276 | def forward(self, idx): 277 | if self.static_feat is None: 278 | nodevec1 = self.emb1(idx) 279 | nodevec2 = self.emb2(idx) 280 | else: 281 | nodevec1 = self.static_feat[idx,:] 282 | nodevec2 = nodevec1 283 | 284 | nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) 285 | nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) 286 | 287 | a = torch.mm(nodevec1, nodevec2.transpose(1,0)) 288 | adj = F.relu(torch.tanh(self.alpha*a)) 289 | mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) 290 | mask.fill_(float('0')) 291 | s1,t1 = adj.topk(self.k,1) 292 | mask.scatter_(1,t1,s1.fill_(1)) 293 | adj = adj*mask 294 | return adj 295 | 296 | 297 | class LayerNorm(nn.Module): 298 | __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] 299 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 300 | super(LayerNorm, self).__init__() 301 | if isinstance(normalized_shape, numbers.Integral): 302 | normalized_shape = (normalized_shape,) 303 | self.normalized_shape = tuple(normalized_shape) 304 | self.eps = eps 305 | self.elementwise_affine = elementwise_affine 306 | if self.elementwise_affine: 307 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) 308 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) 309 | else: 310 | self.register_parameter('weight', None) 311 | self.register_parameter('bias', None) 312 | self.reset_parameters() 313 | 314 | 315 | def reset_parameters(self): 316 | if self.elementwise_affine: 317 | init.ones_(self.weight) 318 | init.zeros_(self.bias) 319 | 320 | def forward(self, input, idx): 321 | if self.elementwise_affine: 322 | return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) 323 | else: 324 | return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) 325 | 326 | def extra_repr(self): 327 | return '{normalized_shape}, eps={eps}, ' \ 328 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 329 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | from layer import * 2 | 3 | 4 | class gtnet(nn.Module): 5 | def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): 6 | super(gtnet, self).__init__() 7 | self.gcn_true = gcn_true 8 | self.buildA_true = buildA_true 9 | self.num_nodes = num_nodes 10 | self.dropout = dropout 11 | self.predefined_A = predefined_A 12 | self.filter_convs = nn.ModuleList() 13 | self.gate_convs = nn.ModuleList() 14 | self.residual_convs = nn.ModuleList() 15 | self.skip_convs = nn.ModuleList() 16 | self.gconv1 = nn.ModuleList() 17 | self.gconv2 = nn.ModuleList() 18 | self.norm = nn.ModuleList() 19 | self.start_conv = nn.Conv2d(in_channels=in_dim, 20 | out_channels=residual_channels, 21 | kernel_size=(1, 1)) 22 | self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) 23 | 24 | self.seq_length = seq_length 25 | kernel_size = 7 26 | if dilation_exponential>1: 27 | self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) 28 | else: 29 | self.receptive_field = layers*(kernel_size-1) + 1 30 | 31 | for i in range(1): 32 | if dilation_exponential>1: 33 | rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) 34 | else: 35 | rf_size_i = i*layers*(kernel_size-1)+1 36 | new_dilation = 1 37 | for j in range(1,layers+1): 38 | if dilation_exponential > 1: 39 | rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) 40 | else: 41 | rf_size_j = rf_size_i+j*(kernel_size-1) 42 | 43 | self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) 44 | self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) 45 | self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, 46 | out_channels=residual_channels, 47 | kernel_size=(1, 1))) 48 | if self.seq_length>self.receptive_field: 49 | self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, 50 | out_channels=skip_channels, 51 | kernel_size=(1, self.seq_length-rf_size_j+1))) 52 | else: 53 | self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, 54 | out_channels=skip_channels, 55 | kernel_size=(1, self.receptive_field-rf_size_j+1))) 56 | 57 | if self.gcn_true: 58 | self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) 59 | self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) 60 | 61 | if self.seq_length>self.receptive_field: 62 | self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) 63 | else: 64 | self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) 65 | 66 | new_dilation *= dilation_exponential 67 | 68 | self.layers = layers 69 | self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, 70 | out_channels=end_channels, 71 | kernel_size=(1,1), 72 | bias=True) 73 | self.end_conv_2 = nn.Conv2d(in_channels=end_channels, 74 | out_channels=out_dim, 75 | kernel_size=(1,1), 76 | bias=True) 77 | if self.seq_length > self.receptive_field: 78 | self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) 79 | self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) 80 | 81 | else: 82 | self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) 83 | self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) 84 | 85 | 86 | self.idx = torch.arange(self.num_nodes).to(device) 87 | 88 | 89 | def forward(self, input, idx=None): 90 | seq_len = input.size(3) 91 | assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' 92 | 93 | if self.seq_length 0: 95 | # shrinkage = self.max_grad_norm / grad_norm 96 | # else: 97 | # shrinkage = 1. 98 | # 99 | # for param in self.params: 100 | # if shrinkage < 1: 101 | # param.grad.data.mul_(shrinkage) 102 | self.optimizer.step() 103 | return grad_norm 104 | 105 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 106 | def updateLearningRate(self, ppl, epoch): 107 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 108 | self.start_decay = True 109 | if self.last_ppl is not None and ppl > self.last_ppl: 110 | self.start_decay = True 111 | 112 | if self.start_decay: 113 | self.lr = self.lr * self.lr_decay 114 | print("Decaying learning rate to %g" % self.lr) 115 | #only decay for one epoch 116 | self.start_decay = False 117 | 118 | self.last_ppl = ppl 119 | 120 | self._makeOptimizer() 121 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | import scipy.sparse as sp 5 | import torch 6 | from scipy.sparse import linalg 7 | from torch.autograd import Variable 8 | 9 | def normal_std(x): 10 | return x.std() * np.sqrt((len(x) - 1.)/(len(x))) 11 | 12 | class DataLoaderS(object): 13 | # train and valid is the ratio of training set and validation set. test = 1 - train - valid 14 | def __init__(self, file_name, train, valid, device, horizon, window, normalize=2): 15 | self.P = window 16 | self.h = horizon 17 | fin = open(file_name) 18 | self.rawdat = np.loadtxt(fin, delimiter=',') 19 | self.dat = np.zeros(self.rawdat.shape) 20 | self.n, self.m = self.dat.shape 21 | self.normalize = 2 22 | self.scale = np.ones(self.m) 23 | self._normalized(normalize) 24 | self._split(int(train * self.n), int((train + valid) * self.n), self.n) 25 | 26 | self.scale = torch.from_numpy(self.scale).float() 27 | tmp = self.test[1] * self.scale.expand(self.test[1].size(0), self.m) 28 | 29 | self.scale = self.scale.to(device) 30 | self.scale = Variable(self.scale) 31 | 32 | self.rse = normal_std(tmp) 33 | self.rae = torch.mean(torch.abs(tmp - torch.mean(tmp))) 34 | 35 | self.device = device 36 | 37 | def _normalized(self, normalize): 38 | # normalized by the maximum value of entire matrix. 39 | 40 | if (normalize == 0): 41 | self.dat = self.rawdat 42 | 43 | if (normalize == 1): 44 | self.dat = self.rawdat / np.max(self.rawdat) 45 | 46 | # normlized by the maximum value of each row(sensor). 47 | if (normalize == 2): 48 | for i in range(self.m): 49 | self.scale[i] = np.max(np.abs(self.rawdat[:, i])) 50 | self.dat[:, i] = self.rawdat[:, i] / np.max(np.abs(self.rawdat[:, i])) 51 | 52 | def _split(self, train, valid, test): 53 | 54 | train_set = range(self.P + self.h - 1, train) 55 | valid_set = range(train, valid) 56 | test_set = range(valid, self.n) 57 | self.train = self._batchify(train_set, self.h) 58 | self.valid = self._batchify(valid_set, self.h) 59 | self.test = self._batchify(test_set, self.h) 60 | 61 | def _batchify(self, idx_set, horizon): 62 | n = len(idx_set) 63 | X = torch.zeros((n, self.P, self.m)) 64 | Y = torch.zeros((n, self.m)) 65 | for i in range(n): 66 | end = idx_set[i] - self.h + 1 67 | start = end - self.P 68 | X[i, :, :] = torch.from_numpy(self.dat[start:end, :]) 69 | Y[i, :] = torch.from_numpy(self.dat[idx_set[i], :]) 70 | return [X, Y] 71 | 72 | def get_batches(self, inputs, targets, batch_size, shuffle=True): 73 | length = len(inputs) 74 | if shuffle: 75 | index = torch.randperm(length) 76 | else: 77 | index = torch.LongTensor(range(length)) 78 | start_idx = 0 79 | while (start_idx < length): 80 | end_idx = min(length, start_idx + batch_size) 81 | excerpt = index[start_idx:end_idx] 82 | X = inputs[excerpt] 83 | Y = targets[excerpt] 84 | X = X.to(self.device) 85 | Y = Y.to(self.device) 86 | yield Variable(X), Variable(Y) 87 | start_idx += batch_size 88 | 89 | class DataLoaderM(object): 90 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True): 91 | """ 92 | :param xs: 93 | :param ys: 94 | :param batch_size: 95 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 96 | """ 97 | self.batch_size = batch_size 98 | self.current_ind = 0 99 | if pad_with_last_sample: 100 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 101 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 102 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 103 | xs = np.concatenate([xs, x_padding], axis=0) 104 | ys = np.concatenate([ys, y_padding], axis=0) 105 | self.size = len(xs) 106 | self.num_batch = int(self.size // self.batch_size) 107 | self.xs = xs 108 | self.ys = ys 109 | 110 | def shuffle(self): 111 | permutation = np.random.permutation(self.size) 112 | xs, ys = self.xs[permutation], self.ys[permutation] 113 | self.xs = xs 114 | self.ys = ys 115 | 116 | def get_iterator(self): 117 | self.current_ind = 0 118 | def _wrapper(): 119 | while self.current_ind < self.num_batch: 120 | start_ind = self.batch_size * self.current_ind 121 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 122 | x_i = self.xs[start_ind: end_ind, ...] 123 | y_i = self.ys[start_ind: end_ind, ...] 124 | yield (x_i, y_i) 125 | self.current_ind += 1 126 | 127 | return _wrapper() 128 | 129 | class StandardScaler(): 130 | """ 131 | Standard the input 132 | """ 133 | def __init__(self, mean, std): 134 | self.mean = mean 135 | self.std = std 136 | def transform(self, data): 137 | return (data - self.mean) / self.std 138 | def inverse_transform(self, data): 139 | return (data * self.std) + self.mean 140 | 141 | 142 | def sym_adj(adj): 143 | """Symmetrically normalize adjacency matrix.""" 144 | adj = sp.coo_matrix(adj) 145 | rowsum = np.array(adj.sum(1)) 146 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 147 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 148 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 149 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense() 150 | 151 | def asym_adj(adj): 152 | """Asymmetrically normalize adjacency matrix.""" 153 | adj = sp.coo_matrix(adj) 154 | rowsum = np.array(adj.sum(1)).flatten() 155 | d_inv = np.power(rowsum, -1).flatten() 156 | d_inv[np.isinf(d_inv)] = 0. 157 | d_mat= sp.diags(d_inv) 158 | return d_mat.dot(adj).astype(np.float32).todense() 159 | 160 | def calculate_normalized_laplacian(adj): 161 | """ 162 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 163 | # D = diag(A 1) 164 | :param adj: 165 | :return: 166 | """ 167 | adj = sp.coo_matrix(adj) 168 | d = np.array(adj.sum(1)) 169 | d_inv_sqrt = np.power(d, -0.5).flatten() 170 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 171 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 172 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 173 | return normalized_laplacian 174 | 175 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 176 | if undirected: 177 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 178 | L = calculate_normalized_laplacian(adj_mx) 179 | if lambda_max is None: 180 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 181 | lambda_max = lambda_max[0] 182 | L = sp.csr_matrix(L) 183 | M, _ = L.shape 184 | I = sp.identity(M, format='csr', dtype=L.dtype) 185 | L = (2 / lambda_max * L) - I 186 | return L.astype(np.float32).todense() 187 | 188 | 189 | def load_pickle(pickle_file): 190 | try: 191 | with open(pickle_file, 'rb') as f: 192 | pickle_data = pickle.load(f) 193 | except UnicodeDecodeError as e: 194 | with open(pickle_file, 'rb') as f: 195 | pickle_data = pickle.load(f, encoding='latin1') 196 | except Exception as e: 197 | print('Unable to load data ', pickle_file, ':', e) 198 | raise 199 | return pickle_data 200 | 201 | def load_adj(pkl_filename): 202 | sensor_ids, sensor_id_to_ind, adj = load_pickle(pkl_filename) 203 | return adj 204 | 205 | 206 | def load_dataset(dataset_dir, batch_size, valid_batch_size= None, test_batch_size=None): 207 | data = {} 208 | for category in ['train', 'val', 'test']: 209 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 210 | data['x_' + category] = cat_data['x'] 211 | data['y_' + category] = cat_data['y'] 212 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 213 | # Data format 214 | for category in ['train', 'val', 'test']: 215 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 216 | 217 | data['train_loader'] = DataLoaderM(data['x_train'], data['y_train'], batch_size) 218 | data['val_loader'] = DataLoaderM(data['x_val'], data['y_val'], valid_batch_size) 219 | data['test_loader'] = DataLoaderM(data['x_test'], data['y_test'], test_batch_size) 220 | data['scaler'] = scaler 221 | return data 222 | 223 | 224 | 225 | def masked_mse(preds, labels, null_val=np.nan): 226 | if np.isnan(null_val): 227 | mask = ~torch.isnan(labels) 228 | else: 229 | mask = (labels!=null_val) 230 | mask = mask.float() 231 | mask /= torch.mean((mask)) 232 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 233 | loss = (preds-labels)**2 234 | loss = loss * mask 235 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 236 | return torch.mean(loss) 237 | 238 | def masked_rmse(preds, labels, null_val=np.nan): 239 | return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) 240 | 241 | 242 | def masked_mae(preds, labels, null_val=np.nan): 243 | if np.isnan(null_val): 244 | mask = ~torch.isnan(labels) 245 | else: 246 | mask = (labels!=null_val) 247 | mask = mask.float() 248 | mask /= torch.mean((mask)) 249 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 250 | loss = torch.abs(preds-labels) 251 | loss = loss * mask 252 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 253 | return torch.mean(loss) 254 | 255 | def masked_mape(preds, labels, null_val=np.nan): 256 | if np.isnan(null_val): 257 | mask = ~torch.isnan(labels) 258 | else: 259 | mask = (labels!=null_val) 260 | mask = mask.float() 261 | mask /= torch.mean((mask)) 262 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 263 | loss = torch.abs(preds-labels)/labels 264 | loss = loss * mask 265 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 266 | return torch.mean(loss) 267 | 268 | 269 | def metric(pred, real): 270 | mae = masked_mae(pred,real,0.0).item() 271 | mape = masked_mape(pred,real,0.0).item() 272 | rmse = masked_rmse(pred,real,0.0).item() 273 | return mae,mape,rmse 274 | 275 | 276 | def load_node_feature(path): 277 | fi = open(path) 278 | x = [] 279 | for li in fi: 280 | li = li.strip() 281 | li = li.split(",") 282 | e = [float(t) for t in li[1:]] 283 | x.append(e) 284 | x = np.array(x) 285 | mean = np.mean(x,axis=0) 286 | std = np.std(x,axis=0) 287 | z = torch.tensor((x-mean)/std,dtype=torch.float) 288 | return z 289 | 290 | 291 | def normal_std(x): 292 | return x.std() * np.sqrt((len(x) - 1.) / (len(x))) 293 | 294 | 295 | 296 | --------------------------------------------------------------------------------