├── README.md ├── figs ├── framework.png └── scconv.png ├── hierarchical_graph_conv.py ├── share.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction 2 | This is a Pytorch implementation of SHARE architecture as described in the paper [Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction](https://arxiv.org/pdf/1911.10516). 3 | 4 |

5 | 6 |

7 | 8 | If you take advantage of the SHARE model in your research, please cite the following: 9 | 10 | ``` 11 | @inproceedings{zhang2019semi, 12 | title={Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction}, 13 | author={Zhang, Weijia and Liu, Hao and Liu, Yanchi and Zhou, Jingbo and Xiong, Hui}, 14 | booktitle={Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence}, 15 | pages={1186--1193}, 16 | year={2020} 17 | } 18 | ``` 19 | 20 | # Requirements 21 | This code is based on Python3 (>= 3.6). There are a few dependencies to run the code. The major libraries are listed as follows: 22 | * Pytorch (0.4.1) 23 | * [dgl](https://github.com/dmlc/dgl) (0.4.1) 24 | 25 | 26 | -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willzhang3/SHARE-parking_availability_prediction-Pytorch/dccfc82200a468529455ca16b6690156c9bbc3f7/figs/framework.png -------------------------------------------------------------------------------- /figs/scconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willzhang3/SHARE-parking_availability_prediction-Pytorch/dccfc82200a468529455ca16b6690156c9bbc3f7/figs/scconv.png -------------------------------------------------------------------------------- /hierarchical_graph_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import dgl 8 | 9 | """CxtConv and PropConv""" 10 | class SpGraphAttentionLayer(nn.Module): 11 | def __init__(self, in_features, out_features, dropout, alpha, ret_adj=False, pa_prop=False): 12 | super(SpGraphAttentionLayer, self).__init__() 13 | self.in_features = in_features 14 | self.out_features = out_features 15 | self.pa_prop = pa_prop 16 | self.w_key = nn.Linear(in_features, out_features, bias=True) 17 | self.w_value = nn.Linear(in_features, out_features, bias=True) 18 | self.leakyrelu = nn.LeakyReLU(alpha) 19 | self.cosinesimilarity = nn.CosineSimilarity(dim=-1, eps=1e-8) 20 | 21 | def edge_attention(self, edges): 22 | # edge UDF 23 | # dot-product attention 24 | att_sim = torch.sum(torch.mul(edges.src['h_key'], edges.dst['h_key']),dim=-1) 25 | # att_sim = self.cosinesimilarity(edges.src['h_key'], edges.dst['h_key']) 26 | return {'att_sim': att_sim} 27 | 28 | def message_func(self, edges): 29 | # message UDF 30 | return {'h_value': edges.src['h_value'], 'att_sim': edges.data['att_sim']} 31 | 32 | def reduce_func(self, nodes): 33 | # reduce UDF 34 | alpha = F.softmax(nodes.mailbox['att_sim'], dim=1) # (# of nodes, # of neibors) 35 | alpha = alpha.unsqueeze(-1) 36 | h_att = torch.sum(alpha * nodes.mailbox['h_value'], dim=1) 37 | return {'h_att': h_att} 38 | 39 | def forward(self, X_key, X_value, g): 40 | """ 41 | :param X_key: X_key data of shape (batch_size(B), num_nodes(N), in_features_1). 42 | :param X_value: X_value dasta of shape (batch_size, num_nodes(N), in_features_2). 43 | :param g: sparse graph. 44 | :return: Output data of shape (batch_size, num_nodes(N), out_features). 45 | """ 46 | B,N,in_features = X_key.size() 47 | h_key = self.w_key(X_key) # (B,N,out_features) 48 | h_key = h_key.view(B*N,-1) # (B*N,out_features) 49 | h_value = X_value if(self.pa_prop == True) else self.w_value(X_value) 50 | h_value = h_value.view(B*N,-1) 51 | g.ndata['h_key'] = h_key 52 | g.ndata['h_value']= h_value 53 | g.apply_edges(self.edge_attention) 54 | g.update_all(self.message_func, self.reduce_func) 55 | h_att = g.ndata.pop('h_att').view(B,N,-1) # (B,N,out_features) 56 | h_conv = h_att if(self.pa_prop == True) else self.leakyrelu(h_att) 57 | return h_conv 58 | 59 | class GAT(nn.Module): 60 | def __init__(self, in_feat, nhid=32, dropout=0, alpha=0.2, hopnum=2, pa_prop=False): 61 | """sparse GAT.""" 62 | super(GAT, self).__init__() 63 | self.pa_prop = pa_prop 64 | self.dropout = nn.Dropout(dropout) 65 | if(pa_prop == True): hopnum = 1 66 | print('hopnum_gat:',hopnum) 67 | self.gat_stacks = nn.ModuleList() 68 | for i in range(hopnum): 69 | if(i > 0): in_feat = nhid 70 | att_layer = SpGraphAttentionLayer(in_feat, nhid, dropout=dropout, alpha=alpha, pa_prop=pa_prop) 71 | self.gat_stacks.append(att_layer) 72 | 73 | def forward(self, X_key, X_value, adj): 74 | out = X_key 75 | for att_layer in self.gat_stacks: 76 | if(self.pa_prop == True): 77 | out = att_layer(out, X_value, adj) 78 | else: 79 | out = att_layer(out, out, adj) 80 | return out 81 | 82 | """SCConv""" 83 | class SCConv(nn.Module): 84 | def __init__(self, in_features, out_features, dropout, alpha, latend_num, gcn_hop): 85 | super(SCConv, self).__init__() 86 | self.in_features = in_features 87 | self.dropout = nn.Dropout(dropout) 88 | self.leakyrelu = nn.LeakyReLU(alpha) 89 | self.conv_block_after_pool = GCN(in_features=self.in_features, out_features=out_features, \ 90 | dropout=dropout, alpha=alpha, hop = gcn_hop) 91 | self.w_classify = nn.Linear(self.in_features, latend_num, bias=True) 92 | 93 | def apply_bn(self, x): 94 | # Batch normalization of 3D tensor x 95 | bn_module = nn.BatchNorm1d(x.size()[1]).cuda() 96 | x = bn_module(x) 97 | return x 98 | 99 | def forward(self, X_lots, adj): 100 | """ 101 | :param X_lots: Concat of the outputs of CxtConv and PA_approximation (batch_size, N, in_features). 102 | :param adj: adj_merge (N, N). 103 | :return: Output soft clustering representation for each parking lot of shape (batch_size, N, out_features). 104 | """ 105 | B, N, in_features = X_lots.size() 106 | h_now = self.dropout(X_lots) # (B, N, F) 107 | S = self.w_classify(h_now) # (B, N, latend_num(K)) 108 | S = F.softmax(S,dim=-1) # (B, N, K) 109 | h_c = torch.bmm(S.permute(0,2,1),h_now) # (B, K, F) 110 | h_c = self.apply_bn(h_c) 111 | adj = torch.bmm(torch.bmm(S.permute(0,2,1),adj),S) # (B, K, K) 112 | # GCN 113 | h_latent = self.dropout(self.conv_block_after_pool(h_c,adj)) # (B, K, F) 114 | h_sc = torch.bmm(S,h_latent) 115 | return h_sc 116 | 117 | class GCN(nn.Module): 118 | def __init__(self, in_features, out_features, dropout, alpha, hop = 1): 119 | super(GCN, self).__init__() 120 | self.in_features = in_features 121 | self.hop = hop 122 | self.dropout = nn.Dropout(dropout) 123 | self.leakyrelu = nn.LeakyReLU(alpha) 124 | self.w_lot = nn.ModuleList() 125 | for i in range(hop): 126 | in_features = (self.in_features) if(i==0) else out_features 127 | self.w_lot.append(nn.Linear(in_features, out_features, bias=True)) 128 | 129 | def forward(self, h_c, adj): 130 | # adj normalize 131 | adj_rowsum = torch.sum(adj,dim=-1,keepdim=True) 132 | adj = adj.div(torch.where(adj_rowsum>1e-8, adj_rowsum, 1e-8*torch.ones(1,1).cuda())) # row normalize 133 | # weight aggregate 134 | for i in range(self.hop): 135 | h_c = torch.bmm(adj,h_c) 136 | h_c = self.leakyrelu(self.w_lot[i](h_c)) #(B, N, F) 137 | return h_c 138 | 139 | -------------------------------------------------------------------------------- /share.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import time 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from hierarchical_graph_conv import GAT, SCConv 7 | 8 | class FeatureEmb(nn.Module): 9 | def __init__(self): 10 | super(FeatureEmb, self).__init__() 11 | # time embedding 12 | # month,day,hour,minute,dayofweek 13 | self.time_emb = nn.ModuleList([nn.Embedding(feature_size, 4) for feature_size in [12,31,24,4,7]]) 14 | for ele in self.time_emb: 15 | nn.init.xavier_uniform_(ele.weight.data, gain=math.sqrt(2.0)) 16 | 17 | def forward(self, X, pa_onehot): 18 | B, N, T_in, F = X.size() # (batch_size, N, T_in, F) 19 | X_time =torch.cat([emb(X[:,:,:,i+4].long()) for i,emb in enumerate(self.time_emb)],dim=-1) # time F = 4*5 = 20 20 | X_cxt = X[...,2:4] # contextual features 21 | X_pa = X[...,:1].long() # PA, 0,1,...,49 22 | pa_scatter = pa_onehot.clone() 23 | X_pa = pa_scatter.scatter_(-1,X_pa,1.0) # discretize to one-hot , F = 50 24 | return X_cxt, X_pa, X_time 25 | 26 | class SHARE(nn.Module): 27 | def __init__(self,args,t_in,t_out,latend_num,train_num,dropout=0.5,alpha=0.2,hid_dim=32,\ 28 | gat_hop=2,device=torch.device('cuda')): 29 | super(SHARE, self).__init__() 30 | self.device = device 31 | # number of context features (here set 2 for test) 32 | self.nfeat = 2 33 | self.hid_dim = hid_dim 34 | self.train_num = train_num 35 | # Feature embedding 36 | self.feature_embedding = FeatureEmb() 37 | # FC layers 38 | self.output_fc = nn.Linear(hid_dim*2, t_out, bias=True) 39 | self.w_pred = nn.Linear(hid_dim*2, 50, bias=True) 40 | self.leakyrelu = nn.LeakyReLU(alpha) 41 | 42 | # Spatial blocks 43 | # CxtConv 44 | self.CxtConv = GAT(in_feat=self.nfeat, nhid=hid_dim, dropout=dropout, alpha=alpha, hopnum=gat_hop, pa_prop=False) 45 | # PropConv 46 | self.PropConv = GAT(in_feat=self.nfeat, nhid=hid_dim, dropout=dropout, alpha=alpha, hopnum=1, pa_prop=True) 47 | # SCConv 48 | self.SCConv = SCConv(in_features=hid_dim+50, out_features=hid_dim, dropout=dropout,\ 49 | alpha=alpha, latend_num=latend_num, gcn_hop = 1) 50 | 51 | # GRU Cell 52 | self.GRU = nn.GRUCell(2*hid_dim+50+20, hid_dim*2, bias=True) 53 | nn.init.xavier_uniform_(self.GRU.weight_ih,gain=math.sqrt(2.0)) 54 | nn.init.xavier_uniform_(self.GRU.weight_hh,gain=math.sqrt(2.0)) 55 | 56 | # Parameter initialization 57 | for ele in self.modules(): 58 | if isinstance(ele, nn.Linear): 59 | nn.init.xavier_uniform_(ele.weight,gain=math.sqrt(2.0)) 60 | 61 | def forward(self, adjs, X, h_t, pa_onehot): 62 | """ 63 | :param adjs: CxtConv, PropConv and SCconv adj. 64 | :param X: Input data of shape (batch_size, num_nodes(N), T_in, num_features(F)). 65 | :param h_t: To init GRU hidden state with shape (N, 2*hid_dim). 66 | :param pa_onehot: be used to discretize y for PA approximation 67 | :return: predicted PA and CE_loss 68 | """ 69 | adj,adj_label,adj_dense = adjs 70 | B,N,T,F_feat = X.size() 71 | X_cxt,X_pa,X_time = self.feature_embedding(X, pa_onehot) 72 | # GRU and Spatial blocks 73 | CE_loss = 0.0 74 | for i in range(T): 75 | y_t = F.softmax(self.w_pred(h_t),dim=-1) # (B, N, p=50) 76 | if(i==T-1): 77 | CE_loss += F.binary_cross_entropy(y_t[:,:self.train_num,:].reshape(B*self.train_num,-1),\ 78 | X_pa[:,:self.train_num,i,:].reshape(B*self.train_num,-1)) 79 | # PropConv 80 | y_att = self.PropConv(X_cxt[:,:,i,:],X_pa[:,:,i,:], adj_label) # (B, N, p=50) 81 | if(i==T-1): 82 | y_att[:,:self.train_num,:] = torch.where(y_att[:,:self.train_num,:]<1.,y_att[:,:self.train_num,:],\ 83 | (1.-1e-8)*torch.ones(1,1).cuda()) 84 | CE_loss += F.binary_cross_entropy(y_att[:,:self.train_num,:].reshape(B*self.train_num,-1),\ 85 | X_pa[:,:self.train_num,i,:].reshape(B*self.train_num,-1)) 86 | # PA approximation 87 | en_yt = torch.exp(torch.sum(y_t*torch.log\ 88 | (torch.where(y_t>1e-8,y_t,1e-8*torch.ones(1,1).cuda())),dim=-1,keepdim=True)) 89 | en_yatt = torch.exp(torch.sum(y_att*torch.log\ 90 | (torch.where(y_att>1e-8,y_att,1e-8*torch.ones(1,1).cuda())),dim=-1,keepdim=True)) 91 | en_yatt = torch.where(torch.sum(y_att,dim=-1,keepdim=True)>1e-8,en_yatt,torch.zeros(1,1).cuda()) 92 | pseudo_y = (en_yt*y_t + en_yatt*y_att)/(en_yt+en_yatt) 93 | if(self.training == False): 94 | pseudo_y[:,:self.train_num,:] = X_pa[:,:self.train_num,i,:] 95 | # CxtConv 96 | h_cxt = self.CxtConv(X_cxt[:,:,i,:],None,adj) # (B, N, tmp_hid) 97 | # SCConv 98 | h_sc = self.SCConv(torch.cat([h_cxt,pseudo_y],dim=-1),adj_dense) 99 | X_feat = torch.cat([h_cxt,pseudo_y,h_sc,X_time[...,i,:]],dim=-1) 100 | h_t = self.GRU(X_feat.view(-1,2*self.hid_dim+50+20), h_t.view(-1,self.hid_dim*2)) # (B*N, 2*tmp_hid) 101 | h_t = h_t.view(B,N,-1) 102 | 103 | out = torch.sigmoid(self.output_fc(h_t)) # (B, N, T_out) 104 | return out, CE_loss 105 | 106 | 107 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import time 8 | import dgl 9 | import torch.nn as nn 10 | import logging 11 | import random 12 | from share import SHARE 13 | from utils import * 14 | 15 | which_gpu = "0" 16 | os.environ["CUDA_VISIBLE_DEVICES"] = which_gpu 17 | parser = argparse.ArgumentParser(description='SHARE') 18 | parser.add_argument('--enable_cuda', action='store_true', default=True, help='GPU training.') 19 | parser.add_argument('--loss', type=str,default='mse', help='Loss function.') 20 | parser.add_argument('--seed', type=int, default=33, help='Random seed.') 21 | parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train.') 22 | parser.add_argument('--batch_size', type=int, default=4, help='Number of batch to train and test.') 23 | parser.add_argument('--t_in', type=int, default=12, help='Input time step.') 24 | parser.add_argument('--t_out', type=int, default=3, help='Output time step.') 25 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.') 26 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay (L2 loss on parameters).') 27 | parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.') 28 | parser.add_argument('--hid_dim', type=int, default=32, help='Dim of hidden units.') 29 | parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).') 30 | parser.add_argument('--hc_ratio', type=float, default=0.1, help='ratio in Hierarchical graph.') 31 | parser.add_argument('--topk', type=int, default=10, help='k nearest neighbors in Propagating graph.') 32 | parser.add_argument('--disteps', type=int, default=1000, help='Farthest neighbors distance in Context graph.') 33 | parser.add_argument('--gat_hop', type=int, default=2, help='Number of hop in CxtConv.') 34 | parser.add_argument('--train_num', type=int, default=-1, help='Number of parking lots for train.') 35 | parser.add_argument('--train_ratio', type=float, default=0.3, help='Parking lots ratio for train.') 36 | parser.add_argument('--patience', type=int, default=30, help='Patience') 37 | parser.add_argument('--beta', type=float, default=0.5, help='Beta of CE_loss .') 38 | 39 | args = parser.parse_args() 40 | args.train_num = int(1965*args.train_ratio+0.5) 41 | logging.basicConfig(level = logging.INFO,filename='./log',format = '%(asctime)s - %(process)s - %(levelname)s - %(message)s') 42 | logger = logging.getLogger(__name__) 43 | 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | if args.enable_cuda and torch.cuda.is_available(): 48 | torch.cuda.manual_seed(args.seed) 49 | args.device = torch.device('cuda') 50 | else: 51 | args.device = torch.device('cpu') 52 | print(args) 53 | logger.info(args) 54 | 55 | def train_epoch(loader_train,h_t,prr_onehot): 56 | """ 57 | Trains one epoch with the given data. 58 | :param loader_train: Training data (X,Y). 59 | :return: Average loss of last bacth. 60 | """ 61 | for i,(X_batch,Y_batch) in enumerate(loader_train): 62 | net.train() 63 | X_batch = X_batch.to(device=args.device) 64 | Y_batch = Y_batch.to(device=args.device) 65 | now_batch = X_batch.shape[0] 66 | optimizer.zero_grad() 67 | if(now_batch == args.batch_size): 68 | y_pred,prr_loss = net(adjs, X_batch, h_t, prr_onehot) # (B, N , T_out) 69 | else: 70 | adj,adj_label,adj_dense = adjs 71 | adjs_copy = (spadj_expand(g_adj,now_batch),spadj_expand(g_adj_label,now_batch),adj_dense[:now_batch]) 72 | y_pred,prr_loss = net(adjs_copy,X_batch,h_t[:now_batch],prr_onehot[:now_batch]) 73 | loss = loss_criterion(y_pred[:,idx_train,:], Y_batch[:,idx_train,:]) + prr_loss*args.beta 74 | loss.backward() 75 | optimizer.step() 76 | if (i*args.batch_size % 80 == 0): 77 | print(i*args.batch_size) 78 | print("train loss:{:.4f}".format(loss.detach().cpu().numpy())) 79 | return loss.detach().cpu().numpy() 80 | 81 | def test_epoch(loader_val,h_t,prr_onehot): 82 | """ 83 | Test one epoch with the given data. 84 | :param loader_val: Valuation or Test data (X,Y). 85 | :return: Loss and MAE. 86 | """ 87 | val_loss = [] 88 | val_mae = [] 89 | for i,(X_batch,Y_batch) in enumerate(loader_val): 90 | if (i*args.batch_size % 80 == 0): 91 | print(i*args.batch_size) 92 | net.eval() 93 | X_batch = X_batch.to(device=args.device) 94 | Y_batch = Y_batch.to(device=args.device) # (B,N,T_out) 95 | now_batch = X_batch.shape[0] 96 | if(now_batch == args.batch_size): 97 | y_pred,prr_loss = net(adjs, X_batch, h_t, prr_onehot) # (B,N,T_out) 98 | else: 99 | adj,adj_label,adj_dense = adjs 100 | adjs_copy = (spadj_expand(g_adj,now_batch),spadj_expand(g_adj_label,now_batch),adj_dense[:now_batch]) 101 | y_pred,prr_loss = net(adjs_copy, X_batch, h_t[:now_batch], prr_onehot[:now_batch]) 102 | loss_val = loss_criterion(y_pred[:,idx_val,:], Y_batch[:,idx_val,:]) + prr_loss*args.beta 103 | val_loss.append(np.asscalar(loss_val.detach().cpu().numpy())) 104 | mae = np.absolute(y_pred[:,idx_val,:].detach().cpu().numpy()*total_park[:,idx_val]\ 105 | -Y_batch[:,idx_val,:].detach().cpu().numpy()*total_park[:,idx_val]) # (B,N,T_out) 106 | val_mae.append(mae) 107 | return np.asarray(val_loss),np.concatenate(val_mae,axis=0) 108 | 109 | def spadj_expand(adj, batch_size): 110 | adj = dgl.batch([adj]*batch_size) 111 | return adj 112 | 113 | def print_log(mae,mse,loss,stage): 114 | mae_o = [np.mean(mae[:,:,i]) for i in range(args.t_out)] 115 | mse_o = [np.mean(mse[:,:,i]) for i in range(args.t_out)] 116 | rmse_o = [np.sqrt(ele) for ele in mse_o] 117 | stage_str = "{} - mean metrics: mae,mse,rmse,loss".format(stage) 118 | mean_str = "mean metric values: {},{},{},{}".format(np.mean(mae_o),np.mean(mse_o),np.mean(rmse_o),np.mean(loss)) 119 | mae_str = "MAE: {}".format(','.join(str(ele) for ele in mae_o)) 120 | mse_str = "MSE: {}".format(','.join(str(ele) for ele in mse_o)) 121 | rmse_str = "RMSE: {}".format(','.join(str(ele) for ele in rmse_o)) 122 | print(stage_str) 123 | print(mean_str) 124 | print(mae_str) 125 | print(mse_str) 126 | print(rmse_str) 127 | logger.info(stage_str) 128 | logger.info(mean_str) 129 | logger.info(mae_str) 130 | logger.info(mse_str) 131 | logger.info(rmse_str) 132 | 133 | if __name__ == '__main__': 134 | 135 | adjs,loader_train,loader_val,loader_test,idx_train,idx_val,total_park = \ 136 | load_data(args.t_in,args.t_out,args.batch_size,args.train_num,args.topk,args.disteps) 137 | N = total_park.shape[1] # total number of parking lots 138 | adj,adj_label = adjs 139 | latend_num = int(N*args.hc_ratio+0.5) # latent node number 140 | print('latend num:',latend_num) 141 | adj_edgenum = adj.shape[1] 142 | adj_label_edgenum = adj_label.shape[1] 143 | adj = torch.from_numpy(adj).long() 144 | adj_label = torch.from_numpy(adj_label).long() 145 | # Merge 2 graph as scconv's adj 146 | adj_dense = torch.sparse_coo_tensor(adj,torch.ones((adj.shape[1])),torch.Size([N,N])).to_dense() 147 | adj_dense_label = torch.sparse_coo_tensor(adj_label,torch.ones((adj_label.shape[1])),torch.Size([N,N])).to_dense() 148 | adj_dense = adj_dense + adj_dense_label 149 | adj_dense = torch.where(adj_dense<1e-8,adj_dense,torch.ones(1,1)) 150 | adj_merge = adj_dense.to(device=args.device).repeat(args.batch_size,1,1) 151 | g_adj = dgl.DGLGraph() 152 | g_adj.add_nodes(N) 153 | g_adj.add_edges(adj[0],adj[1]) 154 | # expand for batch training 155 | adj = spadj_expand(g_adj,args.batch_size) 156 | g_adj_label = dgl.DGLGraph() 157 | g_adj_label.add_nodes(N) 158 | g_adj_label.add_edges(adj_label[0],adj_label[1]) 159 | adj_label = spadj_expand(g_adj_label,args.batch_size) 160 | adjs = (adj,adj_label,adj_merge) 161 | # to init GRU hidden state 162 | h_t = torch.zeros(args.batch_size,N,args.hid_dim*2).to(device=args.device) 163 | # to discretize y for PA approximation 164 | prr_onehot = torch.zeros(args.batch_size,N,args.t_in,50).to(device=args.device) 165 | 166 | # model 167 | net = SHARE(args = args, 168 | t_in=args.t_in, 169 | t_out=args.t_out, 170 | latend_num = latend_num, 171 | train_num = args.train_num, 172 | dropout=args.dropout, 173 | alpha=args.alpha, 174 | hid_dim=args.hid_dim, 175 | gat_hop = args.gat_hop, 176 | device=args.device).to(device=args.device) 177 | 178 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 179 | 180 | if(args.loss == 'mse'): 181 | loss_criterion = nn.MSELoss() 182 | elif(args.loss == 'mae'): 183 | loss_criterion = nn.L1Loss() 184 | 185 | best_epoch = 0 186 | min_mse = 1e15 187 | st_epoch = best_epoch 188 | for epoch in range(st_epoch,args.epochs): 189 | st_time = time.time() 190 | # training 191 | print('training......') 192 | loss_train = train_epoch(loader_train,h_t,prr_onehot) 193 | # validating 194 | with torch.no_grad(): 195 | print('validating......') 196 | val_loss,val_mae = test_epoch(loader_val,h_t,prr_onehot) 197 | val_mse = val_mae**2 198 | # testing 199 | with torch.no_grad(): 200 | print('testing......') 201 | test_loss,test_mae = test_epoch(loader_test,h_t,prr_onehot) 202 | test_mse = test_mae**2 203 | 204 | val_meanmse = np.mean(val_mse) 205 | if(val_meanmse < min_mse): 206 | min_mse = val_meanmse 207 | best_epoch = epoch + 1 208 | best_mae = test_mae.copy() 209 | best_mse = test_mse.copy() 210 | best_loss = test_loss.copy() 211 | # log 212 | try: 213 | print("Epoch: {}".format(epoch+1)) 214 | logger.info("Epoch: {}".format(epoch+1)) 215 | print("Train loss: {}".format(loss_train)) 216 | logger.info("Train loss: {}".format(loss_train)) 217 | print_log(val_mae,val_mse,val_loss,'Validation') 218 | print_log(test_mae,test_mse,test_loss,'Test') 219 | print_log(best_mae,best_mse,best_loss,'Best Epoch-{}'.format(best_epoch)) 220 | print('time: {:.4f}s'.format(time.time() - st_time)) 221 | logger.info('time: {:.4f}s\n'.format(time.time() - st_time)) 222 | except: 223 | print("log error...") 224 | 225 | # early stop 226 | if(epoch+1 - best_epoch >= args.patience): 227 | sys.exit(0) 228 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset,DataLoader 4 | 5 | class GetDataset(Dataset): 6 | def __init__(self, X, Y): 7 | 8 | self.X = X # numpy.ndarray (num_data, num_nodes(N), T_in, num_features(F)) 9 | self.Y = Y # numpy.ndarray (num_data, num_nodes(N), T_out, num_features(F)) 10 | 11 | def __getitem__(self, index): 12 | 13 | # torch.Tensor 14 | tensor_X = self.X[index] 15 | tensor_Y = self.Y[index] 16 | 17 | return tensor_X, tensor_Y 18 | 19 | def __len__(self): 20 | return len(self.X) 21 | 22 | def make_dataset(rawdata, T_in, T_out): 23 | 24 | """ 25 | :input: rawdata (num_nodes(N), T, F1) 26 | :return X: (num_data, num_nodes(N), T_in, F2) 27 | :return Y: (num_data, num_nodes(N), T_out) 28 | """ 29 | X,Y = [],[] 30 | T_all = rawdata.shape[1] 31 | pdata = rawdata.copy() 32 | # be used to discretize y for PA approximation (y has been normalized to [0,1]) 33 | pdata = np.concatenate([((pdata[:,:,:1]-1e-8)*50).astype(int),pdata],axis=-1) 34 | for i in range(T_all-(T_in+T_out)+1): 35 | X.append(pdata[:, i:i+T_in, :]) 36 | Y.append(rawdata[:, i+T_in:i+(T_in+T_out), :1]) 37 | X = torch.from_numpy(np.asarray(X)).float() 38 | Y = torch.from_numpy(np.asarray(Y)).float().squeeze(-1) 39 | print('X shape',X.shape) 40 | print('Y shape',Y.shape) 41 | return GetDataset(X, Y) 42 | 43 | def adj_process(adj,train_num,topk,disteps): 44 | """ 45 | return: sparse CxtConv and sparse PropConv adj 46 | """ 47 | # sparse context graph adj (2,E) 48 | edge_1 = [] 49 | edge_2 = [] 50 | for i in range(adj.shape[0]): 51 | for j in range(adj.shape[1]): 52 | if(i==j or (adj[i,j]<=disteps)): 53 | edge_1.append(i) 54 | edge_2.append(j) 55 | edge_adj = np.asarray([edge_1,edge_2],dtype=int) 56 | 57 | # sparse propagating adj (2,E) 58 | edge_1 = [] 59 | edge_2 = [] 60 | for i in range(adj.shape[0]): 61 | cnt = 0 62 | adj_row = adj[i,:train_num] 63 | adj_row = sorted(enumerate(adj_row), key=lambda x:x[1]) # [(idx,dis),...] 64 | for j,dis in adj_row: 65 | if(i!=j): 66 | edge_1.append(i) 67 | edge_2.append(j) 68 | cnt += 1 69 | if(cnt >= topk and dis>disteps): 70 | break 71 | adj_label = np.asarray([edge_1,edge_2],dtype=int) 72 | return edge_adj, adj_label 73 | 74 | def load_data(T_in, T_out, Batch_Size, train_num, topk, disteps): 75 | # adjacency matrix 76 | adj = np.load('../data/adj.npy') 77 | print('adj shape:',adj.shape) # (N, N) 78 | # parking availability dataset (including PA, time and contextual data) 79 | # [normalized_PA, contextual_feature1, contextual_feature2, month, day, hour, minute, dayofweek] 80 | # An example [0.21, 0.14, 0.35, 5, 15, 13, 30, 3] 81 | # Note the normalized_PA is in [0,1]. 82 | padata = np.load('../data/padata.npy') #(N, T_all, F) 83 | N, T_all, _ = padata.shape 84 | print('X shape:',padata.shape) # (N, T_all, F) 85 | # total parking spots 86 | total_park = np.load('../data/total_park.npy') # (N, ) 87 | dataset_train = make_dataset(padata[:,:int(T_all*0.6)],T_in,T_out) 88 | print('len of dataset_train:',len(dataset_train)) 89 | dataset_val = make_dataset(padata[:,int(T_all*0.6):int(T_all*0.8)],T_in,T_out) 90 | print('len of dataset_val:',len(dataset_val)) 91 | dataset_test = make_dataset(padata[:,int(T_all*0.8):],T_in,T_out) 92 | print('len of dataset_test:',len(dataset_test)) 93 | loader_train = DataLoader(dataset=dataset_train, batch_size=Batch_Size, shuffle=True, pin_memory=True,num_workers=1) 94 | loader_val = DataLoader(dataset=dataset_val, batch_size=Batch_Size, shuffle=False, pin_memory=True,num_workers=1) 95 | loader_test = DataLoader(dataset=dataset_test, batch_size=Batch_Size, shuffle=False, pin_memory=True,num_workers=1) 96 | idx_train = range(0,train_num) # labeled parking lots 97 | idx_val = range(train_num, N) # unlabeled parking lots 98 | # adj process 99 | adjs = adj_process(adj,train_num,topk,disteps) # return CxtConv and PropConv adj 100 | print("load_data finished.") 101 | return adjs,loader_train,loader_val,loader_test,idx_train,idx_val,total_park.reshape(1,N,1) 102 | 103 | 104 | 105 | --------------------------------------------------------------------------------