├── .gitattributes ├── data └── readme.txt ├── README.md ├── LICENSE ├── dataset.py ├── train.py └── model.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/1I_vpbLJhOJpNh-TpLdSWsaG3xCpzMVSQ/view?usp=sharing -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAGNN 2 | 3 | The source code of Group-Aware Graph Neural Network for Nationwide City Air Quality Forecasting. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Friger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | 5 | import torch 6 | import torch.utils.data as Data 7 | 8 | 9 | path = './data' 10 | 11 | class trainDataset(Data.Dataset): 12 | def __init__(self, transform=None, train=True): 13 | self.x = np.load(os.path.join(path,'train_x.npy'),allow_pickle=True) 14 | self.u = np.load(os.path.join(path,'train_u.npy'),allow_pickle=True) 15 | self.y = np.load(os.path.join(path,'train_y.npy'),allow_pickle=True) 16 | self.edge_w = np.load(os.path.join(path,'edge_w.npy'),allow_pickle=True) 17 | self.edge_index = np.load(os.path.join(path,'edge_index.npy'),allow_pickle=True) 18 | self.loc = np.load(os.path.join(path,'loc_filled.npy'),allow_pickle=True) 19 | self.loc = self.loc.astype(np.float) 20 | 21 | 22 | def __getitem__(self, index): 23 | x = torch.FloatTensor(self.x[index]) 24 | x = x.transpose(0,1) 25 | y = torch.FloatTensor(self.y[index]) 26 | y = y.transpose(0,1) 27 | u = torch.tensor(self.u[index]) 28 | edge_index = torch.tensor(self.edge_index) 29 | # edge_index = edge_index.expand((x.size[0],edge_index.size[0],edge_index.size[1])) 30 | edge_w = torch.FloatTensor(self.edge_w) 31 | # edge_w = edge_w.expand((x.size[0],edge_w.size[0])) 32 | loc = torch.FloatTensor(self.loc) 33 | 34 | return [x,u,y,edge_index,edge_w,loc] 35 | 36 | def __len__(self): 37 | return self.x.shape[0] 38 | 39 | class valDataset(Data.Dataset): 40 | def __init__(self, transform=None, train=True): 41 | self.x = np.load(os.path.join(path,'val_x.npy'),allow_pickle=True) 42 | self.u = np.load(os.path.join(path,'val_u.npy'),allow_pickle=True) 43 | self.y = np.load(os.path.join(path,'val_y.npy'),allow_pickle=True) 44 | self.edge_w = np.load(os.path.join(path,'edge_w.npy'),allow_pickle=True) 45 | self.edge_index = np.load(os.path.join(path,'edge_index.npy'),allow_pickle=True) 46 | self.loc = np.load(os.path.join(path,'loc_filled.npy'),allow_pickle=True) 47 | self.loc = self.loc.astype(np.float) 48 | 49 | 50 | def __getitem__(self, index): 51 | x = torch.FloatTensor(self.x[index]) 52 | x = x.transpose(0,1) 53 | y = torch.FloatTensor(self.y[index]) 54 | y = y.transpose(0,1) 55 | u = torch.tensor(self.u[index]) 56 | edge_index = torch.tensor(self.edge_index) 57 | # edge_index = edge_index.expand((x.size[0],edge_index.size[0],edge_index.size[1])) 58 | edge_w = torch.FloatTensor(self.edge_w) 59 | # edge_w = edge_w.expand((x.size[0],edge_w.size[0])) 60 | loc = torch.FloatTensor(self.loc) 61 | 62 | return [x,u,y,edge_index,edge_w,loc] 63 | 64 | def __len__(self): 65 | return self.x.shape[0] 66 | 67 | class testDataset(Data.Dataset): 68 | def __init__(self, transform=None, train=True): 69 | self.x = np.load(os.path.join(path,'test_x.npy'),allow_pickle=True) 70 | self.u = np.load(os.path.join(path,'test_u.npy'),allow_pickle=True) 71 | self.y = np.load(os.path.join(path,'test_y.npy'),allow_pickle=True) 72 | self.edge_w = np.load(os.path.join(path,'edge_w.npy'),allow_pickle=True) 73 | self.edge_index = np.load(os.path.join(path,'edge_index.npy'),allow_pickle=True) 74 | self.loc = np.load(os.path.join(path,'loc_filled.npy'),allow_pickle=True) 75 | self.loc = self.loc.astype(np.float) 76 | 77 | 78 | def __getitem__(self, index): 79 | x = torch.FloatTensor(self.x[index]) 80 | x = x.transpose(0,1) 81 | y = torch.FloatTensor(self.y[index]) 82 | y = y.transpose(0,1) 83 | u = torch.tensor(self.u[index]) 84 | edge_index = torch.tensor(self.edge_index) 85 | # edge_index = edge_index.expand((x.size[0],edge_index.size[0],edge_index.size[1])) 86 | edge_w = torch.FloatTensor(self.edge_w) 87 | # edge_w = edge_w.expand((x.size[0],edge_w.size[0])) 88 | loc = torch.FloatTensor(self.loc) 89 | 90 | return [x,u,y,edge_index,edge_w,loc] 91 | 92 | def __len__(self): 93 | return self.x.shape[0] 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as Data 8 | from sklearn.cluster import KMeans 9 | from torch_geometric.nn import MetaLayer 10 | 11 | from model import Model 12 | from dataset import trainDataset,valDataset,testDataset 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description='Multi-city AQI forecasting') 16 | parser.add_argument('--device',type=str,default='cuda',help='') 17 | parser.add_argument('--mode',type=str,default='full',help='') 18 | parser.add_argument('--encoder',type=str,default='self',help='') 19 | parser.add_argument('--w_init',type=str,default='rand',help='') 20 | parser.add_argument('--mark',type=str,default='',help='') 21 | parser.add_argument('--run_times',type=int,default=5,help='') 22 | parser.add_argument('--epoch',type=int,default=300,help='') 23 | parser.add_argument('--batch_size',type=int,default=64,help='') 24 | parser.add_argument('--w_rate',type=int,default=50,help='') 25 | parser.add_argument('--city_num',type=int,default=209,help='') 26 | parser.add_argument('--group_num',type=int,default=15,help='') 27 | parser.add_argument('--gnn_h',type=int,default=32,help='') 28 | parser.add_argument('--gnn_layer',type=int,default=2,help='') 29 | parser.add_argument('--x_em',type=int,default=32,help='x embedding') 30 | parser.add_argument('--date_em',type=int,default=4,help='date embedding') 31 | parser.add_argument('--loc_em',type=int,default=12,help='loc embedding') 32 | parser.add_argument('--edge_h',type=int,default=12,help='edge h') 33 | parser.add_argument('--lr',type=float,default=0.001,help='lr') 34 | parser.add_argument('--wd',type=float,default=0.001,help='weight decay') 35 | parser.add_argument('--pred_step',type=int,default=6,help='step') 36 | args = parser.parse_args() 37 | print(args) 38 | 39 | train_dataset = trainDataset() 40 | val_dataset = valDataset() 41 | test_dataset = testDataset() 42 | print(len(train_dataset)+len(val_dataset)+len(test_dataset)) 43 | train_loader = Data.DataLoader(train_dataset, batch_size=args.batch_size, 44 | shuffle=True, num_workers=8, pin_memory=True) 45 | val_loader = Data.DataLoader(val_dataset, batch_size=args.batch_size, 46 | shuffle=False, num_workers=8, pin_memory=True) 47 | test_loader = Data.DataLoader(test_dataset, batch_size=args.batch_size, 48 | shuffle=False, num_workers=8, pin_memory=True) 49 | device = args.device 50 | # city_index = [0,2,30,32,43] 51 | path = './data' 52 | 53 | 54 | for _ in range(args.run_times): 55 | start = time.time() 56 | 57 | w = None 58 | if args.w_init == 'group': 59 | city_loc = np.load(os.path.join(path,'loc_filled.npy'),allow_pickle=True) 60 | kmeans = KMeans(n_clusters=args.group_num, random_state=0).fit(city_loc) 61 | group_list = kmeans.labels_.tolist() 62 | w = np.random.randn(args.city_num,args.group_num) 63 | w = w * 0.1 64 | for i in range(len(group_list)): 65 | w[i,group_list[i]] = 1.0 66 | w = torch.FloatTensor(w).to(device,non_blocking=True) 67 | 68 | city_model = Model(args.mode,args.encoder,args.w_init,w,args.x_em,args.date_em,args.loc_em,args.edge_h,args.gnn_h, 69 | args.gnn_layer,args.city_num,args.group_num,args.pred_step,device).to(device) 70 | city_num = sum(p.numel() for p in city_model.parameters() if p.requires_grad) 71 | print('city_model:', 'Trainable,', city_num) 72 | # print(city_model) 73 | criterion = nn.L1Loss(reduction = 'sum') 74 | all_params = city_model.parameters() 75 | w_params = [] 76 | other_params = [] 77 | for pname, p in city_model.named_parameters(): 78 | if pname == 'w': 79 | w_params += [p] 80 | params_id = list(map(id, w_params)) 81 | other_params = list(filter(lambda p: id(p) not in params_id, all_params)) 82 | # print(len(w_params),len(other_params)) 83 | optimizer = torch.optim.Adam([ 84 | {'params': other_params}, 85 | {'params': w_params, 'lr': args.lr * args.w_rate} 86 | ], lr=args.lr, weight_decay=args.wd) 87 | 88 | val_loss_min = np.inf 89 | for epoch in range(args.epoch): 90 | for i,data in enumerate(train_loader): 91 | data = [item.to(device,non_blocking=True) for item in data] 92 | x,u,y,edge_index,edge_w,loc = data 93 | outputs = city_model(x,u,edge_index,edge_w,loc) 94 | loss = criterion(y,outputs) 95 | city_model.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | 99 | if epoch % 10 == 0 and i % 100 == 0: 100 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 101 | .format(epoch, args.epoch, i, int(len(train_dataset)/args.batch_size), loss.item())) 102 | 103 | if epoch % 5 == 0: 104 | with torch.no_grad(): 105 | val_loss = 0 106 | for j, data_val in enumerate(val_loader): 107 | data_val = [item.to(device,non_blocking=True) for item in data_val] 108 | x_val,u_val,y_val,edge_index_val,edge_w_val,loc_val = data_val 109 | outputs_val = city_model(x_val,u_val,edge_index_val,edge_w_val,loc_val) 110 | batch_loss = criterion(y_val,outputs_val) 111 | val_loss += batch_loss.item() 112 | print('Epoch:',epoch,', val_loss:',val_loss) 113 | if val_loss < val_loss_min: 114 | torch.save(city_model.state_dict(),args.encoder+'_para_'+args.mark+'.ckpt') 115 | val_loss_min = val_loss 116 | print('parameters have been updated during epoch ',epoch) 117 | 118 | mae_loss = torch.zeros(args.city_num,args.pred_step).to(device) 119 | rmse_loss = torch.zeros(args.city_num,args.pred_step).to(device) 120 | 121 | def cal_loss(outputs,y): 122 | global mae_loss, rmse_loss 123 | temp_loss = torch.abs(outputs-y) 124 | mae_loss = torch.add(mae_loss,temp_loss.sum(dim=0)) 125 | 126 | temp_loss = torch.pow(temp_loss,2) 127 | rmse_loss = torch.add(rmse_loss,temp_loss.sum(dim=0)) 128 | 129 | 130 | with torch.no_grad(): 131 | city_model.load_state_dict(torch.load(args.encoder+'_para_'+args.mark+'.ckpt')) 132 | w_weight = city_model.state_dict()['w'] 133 | w_weight = F.softmax(w_weight) 134 | _,w_weight = torch.max(w_weight,dim=-1) 135 | print(w_weight.cpu().tolist()) 136 | 137 | for i, data in enumerate(test_loader): 138 | data = [item.to(device,non_blocking=True) for item in data] 139 | x,u,y,edge_index,edge_w,loc = data 140 | outputs = city_model(x,u,edge_index,edge_w,loc) 141 | cal_loss(outputs,y) 142 | 143 | mae_loss = mae_loss/(len(test_dataset)) 144 | rmse_loss = rmse_loss/(len(test_dataset)) 145 | mae_loss = mae_loss.mean(dim=0) 146 | rmse_loss = rmse_loss.mean(dim=0) 147 | 148 | end = time.time() 149 | print('Running time: %s Seconds'%(end-start)) 150 | 151 | mae_loss = mae_loss.cpu() 152 | rmse_loss = rmse_loss.cpu() 153 | 154 | print('mae:', np.array(mae_loss)) 155 | print('rmse:', np.sqrt(np.array(rmse_loss))) 156 | 157 | for i, data in enumerate(Data.DataLoader(test_dataset, batch_size=1,shuffle=False, pin_memory=True)): 158 | data = [item.to(device,non_blocking=True) for item in data] 159 | x,u,y,edge_index,edge_w,loc = data 160 | outputs = city_model(x,u,edge_index,edge_w,loc) 161 | if i == 305: 162 | print(x[:,0]) 163 | print(outputs[:,0]) 164 | 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn import Sequential as Seq, Linear as Lin, ReLU 6 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 7 | from torch.nn.parameter import Parameter 8 | from torch_scatter import scatter_mean 9 | from torch_geometric.nn import MetaLayer 10 | 11 | 12 | TIME_WINDOW = 24 13 | PRED_LEN = 6 14 | 15 | class Model(nn.Module): 16 | def __init__(self,mode,encoder,w_init,w,x_em,date_em,loc_em,edge_h,gnn_h,gnn_layer,city_num,group_num,pred_step,device): 17 | super(Model, self).__init__() 18 | self.device = device 19 | self.mode = mode 20 | self.encoder = encoder 21 | self.w_init = w_init 22 | self.city_num = city_num 23 | self.group_num = group_num 24 | self.edge_h = edge_h 25 | self.gnn_layer = gnn_layer 26 | self.pred_step = pred_step 27 | if self.encoder == 'self': 28 | self.encoder_layer = TransformerEncoderLayer(8, nhead=4, dim_feedforward=256) 29 | # self.x_embed = Lin(8, x_em) 30 | self.x_embed = Lin(TIME_WINDOW*8, x_em) 31 | 32 | elif self.encoder == 'lstm': 33 | self.input_LSTM = nn.LSTM(8,x_em,num_layers=1,batch_first=True) 34 | if self.w_init == 'rand': 35 | self.w = Parameter(torch.randn(city_num,group_num).to(device,non_blocking=True),requires_grad=True) 36 | elif self.w_init == 'group': 37 | self.w = Parameter(w,requires_grad=True) 38 | self.loc_embed = Lin(2, loc_em) 39 | self.u_embed1 = nn.Embedding(12, date_em) #month 40 | self.u_embed2 = nn.Embedding(7, date_em) #week 41 | self.u_embed3 = nn.Embedding(24, date_em) #hour 42 | self.edge_inf = Seq(Lin(x_em*2+date_em*3+loc_em*2,edge_h),ReLU(inplace=True)) 43 | self.group_gnn = nn.ModuleList([NodeModel(x_em+loc_em,edge_h,gnn_h)]) 44 | for i in range(self.gnn_layer-1): 45 | self.group_gnn.append(NodeModel(gnn_h,edge_h,gnn_h)) 46 | self.global_gnn = nn.ModuleList([NodeModel(x_em+gnn_h,1,gnn_h)]) 47 | for i in range(self.gnn_layer-1): 48 | self.global_gnn.append(NodeModel(gnn_h,1,gnn_h)) 49 | if self.mode == 'ag': 50 | self.decoder = DecoderModule(x_em,edge_h,gnn_h,gnn_layer,city_num,group_num,device) 51 | self.predMLP = Seq(Lin(gnn_h,16),ReLU(inplace=True),Lin(16,1),ReLU(inplace=True)) 52 | if self.mode == 'full': 53 | self.decoder = DecoderModule(x_em,edge_h,gnn_h,gnn_layer,city_num,group_num,device) 54 | self.predMLP = Seq(Lin(gnn_h,16),ReLU(inplace=True),Lin(16,self.pred_step),ReLU(inplace=True)) 55 | 56 | def batchInput(self,x,edge_w,edge_index): 57 | sta_num = x.shape[1] 58 | x = x.reshape(-1,x.shape[-1]) 59 | edge_w = edge_w.reshape(-1,edge_w.shape[-1]) 60 | for i in range(edge_index.size(0)): 61 | edge_index[i,:] = torch.add(edge_index[i,:], i*sta_num) 62 | # print(edge_index.shape) 63 | edge_index = edge_index.transpose(0,1) 64 | # print(edge_index.shape) 65 | edge_index = edge_index.reshape(2,-1) 66 | return x, edge_w, edge_index 67 | 68 | def forward(self,x,u,edge_index,edge_w,loc): 69 | x = x.reshape(-1,x.shape[2],x.shape[3]) 70 | if self.encoder == 'self': 71 | # [S,B,E] 72 | # print(x.shape) 73 | x = x.transpose(0,1) 74 | x = self.encoder_layer(x) 75 | x = x.transpose(0,1) 76 | # print(x.shape) 77 | x = x.reshape(-1,self.city_num,TIME_WINDOW*x.shape[-1]) 78 | x = self.x_embed(x) 79 | # x = x.reshape(-1,self.city_num,TIME_WINDOW,x.shape[-1]) 80 | # x = torch.max(x,dim=-2).values 81 | # print(x.shape) 82 | elif self.encoder == 'lstm': 83 | _,(x,_) = self.input_LSTM(x) 84 | x = x.reshape(-1,self.city_num,x.shape[-1]) 85 | # print(x.shape) 86 | # print(x.shape) 87 | 88 | # graph pooling 89 | # print(self.w[10]) 90 | w = F.softmax(self.w) 91 | w1 = w.transpose(0,1) 92 | w1 = w1.unsqueeze(dim=0) 93 | w1 = w1.repeat_interleave(x.size(0), dim=0) 94 | # print(w.shape,x.shape) 95 | # print(loc.shape) 96 | loc = self.loc_embed(loc) 97 | x_loc = torch.cat([x,loc],dim=-1) 98 | g_x = torch.bmm(w1,x_loc) 99 | # print(g_x.shape) 100 | 101 | # group gnn 102 | u_em1 = self.u_embed1(u[:,0]) 103 | u_em2 = self.u_embed2(u[:,1]) 104 | u_em3 = self.u_embed3(u[:,2]) 105 | u_em = torch.cat([u_em1,u_em2,u_em3],dim=-1) 106 | # print(u_em.shape) 107 | for i in range(self.group_num): 108 | for j in range(self.group_num): 109 | if i == j: continue 110 | g_edge_input = torch.cat([g_x[:,i],g_x[:,j],u_em],dim=-1) 111 | tmp_g_edge_w = self.edge_inf(g_edge_input) 112 | tmp_g_edge_w = tmp_g_edge_w.unsqueeze(dim=0) 113 | tmp_g_edge_index = torch.tensor([i,j]).unsqueeze(dim=0).to(self.device,non_blocking=True) 114 | if i == 0 and j == 1: 115 | g_edge_w = tmp_g_edge_w 116 | g_edge_index = tmp_g_edge_index 117 | else: 118 | g_edge_w = torch.cat([g_edge_w,tmp_g_edge_w],dim=0) 119 | g_edge_index = torch.cat([g_edge_index,tmp_g_edge_index],dim=0) 120 | # print(g_edge_w.shape,g_edge_index.shape) 121 | g_edge_w = g_edge_w.transpose(0,1) 122 | g_edge_index = g_edge_index.unsqueeze(dim=0) 123 | g_edge_index = g_edge_index.repeat_interleave(u_em.shape[0],dim=0) 124 | g_edge_index = g_edge_index.transpose(1,2) 125 | # print(g_x.shape,g_edge_w.shape,g_edge_index.shape) 126 | g_x, g_edge_w, g_edge_index = self.batchInput(g_x, g_edge_w, g_edge_index) 127 | # print(g_x.shape,g_edge_w.shape,g_edge_index.shape) 128 | for i in range(self.gnn_layer): 129 | g_x = self.group_gnn[i](g_x,g_edge_index,g_edge_w) 130 | 131 | g_x = g_x.reshape(-1,self.group_num,g_x.shape[-1]) 132 | # print(g_x.shape,self.w.shape) 133 | w2 = w.unsqueeze(dim=0) 134 | w2 = w2.repeat_interleave(g_x.size(0), dim=0) 135 | new_x = torch.bmm(w2,g_x) 136 | # print(new_x.shape,x.shape) 137 | new_x = torch.cat([x,new_x],dim=-1) 138 | edge_w = edge_w.unsqueeze(dim=-1) 139 | # print(new_x.shape,edge_w.shape,edge_index.shape) 140 | new_x, edge_w, edge_index = self.batchInput(new_x, edge_w, edge_index) 141 | # print(new_x.shape,edge_w.shape,edge_index.shape) 142 | for i in range(self.gnn_layer): 143 | new_x = self.global_gnn[i](new_x,edge_index,edge_w) 144 | # print(new_x.shape) 145 | if self.mode == 'ag': 146 | for i in range(self.pred_step): 147 | new_x = self.decoder(new_x,self.w,g_edge_index,g_edge_w,edge_index,edge_w) 148 | tmp_res = self.predMLP(new_x) 149 | tmp_res = tmp_res.reshape(-1,self.city_num) 150 | tmp_res = tmp_res.unsqueeze(dim=-1) 151 | if i == 0: 152 | res = tmp_res 153 | else: 154 | res = torch.cat([res,tmp_res],dim=-1) 155 | if self.mode == 'full': 156 | new_x = self.decoder(new_x,self.w,g_edge_index,g_edge_w,edge_index,edge_w) 157 | res = self.predMLP(new_x) 158 | res = res.reshape(-1,self.city_num,self.pred_step) 159 | 160 | # print(res.shape) 161 | return res 162 | 163 | class DecoderModule(nn.Module): 164 | def __init__(self,x_em,edge_h,gnn_h,gnn_layer,city_num,group_num,device): 165 | super(DecoderModule, self).__init__() 166 | self.device = device 167 | self.city_num = city_num 168 | self.group_num = group_num 169 | self.gnn_layer = gnn_layer 170 | self.x_embed = Lin(gnn_h, x_em) 171 | self.group_gnn = nn.ModuleList([NodeModel(x_em,edge_h,gnn_h)]) 172 | for i in range(self.gnn_layer-1): 173 | self.group_gnn.append(NodeModel(gnn_h,edge_h,gnn_h)) 174 | self.global_gnn = nn.ModuleList([NodeModel(x_em+gnn_h,1,gnn_h)]) 175 | for i in range(self.gnn_layer-1): 176 | self.global_gnn.append(NodeModel(gnn_h,1,gnn_h)) 177 | 178 | def forward(self,x,trans_w,g_edge_index,g_edge_w,edge_index,edge_w): 179 | x = self.x_embed(x) 180 | x = x.reshape(-1,self.city_num,x.shape[-1]) 181 | w = Parameter(trans_w,requires_grad=False).to(self.device,non_blocking=True) 182 | w1 = w.transpose(0,1) 183 | w1 = w1.unsqueeze(dim=0) 184 | w1 = w1.repeat_interleave(x.size(0), dim=0) 185 | g_x = torch.bmm(w1,x) 186 | g_x = g_x.reshape(-1,g_x.shape[-1]) 187 | for i in range(self.gnn_layer): 188 | g_x = self.group_gnn[i](g_x,g_edge_index,g_edge_w) 189 | g_x = g_x.reshape(-1,self.group_num,g_x.shape[-1]) 190 | w2 = w.unsqueeze(dim=0) 191 | w2 = w2.repeat_interleave(g_x.size(0), dim=0) 192 | new_x = torch.bmm(w2,g_x) 193 | new_x = torch.cat([x,new_x],dim=-1) 194 | new_x = new_x.reshape(-1,new_x.shape[-1]) 195 | # print(new_x.shape,edge_w.shape,edge_index.shape) 196 | for i in range(self.gnn_layer): 197 | new_x = self.global_gnn[i](new_x,edge_index,edge_w) 198 | 199 | return new_x 200 | 201 | 202 | class NodeModel(torch.nn.Module): 203 | def __init__(self,node_h,edge_h,gnn_h): 204 | super(NodeModel, self).__init__() 205 | self.node_mlp_1 = Seq(Lin(node_h+edge_h,gnn_h), ReLU(inplace=True)) 206 | self.node_mlp_2 = Seq(Lin(node_h+gnn_h,gnn_h), ReLU(inplace=True)) 207 | 208 | def forward(self, x, edge_index, edge_attr): 209 | # x: [N, F_x], where N is the number of nodes. 210 | # edge_index: [2, E] with max entry N - 1. 211 | # edge_attr: [E, F_e] 212 | row, col = edge_index 213 | out = torch.cat([x[row], edge_attr], dim=1) 214 | out = self.node_mlp_1(out) 215 | out = scatter_mean(out, col, dim=0, dim_size=x.size(0)) 216 | out = torch.cat([x, out], dim=1) 217 | return self.node_mlp_2(out) --------------------------------------------------------------------------------