├── README.md
├── figs
├── framework.png
└── scconv.png
├── hierarchical_graph_conv.py
├── share.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction
2 | This is a Pytorch implementation of SHARE architecture as described in the paper [Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction](https://arxiv.org/pdf/1911.10516).
3 |
4 |
5 |
6 |
7 |
8 | If you take advantage of the SHARE model in your research, please cite the following:
9 |
10 | ```
11 | @inproceedings{zhang2019semi,
12 | title={Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction},
13 | author={Zhang, Weijia and Liu, Hao and Liu, Yanchi and Zhou, Jingbo and Xiong, Hui},
14 | booktitle={Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence},
15 | pages={1186--1193},
16 | year={2020}
17 | }
18 | ```
19 |
20 | # Requirements
21 | This code is based on Python3 (>= 3.6). There are a few dependencies to run the code. The major libraries are listed as follows:
22 | * Pytorch (0.4.1)
23 | * [dgl](https://github.com/dmlc/dgl) (0.4.1)
24 |
25 |
26 |
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willzhang3/SHARE-parking_availability_prediction-Pytorch/dccfc82200a468529455ca16b6690156c9bbc3f7/figs/framework.png
--------------------------------------------------------------------------------
/figs/scconv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/willzhang3/SHARE-parking_availability_prediction-Pytorch/dccfc82200a468529455ca16b6690156c9bbc3f7/figs/scconv.png
--------------------------------------------------------------------------------
/hierarchical_graph_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import dgl
8 |
9 | """CxtConv and PropConv"""
10 | class SpGraphAttentionLayer(nn.Module):
11 | def __init__(self, in_features, out_features, dropout, alpha, ret_adj=False, pa_prop=False):
12 | super(SpGraphAttentionLayer, self).__init__()
13 | self.in_features = in_features
14 | self.out_features = out_features
15 | self.pa_prop = pa_prop
16 | self.w_key = nn.Linear(in_features, out_features, bias=True)
17 | self.w_value = nn.Linear(in_features, out_features, bias=True)
18 | self.leakyrelu = nn.LeakyReLU(alpha)
19 | self.cosinesimilarity = nn.CosineSimilarity(dim=-1, eps=1e-8)
20 |
21 | def edge_attention(self, edges):
22 | # edge UDF
23 | # dot-product attention
24 | att_sim = torch.sum(torch.mul(edges.src['h_key'], edges.dst['h_key']),dim=-1)
25 | # att_sim = self.cosinesimilarity(edges.src['h_key'], edges.dst['h_key'])
26 | return {'att_sim': att_sim}
27 |
28 | def message_func(self, edges):
29 | # message UDF
30 | return {'h_value': edges.src['h_value'], 'att_sim': edges.data['att_sim']}
31 |
32 | def reduce_func(self, nodes):
33 | # reduce UDF
34 | alpha = F.softmax(nodes.mailbox['att_sim'], dim=1) # (# of nodes, # of neibors)
35 | alpha = alpha.unsqueeze(-1)
36 | h_att = torch.sum(alpha * nodes.mailbox['h_value'], dim=1)
37 | return {'h_att': h_att}
38 |
39 | def forward(self, X_key, X_value, g):
40 | """
41 | :param X_key: X_key data of shape (batch_size(B), num_nodes(N), in_features_1).
42 | :param X_value: X_value dasta of shape (batch_size, num_nodes(N), in_features_2).
43 | :param g: sparse graph.
44 | :return: Output data of shape (batch_size, num_nodes(N), out_features).
45 | """
46 | B,N,in_features = X_key.size()
47 | h_key = self.w_key(X_key) # (B,N,out_features)
48 | h_key = h_key.view(B*N,-1) # (B*N,out_features)
49 | h_value = X_value if(self.pa_prop == True) else self.w_value(X_value)
50 | h_value = h_value.view(B*N,-1)
51 | g.ndata['h_key'] = h_key
52 | g.ndata['h_value']= h_value
53 | g.apply_edges(self.edge_attention)
54 | g.update_all(self.message_func, self.reduce_func)
55 | h_att = g.ndata.pop('h_att').view(B,N,-1) # (B,N,out_features)
56 | h_conv = h_att if(self.pa_prop == True) else self.leakyrelu(h_att)
57 | return h_conv
58 |
59 | class GAT(nn.Module):
60 | def __init__(self, in_feat, nhid=32, dropout=0, alpha=0.2, hopnum=2, pa_prop=False):
61 | """sparse GAT."""
62 | super(GAT, self).__init__()
63 | self.pa_prop = pa_prop
64 | self.dropout = nn.Dropout(dropout)
65 | if(pa_prop == True): hopnum = 1
66 | print('hopnum_gat:',hopnum)
67 | self.gat_stacks = nn.ModuleList()
68 | for i in range(hopnum):
69 | if(i > 0): in_feat = nhid
70 | att_layer = SpGraphAttentionLayer(in_feat, nhid, dropout=dropout, alpha=alpha, pa_prop=pa_prop)
71 | self.gat_stacks.append(att_layer)
72 |
73 | def forward(self, X_key, X_value, adj):
74 | out = X_key
75 | for att_layer in self.gat_stacks:
76 | if(self.pa_prop == True):
77 | out = att_layer(out, X_value, adj)
78 | else:
79 | out = att_layer(out, out, adj)
80 | return out
81 |
82 | """SCConv"""
83 | class SCConv(nn.Module):
84 | def __init__(self, in_features, out_features, dropout, alpha, latend_num, gcn_hop):
85 | super(SCConv, self).__init__()
86 | self.in_features = in_features
87 | self.dropout = nn.Dropout(dropout)
88 | self.leakyrelu = nn.LeakyReLU(alpha)
89 | self.conv_block_after_pool = GCN(in_features=self.in_features, out_features=out_features, \
90 | dropout=dropout, alpha=alpha, hop = gcn_hop)
91 | self.w_classify = nn.Linear(self.in_features, latend_num, bias=True)
92 |
93 | def apply_bn(self, x):
94 | # Batch normalization of 3D tensor x
95 | bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
96 | x = bn_module(x)
97 | return x
98 |
99 | def forward(self, X_lots, adj):
100 | """
101 | :param X_lots: Concat of the outputs of CxtConv and PA_approximation (batch_size, N, in_features).
102 | :param adj: adj_merge (N, N).
103 | :return: Output soft clustering representation for each parking lot of shape (batch_size, N, out_features).
104 | """
105 | B, N, in_features = X_lots.size()
106 | h_now = self.dropout(X_lots) # (B, N, F)
107 | S = self.w_classify(h_now) # (B, N, latend_num(K))
108 | S = F.softmax(S,dim=-1) # (B, N, K)
109 | h_c = torch.bmm(S.permute(0,2,1),h_now) # (B, K, F)
110 | h_c = self.apply_bn(h_c)
111 | adj = torch.bmm(torch.bmm(S.permute(0,2,1),adj),S) # (B, K, K)
112 | # GCN
113 | h_latent = self.dropout(self.conv_block_after_pool(h_c,adj)) # (B, K, F)
114 | h_sc = torch.bmm(S,h_latent)
115 | return h_sc
116 |
117 | class GCN(nn.Module):
118 | def __init__(self, in_features, out_features, dropout, alpha, hop = 1):
119 | super(GCN, self).__init__()
120 | self.in_features = in_features
121 | self.hop = hop
122 | self.dropout = nn.Dropout(dropout)
123 | self.leakyrelu = nn.LeakyReLU(alpha)
124 | self.w_lot = nn.ModuleList()
125 | for i in range(hop):
126 | in_features = (self.in_features) if(i==0) else out_features
127 | self.w_lot.append(nn.Linear(in_features, out_features, bias=True))
128 |
129 | def forward(self, h_c, adj):
130 | # adj normalize
131 | adj_rowsum = torch.sum(adj,dim=-1,keepdim=True)
132 | adj = adj.div(torch.where(adj_rowsum>1e-8, adj_rowsum, 1e-8*torch.ones(1,1).cuda())) # row normalize
133 | # weight aggregate
134 | for i in range(self.hop):
135 | h_c = torch.bmm(adj,h_c)
136 | h_c = self.leakyrelu(self.w_lot[i](h_c)) #(B, N, F)
137 | return h_c
138 |
139 |
--------------------------------------------------------------------------------
/share.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import time
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from hierarchical_graph_conv import GAT, SCConv
7 |
8 | class FeatureEmb(nn.Module):
9 | def __init__(self):
10 | super(FeatureEmb, self).__init__()
11 | # time embedding
12 | # month,day,hour,minute,dayofweek
13 | self.time_emb = nn.ModuleList([nn.Embedding(feature_size, 4) for feature_size in [12,31,24,4,7]])
14 | for ele in self.time_emb:
15 | nn.init.xavier_uniform_(ele.weight.data, gain=math.sqrt(2.0))
16 |
17 | def forward(self, X, pa_onehot):
18 | B, N, T_in, F = X.size() # (batch_size, N, T_in, F)
19 | X_time =torch.cat([emb(X[:,:,:,i+4].long()) for i,emb in enumerate(self.time_emb)],dim=-1) # time F = 4*5 = 20
20 | X_cxt = X[...,2:4] # contextual features
21 | X_pa = X[...,:1].long() # PA, 0,1,...,49
22 | pa_scatter = pa_onehot.clone()
23 | X_pa = pa_scatter.scatter_(-1,X_pa,1.0) # discretize to one-hot , F = 50
24 | return X_cxt, X_pa, X_time
25 |
26 | class SHARE(nn.Module):
27 | def __init__(self,args,t_in,t_out,latend_num,train_num,dropout=0.5,alpha=0.2,hid_dim=32,\
28 | gat_hop=2,device=torch.device('cuda')):
29 | super(SHARE, self).__init__()
30 | self.device = device
31 | # number of context features (here set 2 for test)
32 | self.nfeat = 2
33 | self.hid_dim = hid_dim
34 | self.train_num = train_num
35 | # Feature embedding
36 | self.feature_embedding = FeatureEmb()
37 | # FC layers
38 | self.output_fc = nn.Linear(hid_dim*2, t_out, bias=True)
39 | self.w_pred = nn.Linear(hid_dim*2, 50, bias=True)
40 | self.leakyrelu = nn.LeakyReLU(alpha)
41 |
42 | # Spatial blocks
43 | # CxtConv
44 | self.CxtConv = GAT(in_feat=self.nfeat, nhid=hid_dim, dropout=dropout, alpha=alpha, hopnum=gat_hop, pa_prop=False)
45 | # PropConv
46 | self.PropConv = GAT(in_feat=self.nfeat, nhid=hid_dim, dropout=dropout, alpha=alpha, hopnum=1, pa_prop=True)
47 | # SCConv
48 | self.SCConv = SCConv(in_features=hid_dim+50, out_features=hid_dim, dropout=dropout,\
49 | alpha=alpha, latend_num=latend_num, gcn_hop = 1)
50 |
51 | # GRU Cell
52 | self.GRU = nn.GRUCell(2*hid_dim+50+20, hid_dim*2, bias=True)
53 | nn.init.xavier_uniform_(self.GRU.weight_ih,gain=math.sqrt(2.0))
54 | nn.init.xavier_uniform_(self.GRU.weight_hh,gain=math.sqrt(2.0))
55 |
56 | # Parameter initialization
57 | for ele in self.modules():
58 | if isinstance(ele, nn.Linear):
59 | nn.init.xavier_uniform_(ele.weight,gain=math.sqrt(2.0))
60 |
61 | def forward(self, adjs, X, h_t, pa_onehot):
62 | """
63 | :param adjs: CxtConv, PropConv and SCconv adj.
64 | :param X: Input data of shape (batch_size, num_nodes(N), T_in, num_features(F)).
65 | :param h_t: To init GRU hidden state with shape (N, 2*hid_dim).
66 | :param pa_onehot: be used to discretize y for PA approximation
67 | :return: predicted PA and CE_loss
68 | """
69 | adj,adj_label,adj_dense = adjs
70 | B,N,T,F_feat = X.size()
71 | X_cxt,X_pa,X_time = self.feature_embedding(X, pa_onehot)
72 | # GRU and Spatial blocks
73 | CE_loss = 0.0
74 | for i in range(T):
75 | y_t = F.softmax(self.w_pred(h_t),dim=-1) # (B, N, p=50)
76 | if(i==T-1):
77 | CE_loss += F.binary_cross_entropy(y_t[:,:self.train_num,:].reshape(B*self.train_num,-1),\
78 | X_pa[:,:self.train_num,i,:].reshape(B*self.train_num,-1))
79 | # PropConv
80 | y_att = self.PropConv(X_cxt[:,:,i,:],X_pa[:,:,i,:], adj_label) # (B, N, p=50)
81 | if(i==T-1):
82 | y_att[:,:self.train_num,:] = torch.where(y_att[:,:self.train_num,:]<1.,y_att[:,:self.train_num,:],\
83 | (1.-1e-8)*torch.ones(1,1).cuda())
84 | CE_loss += F.binary_cross_entropy(y_att[:,:self.train_num,:].reshape(B*self.train_num,-1),\
85 | X_pa[:,:self.train_num,i,:].reshape(B*self.train_num,-1))
86 | # PA approximation
87 | en_yt = torch.exp(torch.sum(y_t*torch.log\
88 | (torch.where(y_t>1e-8,y_t,1e-8*torch.ones(1,1).cuda())),dim=-1,keepdim=True))
89 | en_yatt = torch.exp(torch.sum(y_att*torch.log\
90 | (torch.where(y_att>1e-8,y_att,1e-8*torch.ones(1,1).cuda())),dim=-1,keepdim=True))
91 | en_yatt = torch.where(torch.sum(y_att,dim=-1,keepdim=True)>1e-8,en_yatt,torch.zeros(1,1).cuda())
92 | pseudo_y = (en_yt*y_t + en_yatt*y_att)/(en_yt+en_yatt)
93 | if(self.training == False):
94 | pseudo_y[:,:self.train_num,:] = X_pa[:,:self.train_num,i,:]
95 | # CxtConv
96 | h_cxt = self.CxtConv(X_cxt[:,:,i,:],None,adj) # (B, N, tmp_hid)
97 | # SCConv
98 | h_sc = self.SCConv(torch.cat([h_cxt,pseudo_y],dim=-1),adj_dense)
99 | X_feat = torch.cat([h_cxt,pseudo_y,h_sc,X_time[...,i,:]],dim=-1)
100 | h_t = self.GRU(X_feat.view(-1,2*self.hid_dim+50+20), h_t.view(-1,self.hid_dim*2)) # (B*N, 2*tmp_hid)
101 | h_t = h_t.view(B,N,-1)
102 |
103 | out = torch.sigmoid(self.output_fc(h_t)) # (B, N, T_out)
104 | return out, CE_loss
105 |
106 |
107 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import time
8 | import dgl
9 | import torch.nn as nn
10 | import logging
11 | import random
12 | from share import SHARE
13 | from utils import *
14 |
15 | which_gpu = "0"
16 | os.environ["CUDA_VISIBLE_DEVICES"] = which_gpu
17 | parser = argparse.ArgumentParser(description='SHARE')
18 | parser.add_argument('--enable_cuda', action='store_true', default=True, help='GPU training.')
19 | parser.add_argument('--loss', type=str,default='mse', help='Loss function.')
20 | parser.add_argument('--seed', type=int, default=33, help='Random seed.')
21 | parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train.')
22 | parser.add_argument('--batch_size', type=int, default=4, help='Number of batch to train and test.')
23 | parser.add_argument('--t_in', type=int, default=12, help='Input time step.')
24 | parser.add_argument('--t_out', type=int, default=3, help='Output time step.')
25 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
26 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay (L2 loss on parameters).')
27 | parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
28 | parser.add_argument('--hid_dim', type=int, default=32, help='Dim of hidden units.')
29 | parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
30 | parser.add_argument('--hc_ratio', type=float, default=0.1, help='ratio in Hierarchical graph.')
31 | parser.add_argument('--topk', type=int, default=10, help='k nearest neighbors in Propagating graph.')
32 | parser.add_argument('--disteps', type=int, default=1000, help='Farthest neighbors distance in Context graph.')
33 | parser.add_argument('--gat_hop', type=int, default=2, help='Number of hop in CxtConv.')
34 | parser.add_argument('--train_num', type=int, default=-1, help='Number of parking lots for train.')
35 | parser.add_argument('--train_ratio', type=float, default=0.3, help='Parking lots ratio for train.')
36 | parser.add_argument('--patience', type=int, default=30, help='Patience')
37 | parser.add_argument('--beta', type=float, default=0.5, help='Beta of CE_loss .')
38 |
39 | args = parser.parse_args()
40 | args.train_num = int(1965*args.train_ratio+0.5)
41 | logging.basicConfig(level = logging.INFO,filename='./log',format = '%(asctime)s - %(process)s - %(levelname)s - %(message)s')
42 | logger = logging.getLogger(__name__)
43 |
44 | random.seed(args.seed)
45 | np.random.seed(args.seed)
46 | torch.manual_seed(args.seed)
47 | if args.enable_cuda and torch.cuda.is_available():
48 | torch.cuda.manual_seed(args.seed)
49 | args.device = torch.device('cuda')
50 | else:
51 | args.device = torch.device('cpu')
52 | print(args)
53 | logger.info(args)
54 |
55 | def train_epoch(loader_train,h_t,prr_onehot):
56 | """
57 | Trains one epoch with the given data.
58 | :param loader_train: Training data (X,Y).
59 | :return: Average loss of last bacth.
60 | """
61 | for i,(X_batch,Y_batch) in enumerate(loader_train):
62 | net.train()
63 | X_batch = X_batch.to(device=args.device)
64 | Y_batch = Y_batch.to(device=args.device)
65 | now_batch = X_batch.shape[0]
66 | optimizer.zero_grad()
67 | if(now_batch == args.batch_size):
68 | y_pred,prr_loss = net(adjs, X_batch, h_t, prr_onehot) # (B, N , T_out)
69 | else:
70 | adj,adj_label,adj_dense = adjs
71 | adjs_copy = (spadj_expand(g_adj,now_batch),spadj_expand(g_adj_label,now_batch),adj_dense[:now_batch])
72 | y_pred,prr_loss = net(adjs_copy,X_batch,h_t[:now_batch],prr_onehot[:now_batch])
73 | loss = loss_criterion(y_pred[:,idx_train,:], Y_batch[:,idx_train,:]) + prr_loss*args.beta
74 | loss.backward()
75 | optimizer.step()
76 | if (i*args.batch_size % 80 == 0):
77 | print(i*args.batch_size)
78 | print("train loss:{:.4f}".format(loss.detach().cpu().numpy()))
79 | return loss.detach().cpu().numpy()
80 |
81 | def test_epoch(loader_val,h_t,prr_onehot):
82 | """
83 | Test one epoch with the given data.
84 | :param loader_val: Valuation or Test data (X,Y).
85 | :return: Loss and MAE.
86 | """
87 | val_loss = []
88 | val_mae = []
89 | for i,(X_batch,Y_batch) in enumerate(loader_val):
90 | if (i*args.batch_size % 80 == 0):
91 | print(i*args.batch_size)
92 | net.eval()
93 | X_batch = X_batch.to(device=args.device)
94 | Y_batch = Y_batch.to(device=args.device) # (B,N,T_out)
95 | now_batch = X_batch.shape[0]
96 | if(now_batch == args.batch_size):
97 | y_pred,prr_loss = net(adjs, X_batch, h_t, prr_onehot) # (B,N,T_out)
98 | else:
99 | adj,adj_label,adj_dense = adjs
100 | adjs_copy = (spadj_expand(g_adj,now_batch),spadj_expand(g_adj_label,now_batch),adj_dense[:now_batch])
101 | y_pred,prr_loss = net(adjs_copy, X_batch, h_t[:now_batch], prr_onehot[:now_batch])
102 | loss_val = loss_criterion(y_pred[:,idx_val,:], Y_batch[:,idx_val,:]) + prr_loss*args.beta
103 | val_loss.append(np.asscalar(loss_val.detach().cpu().numpy()))
104 | mae = np.absolute(y_pred[:,idx_val,:].detach().cpu().numpy()*total_park[:,idx_val]\
105 | -Y_batch[:,idx_val,:].detach().cpu().numpy()*total_park[:,idx_val]) # (B,N,T_out)
106 | val_mae.append(mae)
107 | return np.asarray(val_loss),np.concatenate(val_mae,axis=0)
108 |
109 | def spadj_expand(adj, batch_size):
110 | adj = dgl.batch([adj]*batch_size)
111 | return adj
112 |
113 | def print_log(mae,mse,loss,stage):
114 | mae_o = [np.mean(mae[:,:,i]) for i in range(args.t_out)]
115 | mse_o = [np.mean(mse[:,:,i]) for i in range(args.t_out)]
116 | rmse_o = [np.sqrt(ele) for ele in mse_o]
117 | stage_str = "{} - mean metrics: mae,mse,rmse,loss".format(stage)
118 | mean_str = "mean metric values: {},{},{},{}".format(np.mean(mae_o),np.mean(mse_o),np.mean(rmse_o),np.mean(loss))
119 | mae_str = "MAE: {}".format(','.join(str(ele) for ele in mae_o))
120 | mse_str = "MSE: {}".format(','.join(str(ele) for ele in mse_o))
121 | rmse_str = "RMSE: {}".format(','.join(str(ele) for ele in rmse_o))
122 | print(stage_str)
123 | print(mean_str)
124 | print(mae_str)
125 | print(mse_str)
126 | print(rmse_str)
127 | logger.info(stage_str)
128 | logger.info(mean_str)
129 | logger.info(mae_str)
130 | logger.info(mse_str)
131 | logger.info(rmse_str)
132 |
133 | if __name__ == '__main__':
134 |
135 | adjs,loader_train,loader_val,loader_test,idx_train,idx_val,total_park = \
136 | load_data(args.t_in,args.t_out,args.batch_size,args.train_num,args.topk,args.disteps)
137 | N = total_park.shape[1] # total number of parking lots
138 | adj,adj_label = adjs
139 | latend_num = int(N*args.hc_ratio+0.5) # latent node number
140 | print('latend num:',latend_num)
141 | adj_edgenum = adj.shape[1]
142 | adj_label_edgenum = adj_label.shape[1]
143 | adj = torch.from_numpy(adj).long()
144 | adj_label = torch.from_numpy(adj_label).long()
145 | # Merge 2 graph as scconv's adj
146 | adj_dense = torch.sparse_coo_tensor(adj,torch.ones((adj.shape[1])),torch.Size([N,N])).to_dense()
147 | adj_dense_label = torch.sparse_coo_tensor(adj_label,torch.ones((adj_label.shape[1])),torch.Size([N,N])).to_dense()
148 | adj_dense = adj_dense + adj_dense_label
149 | adj_dense = torch.where(adj_dense<1e-8,adj_dense,torch.ones(1,1))
150 | adj_merge = adj_dense.to(device=args.device).repeat(args.batch_size,1,1)
151 | g_adj = dgl.DGLGraph()
152 | g_adj.add_nodes(N)
153 | g_adj.add_edges(adj[0],adj[1])
154 | # expand for batch training
155 | adj = spadj_expand(g_adj,args.batch_size)
156 | g_adj_label = dgl.DGLGraph()
157 | g_adj_label.add_nodes(N)
158 | g_adj_label.add_edges(adj_label[0],adj_label[1])
159 | adj_label = spadj_expand(g_adj_label,args.batch_size)
160 | adjs = (adj,adj_label,adj_merge)
161 | # to init GRU hidden state
162 | h_t = torch.zeros(args.batch_size,N,args.hid_dim*2).to(device=args.device)
163 | # to discretize y for PA approximation
164 | prr_onehot = torch.zeros(args.batch_size,N,args.t_in,50).to(device=args.device)
165 |
166 | # model
167 | net = SHARE(args = args,
168 | t_in=args.t_in,
169 | t_out=args.t_out,
170 | latend_num = latend_num,
171 | train_num = args.train_num,
172 | dropout=args.dropout,
173 | alpha=args.alpha,
174 | hid_dim=args.hid_dim,
175 | gat_hop = args.gat_hop,
176 | device=args.device).to(device=args.device)
177 |
178 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
179 |
180 | if(args.loss == 'mse'):
181 | loss_criterion = nn.MSELoss()
182 | elif(args.loss == 'mae'):
183 | loss_criterion = nn.L1Loss()
184 |
185 | best_epoch = 0
186 | min_mse = 1e15
187 | st_epoch = best_epoch
188 | for epoch in range(st_epoch,args.epochs):
189 | st_time = time.time()
190 | # training
191 | print('training......')
192 | loss_train = train_epoch(loader_train,h_t,prr_onehot)
193 | # validating
194 | with torch.no_grad():
195 | print('validating......')
196 | val_loss,val_mae = test_epoch(loader_val,h_t,prr_onehot)
197 | val_mse = val_mae**2
198 | # testing
199 | with torch.no_grad():
200 | print('testing......')
201 | test_loss,test_mae = test_epoch(loader_test,h_t,prr_onehot)
202 | test_mse = test_mae**2
203 |
204 | val_meanmse = np.mean(val_mse)
205 | if(val_meanmse < min_mse):
206 | min_mse = val_meanmse
207 | best_epoch = epoch + 1
208 | best_mae = test_mae.copy()
209 | best_mse = test_mse.copy()
210 | best_loss = test_loss.copy()
211 | # log
212 | try:
213 | print("Epoch: {}".format(epoch+1))
214 | logger.info("Epoch: {}".format(epoch+1))
215 | print("Train loss: {}".format(loss_train))
216 | logger.info("Train loss: {}".format(loss_train))
217 | print_log(val_mae,val_mse,val_loss,'Validation')
218 | print_log(test_mae,test_mse,test_loss,'Test')
219 | print_log(best_mae,best_mse,best_loss,'Best Epoch-{}'.format(best_epoch))
220 | print('time: {:.4f}s'.format(time.time() - st_time))
221 | logger.info('time: {:.4f}s\n'.format(time.time() - st_time))
222 | except:
223 | print("log error...")
224 |
225 | # early stop
226 | if(epoch+1 - best_epoch >= args.patience):
227 | sys.exit(0)
228 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset,DataLoader
4 |
5 | class GetDataset(Dataset):
6 | def __init__(self, X, Y):
7 |
8 | self.X = X # numpy.ndarray (num_data, num_nodes(N), T_in, num_features(F))
9 | self.Y = Y # numpy.ndarray (num_data, num_nodes(N), T_out, num_features(F))
10 |
11 | def __getitem__(self, index):
12 |
13 | # torch.Tensor
14 | tensor_X = self.X[index]
15 | tensor_Y = self.Y[index]
16 |
17 | return tensor_X, tensor_Y
18 |
19 | def __len__(self):
20 | return len(self.X)
21 |
22 | def make_dataset(rawdata, T_in, T_out):
23 |
24 | """
25 | :input: rawdata (num_nodes(N), T, F1)
26 | :return X: (num_data, num_nodes(N), T_in, F2)
27 | :return Y: (num_data, num_nodes(N), T_out)
28 | """
29 | X,Y = [],[]
30 | T_all = rawdata.shape[1]
31 | pdata = rawdata.copy()
32 | # be used to discretize y for PA approximation (y has been normalized to [0,1])
33 | pdata = np.concatenate([((pdata[:,:,:1]-1e-8)*50).astype(int),pdata],axis=-1)
34 | for i in range(T_all-(T_in+T_out)+1):
35 | X.append(pdata[:, i:i+T_in, :])
36 | Y.append(rawdata[:, i+T_in:i+(T_in+T_out), :1])
37 | X = torch.from_numpy(np.asarray(X)).float()
38 | Y = torch.from_numpy(np.asarray(Y)).float().squeeze(-1)
39 | print('X shape',X.shape)
40 | print('Y shape',Y.shape)
41 | return GetDataset(X, Y)
42 |
43 | def adj_process(adj,train_num,topk,disteps):
44 | """
45 | return: sparse CxtConv and sparse PropConv adj
46 | """
47 | # sparse context graph adj (2,E)
48 | edge_1 = []
49 | edge_2 = []
50 | for i in range(adj.shape[0]):
51 | for j in range(adj.shape[1]):
52 | if(i==j or (adj[i,j]<=disteps)):
53 | edge_1.append(i)
54 | edge_2.append(j)
55 | edge_adj = np.asarray([edge_1,edge_2],dtype=int)
56 |
57 | # sparse propagating adj (2,E)
58 | edge_1 = []
59 | edge_2 = []
60 | for i in range(adj.shape[0]):
61 | cnt = 0
62 | adj_row = adj[i,:train_num]
63 | adj_row = sorted(enumerate(adj_row), key=lambda x:x[1]) # [(idx,dis),...]
64 | for j,dis in adj_row:
65 | if(i!=j):
66 | edge_1.append(i)
67 | edge_2.append(j)
68 | cnt += 1
69 | if(cnt >= topk and dis>disteps):
70 | break
71 | adj_label = np.asarray([edge_1,edge_2],dtype=int)
72 | return edge_adj, adj_label
73 |
74 | def load_data(T_in, T_out, Batch_Size, train_num, topk, disteps):
75 | # adjacency matrix
76 | adj = np.load('../data/adj.npy')
77 | print('adj shape:',adj.shape) # (N, N)
78 | # parking availability dataset (including PA, time and contextual data)
79 | # [normalized_PA, contextual_feature1, contextual_feature2, month, day, hour, minute, dayofweek]
80 | # An example [0.21, 0.14, 0.35, 5, 15, 13, 30, 3]
81 | # Note the normalized_PA is in [0,1].
82 | padata = np.load('../data/padata.npy') #(N, T_all, F)
83 | N, T_all, _ = padata.shape
84 | print('X shape:',padata.shape) # (N, T_all, F)
85 | # total parking spots
86 | total_park = np.load('../data/total_park.npy') # (N, )
87 | dataset_train = make_dataset(padata[:,:int(T_all*0.6)],T_in,T_out)
88 | print('len of dataset_train:',len(dataset_train))
89 | dataset_val = make_dataset(padata[:,int(T_all*0.6):int(T_all*0.8)],T_in,T_out)
90 | print('len of dataset_val:',len(dataset_val))
91 | dataset_test = make_dataset(padata[:,int(T_all*0.8):],T_in,T_out)
92 | print('len of dataset_test:',len(dataset_test))
93 | loader_train = DataLoader(dataset=dataset_train, batch_size=Batch_Size, shuffle=True, pin_memory=True,num_workers=1)
94 | loader_val = DataLoader(dataset=dataset_val, batch_size=Batch_Size, shuffle=False, pin_memory=True,num_workers=1)
95 | loader_test = DataLoader(dataset=dataset_test, batch_size=Batch_Size, shuffle=False, pin_memory=True,num_workers=1)
96 | idx_train = range(0,train_num) # labeled parking lots
97 | idx_val = range(train_num, N) # unlabeled parking lots
98 | # adj process
99 | adjs = adj_process(adj,train_num,topk,disteps) # return CxtConv and PropConv adj
100 | print("load_data finished.")
101 | return adjs,loader_train,loader_val,loader_test,idx_train,idx_val,total_park.reshape(1,N,1)
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------