├── code ├── configs.pyc ├── configs.py ├── dataset.py ├── utils.py ├── decoder.py ├── attention.py ├── model.py ├── main.py ├── environment.py └── GCN.py ├── .gitignore └── readme.md /code/configs.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fulcrum-zou/VRP-GCN-NPEC/HEAD/code/configs.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | resources 2 | 3 | dataset 4 | 5 | *.ipynb 6 | 7 | result 8 | 9 | code/ortools_cvrp.py 10 | 11 | report.md 12 | 13 | 实验三.md -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # VRP-GCN-NPEC 2 | 3 | This is an implementation of the model GCN-NPEC in [Efficiently Solving the Practical Vehicle Routing Problem: A Novel Joint Learning Approach](https://dl.acm.org/doi/pdf/10.1145/3394486.3403356). It solves vehicle routing problem using deep learning methods combined with reinforcement learning. We use PyTorch as the framework. 4 | 5 | -------------------------------------------------------------------------------- /code/configs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | file_path = '../dataset/' 4 | file_name = 'G-20' 5 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 6 | use_cuda = False 7 | 8 | node_hidden_dim = 100 9 | edge_hidden_dim = 100 10 | gcn_num_layers = 2 11 | 12 | num_epochs = 10 13 | batch_size = 64 14 | beta = 1 15 | learning_rate = 1e-4 16 | weight_decay = 0.96 17 | 18 | node_num = 20 # number of customers 19 | initial_capacity = 1 # initial capacity of vehicles 20 | k = 10 # number of nearest neighbors 21 | alpha = 1 -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader, TensorDataset 5 | 6 | class MyDataloader(): 7 | def __init__(self): 8 | self.train = np.load(file_path + 'data/' + file_name + '-training.npz') 9 | self.test = np.load(file_path + 'data/' + file_name + '-testing.npz') 10 | 11 | def load_data(self, data, shuffle=True): 12 | graph, demand, distance = (data[i] for i in data.files) 13 | dataset = TensorDataset(torch.FloatTensor(graph), torch.FloatTensor(demand), torch.FloatTensor(distance)) 14 | dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True) 15 | return dataloader 16 | 17 | def dataloader(self): 18 | train_loader = self.load_data(self.train) 19 | test_loader = self.load_data(self.test, shuffle=False) 20 | return train_loader, test_loader -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import re 5 | 6 | def write_loss(file_name, epoch, loss): 7 | file_path = '../result/' + file_name 8 | mode = 'w' if epoch == 0 else 'a' 9 | f = open(file_path, mode) 10 | f.write('%d %.2f\n' %(epoch, loss)) 11 | f.close() 12 | 13 | def write_distance(file_name, epoch, dist): 14 | file_path = '../result/' + file_name 15 | mode = 'w' if epoch == 0 else 'a' 16 | f = open(file_path, mode) 17 | f.write('%d %.2f\n' %(epoch, dist)) 18 | f.close() 19 | 20 | def plot_loss(loss): 21 | plt.cla() 22 | file_path = '../result/' + 'loss.png' 23 | plt.plot(loss, color='skyblue', linewidth=1) 24 | plt.title('Train Loss') 25 | plt.xlabel('epochs') 26 | plt.ylabel('loss') 27 | plt.savefig(file_path) 28 | 29 | def plot_dist(dist): 30 | plt.cla() 31 | file_path = '../result/' + 'dist.png' 32 | plt.plot(dist, color='skyblue', linewidth=1) 33 | plt.title('Train Distance') 34 | plt.xlabel('epochs') 35 | plt.ylabel('distance') 36 | plt.savefig(file_path) 37 | 38 | def write_file(file_name, train_result, test_result): 39 | file_path = '../result/' + file_name 40 | f = open(file_path, 'a') 41 | for i in range(len(train_result)): 42 | f.write('%.2f ' %train_result[i][0]) 43 | f.write('%.2f\n' %train_result[i][1]) 44 | f.write('%.2f ' %test_result[i][0]) 45 | f.write('%.2f\n' %test_result[i][1]) 46 | f.close() 47 | 48 | def read_file(file_name): 49 | file_path = 'result/' + file_name 50 | train_result, test_result = [], [] 51 | with open(file_path, 'r', encoding='utf-8') as f: 52 | for i, line in enumerate(f.readlines()): 53 | temp = re.findall(r'-?\d+\.?\d*e?-?\d*?', line) 54 | if i % 2 == 0: 55 | train_result.append([float(temp[0]), float(temp[1])]) 56 | else: 57 | test_result.append([float(temp[0]), float(temp[1])]) 58 | 59 | f.close() 60 | return train_result, test_result -------------------------------------------------------------------------------- /code/decoder.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | from attention import * 3 | import torch 4 | import torch.nn as nn 5 | 6 | class SequencialDecoder(nn.Module): 7 | def __init__(self, hidden_dim, decode_type, use_cuda=False): 8 | super(SequencialDecoder, self).__init__() 9 | self.hidden_dim = hidden_dim 10 | self.decode_type = decode_type 11 | 12 | self.softmax = nn.Softmax(dim=1) 13 | self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=2) 14 | self.tanh = nn.Tanh() 15 | self.h = nn.Linear(hidden_dim, 1) 16 | self.W = nn.Linear(2, 1) 17 | self.pointer = AttentionPointer(hidden_dim, use_tanh=True, use_cuda=use_cuda) 18 | 19 | def forward(self, x, last_node, hidden, mask): 20 | ''' 21 | @param x: (batch_size, node_num, hidden_dim) 22 | @param last_node: (batch_size, 1) 23 | @param hidden: (2, batch_size, hidden_dim) 24 | @param mask: (batch_size, node_num) 25 | ''' 26 | batch_size = x.size(0) 27 | batch_idx = torch.arange(start=0, end=batch_size).unsqueeze(1) 28 | if use_cuda: 29 | batch_idx = batch_idx.to(device) 30 | last_x = x[batch_idx, last_node].permute(1, 0, 2) 31 | _, hidden = self.gru(last_x, hidden) 32 | z = hidden[-1] 33 | # Eq 15 34 | _, u = self.pointer(z, x.permute(1, 0, 2)) 35 | # Eq 16 36 | u = u.masked_fill_(mask, -np.inf) 37 | probs = self.softmax(u) 38 | if self.decode_type == 'sample': 39 | # SampleRollout 40 | idx = torch.multinomial(probs, num_samples=1) 41 | elif self.decode_type == 'greedy': 42 | # GreedyRollout 43 | idx = torch.max(probs, dim=1)[1].unsqueeze(1) 44 | prob = probs[batch_idx, idx].squeeze(1) 45 | 46 | return idx, prob, hidden 47 | 48 | class ClassificationDecoder(nn.Module): 49 | def __init__(self, hidden_dim): 50 | super(ClassificationDecoder, self).__init__() 51 | self.MLP = nn.Sequential( 52 | nn.Linear(hidden_dim, 256), 53 | nn.ReLU(), 54 | nn.Linear(256, 256), 55 | nn.ReLU(), 56 | nn.Linear(256, 2), 57 | ) 58 | self.softmax = nn.Softmax(-1) 59 | 60 | def forward(self, e): 61 | ''' 62 | @param e: (batch_size, node_num, node_num, hidden_dim) 63 | ''' 64 | a = self.MLP(e) 65 | a = a.squeeze(-1) 66 | out = self.softmax(a) 67 | return out -------------------------------------------------------------------------------- /code/attention.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | 8 | class AttentionEncoder(nn.Module): 9 | def __init__(self, hidden_dim): 10 | super(AttentionEncoder, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | 13 | def forward(self, x, neighbor): 14 | ''' 15 | @param x: (batch_size, node_num, hidden_dim) 16 | @param neighbor: (batch_size, node_num, k, hidden_dim) 17 | ''' 18 | # scaled dot-product attention 19 | x = x.unsqueeze(2) 20 | neighbor = neighbor.permute(0, 1, 3, 2) 21 | attn_score = F.softmax(torch.matmul(x, neighbor) / np.sqrt(self.hidden_dim), dim=-1) # (batch_size, node_num, 1, k) 22 | weighted_neighbor = attn_score * neighbor 23 | 24 | # aggregation 25 | agg = x.squeeze(2) + torch.sum(weighted_neighbor, dim=-1) 26 | 27 | return agg 28 | 29 | class AttentionPointer(nn.Module): 30 | def __init__(self, hidden_dim, use_tanh=False, use_cuda=False): 31 | super(AttentionPointer, self).__init__() 32 | self.hidden_dim = hidden_dim 33 | self.use_tanh = use_tanh 34 | 35 | self.project_hidden = nn.Linear(hidden_dim, hidden_dim) 36 | self.project_x = nn.Conv1d(hidden_dim, hidden_dim, 1, 1) 37 | self.C = 10 38 | self.tanh = nn.Tanh() 39 | 40 | v = torch.FloatTensor(hidden_dim) 41 | if use_cuda: 42 | v = v.cuda() 43 | self.v = nn.Parameter(v) 44 | self.v.data.uniform_(-(1. / math.sqrt(hidden_dim)) , 1. / math.sqrt(hidden_dim)) 45 | 46 | def forward(self, hidden, x): 47 | ''' 48 | @param hidden: (batch_size, hidden_dim) 49 | @param x: (node_num, batch_size, hidden_dim) 50 | ''' 51 | x = x.permute(1, 2, 0) 52 | q = self.project_hidden(hidden).unsqueeze(2) # batch_size x hidden_dim x 1 53 | e = self.project_x(x) # batch_size x hidden_dim x node_num 54 | # expand the hidden by node_num 55 | # batch_size x hidden_dim x node_num 56 | expanded_q = q.repeat(1, 1, e.size(2)) 57 | # batch x 1 x hidden_dim 58 | v_view = self.v.unsqueeze(0).expand(expanded_q.size(0), len(self.v)).unsqueeze(1) 59 | # (batch_size x 1 x hidden_dim) * (batch_size x hidden_dim x node_num) 60 | u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1) 61 | if self.use_tanh: 62 | logits = self.C * self.tanh(u) 63 | else: 64 | logits = u 65 | return e, logits -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | from GCN import * 3 | from decoder import * 4 | import torch.nn as nn 5 | 6 | class Model(nn.Module): 7 | def __init__(self, 8 | node_hidden_dim, 9 | edge_hidden_dim, 10 | gcn_num_layers, 11 | k): 12 | super(Model, self).__init__() 13 | 14 | self.GCN = GCN(node_hidden_dim, edge_hidden_dim, 15 | gcn_num_layers, k) 16 | self.sequencialDecoderSample = SequencialDecoder(node_hidden_dim, decode_type='sample', use_cuda=use_cuda) 17 | self.sequencialDecoderGreedy = SequencialDecoder(node_hidden_dim, decode_type='greedy', use_cuda=use_cuda) 18 | self.classificationDecoder = ClassificationDecoder(edge_hidden_dim) 19 | 20 | def seqDecoderForward(self, env, h_node, decode_type='sample'): 21 | # initialize last_node, hidden, mask & reset the environment 22 | env.reset() 23 | last_node = torch.zeros((batch_size, 1)).long().to(device) 24 | hidden = torch.zeros((2, batch_size, node_hidden_dim)).to(device) 25 | mask = torch.zeros((batch_size, node_num+1), dtype=torch.bool).to(device) 26 | mask[:, 0] = True 27 | log_prob = 0 28 | while env.all_visited() == False: 29 | # idx: (batch_size, 1) 30 | # prob: (batch_size) 31 | # hidden: (2, batch_size, hidden_dim) 32 | if decode_type=='sample': 33 | idx, prob, hidden = self.sequencialDecoderSample(h_node, last_node, hidden, mask) 34 | elif decode_type=='greedy': 35 | idx, prob, hidden = self.sequencialDecoderGreedy(h_node, last_node, hidden, mask) 36 | env.step(idx) 37 | last_node = idx 38 | log_prob = log_prob + torch.log(prob) 39 | mask = env.get_mask(idx) 40 | total_dist = env.calc_distance() 41 | matrix = env.decode_routes() 42 | 43 | return total_dist, log_prob, matrix 44 | 45 | def forward(self, env): 46 | x_c = env.graph 47 | x_d = env.demand 48 | m = env.distance 49 | 50 | # GCN encoder 51 | h_node, h_edge = self.GCN(x_c, x_d, m) 52 | batch_size, node_num, node_hidden_dim = h_node.shape 53 | 54 | # sequencial decoder 55 | # SampleRollout 56 | sample_distance, sample_logprob, target_matrix = self.seqDecoderForward(env, h_node, decode_type='sample') 57 | # print('sample:', env.routes[0]) 58 | # GreedyRollout 59 | greedy_distance, _, _ = self.seqDecoderForward(env, h_node, decode_type='greedy') 60 | # print('greedy:', env.routes[0]) 61 | 62 | # classification decoder 63 | predict_matrix = self.classificationDecoder(h_edge) 64 | 65 | return sample_logprob, sample_distance, greedy_distance, target_matrix, predict_matrix -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | from utils import * 3 | from environment import * 4 | from model import * 5 | from dataset import * 6 | import numpy as np 7 | import tqdm 8 | from torch.nn import CrossEntropyLoss 9 | 10 | def train(): 11 | loss_per_epoch = 0 12 | dist_per_epoch = 0 13 | batch_num = 0 14 | mean_dist_sample = 0 15 | mean_dist_greedy = 0 16 | for item in tqdm.tqdm(train_loader, 'train'): 17 | batch_num += 1 18 | graph, demand, distance = item[0].to(device), item[1].to(device), item[2].to(device) 19 | env = Environment(graph, demand, distance) 20 | sample_logprob, sample_distance, greedy_distance, target_matrix, predict_matrix = model(env) 21 | predict_matrix = predict_matrix.view(-1, 2) 22 | target_matrix = target_matrix.view(-1) 23 | classification_loss = criterion(predict_matrix.to(device), target_matrix.to(device)) 24 | advantage = (sample_distance - greedy_distance).detach() 25 | mean_dist_sample += torch.mean(sample_distance) 26 | mean_dist_greedy += torch.mean(greedy_distance) 27 | reinforce = advantage * sample_logprob 28 | sequancial_loss = reinforce.sum() 29 | loss = alpha * sequancial_loss + beta * classification_loss 30 | optimizer.zero_grad() 31 | loss.backward() 32 | optimizer.step() 33 | loss_per_epoch += loss 34 | if batch_num == 5: 35 | break 36 | loss_per_epoch /= (batch_size * batch_num) 37 | dist_per_epoch = mean_dist_sample / batch_num 38 | 39 | print('sample: %.2f greedy: %.2f' %(mean_dist_sample / batch_num, mean_dist_greedy / batch_num)) 40 | return loss_per_epoch, dist_per_epoch, (mean_dist_sample.sum() < mean_dist_greedy.sum()) 41 | 42 | def test(): 43 | loss_per_epoch = 0 44 | dist_per_epoch = 0 45 | batch_num = 0 46 | mean_dist_sample = 0 47 | mean_dist_greedy = 0 48 | with torch.no_grad(): 49 | for item in tqdm.tqdm(test_loader, 'test '): 50 | batch_num += 1 51 | graph, demand, distance = item[0].to(device), item[1].to(device), item[2].to(device) 52 | env = Environment(graph, demand, distance) 53 | sample_logprob, sample_distance, greedy_distance, target_matrix, predict_matrix = model(env) 54 | predict_matrix = predict_matrix.view(-1, 2) 55 | target_matrix = target_matrix.view(-1) 56 | classification_loss = criterion(predict_matrix.to(device), target_matrix.to(device)) 57 | advantage = (sample_distance - greedy_distance).detach() 58 | mean_dist_sample += torch.mean(sample_distance) 59 | mean_dist_greedy += torch.mean(greedy_distance) 60 | reinforce = advantage * sample_logprob 61 | sequancial_loss = reinforce.sum() 62 | loss = alpha * sequancial_loss + beta * classification_loss 63 | loss_per_epoch += loss 64 | if batch_num == 5: 65 | break 66 | 67 | loss_per_epoch /= (batch_size * batch_num) 68 | dist_per_epoch = mean_dist_sample / batch_num 69 | return loss_per_epoch, dist_per_epoch 70 | 71 | 72 | if __name__ == '__main__': 73 | myDataloader = MyDataloader() 74 | train_loader, test_loader = myDataloader.dataloader() 75 | 76 | model = Model(node_hidden_dim, edge_hidden_dim, gcn_num_layers, k).to(device) 77 | criterion = CrossEntropyLoss() 78 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 79 | train_loss, test_loss = [], [] 80 | train_dist, test_dist = [], [] 81 | 82 | for i in range(num_epochs): 83 | if i == 5: 84 | break 85 | train_loss_per_epoch, train_dist_per_epoch, update = train() 86 | test_loss_per_epoch, test_dist_per_epoch = test() 87 | 88 | train_loss.append(train_loss_per_epoch) 89 | train_dist.append(train_dist_per_epoch) 90 | test_loss.append(test_loss_per_epoch) 91 | test_dist.append(test_dist_per_epoch) 92 | print('epoch: %d -train loss: %.2f -distance: %.2f -test loss: %.2f -distance: %.2f' %(i, train_loss[-1], train_dist[-1], test_loss[-1], test_dist[-1])) 93 | write_loss('train_loss.txt', i, train_loss[-1]) 94 | write_distance('train_dist.txt', i, train_dist[-1]) 95 | write_loss('test_loss.txt', i, test_loss[-1]) 96 | write_distance('test_dist.txt', i, test_dist[-1]) 97 | # torch.save(model.state_dict(), '../result/params.pkl') 98 | plot_loss(train_loss) 99 | plot_dist(train_dist) 100 | 101 | if update: 102 | print('update') 103 | model.sequencialDecoderGreedy.load_state_dict(model.sequencialDecoderSample.state_dict()) -------------------------------------------------------------------------------- /code/environment.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | from model import * 3 | from copy import deepcopy 4 | 5 | class Environment: 6 | def __init__(self, graph, demand, distance): 7 | ''' 8 | @param graph: (batch_size, node_num(N+1), 2) 9 | @param demand: (batch_size, node_num(N+1)) 10 | @param distance: (batch_size, node_num(N+1), node_num(N+1)) 11 | ''' 12 | self.graph = graph 13 | self.demand = demand 14 | self.distance = distance 15 | 16 | self.batch_size = batch_size 17 | self.node_num = node_num 18 | self.initial_capacity = initial_capacity 19 | self.k = k 20 | 21 | self.visited, self.routes, self.remaining_capacity, self.remaining_demands = self.init_state() 22 | self.time_step = 0 23 | 24 | def init_state(self): 25 | ''' 26 | visited: (batch_size, node_num+1) 27 | routes: (batch_size, 1) 28 | remaining_capacity: (batch_size, 1) 29 | remaining_demands: (batch_size, node_num+1) 30 | ''' 31 | visited = torch.zeros(self.batch_size, self.node_num+1, dtype=torch.bool) 32 | visited[:, 0] = True 33 | routes = torch.full((self.batch_size, 1), 0, dtype=torch.long) 34 | remaining_capacity = torch.full(size=(self.batch_size, 1), fill_value=self.initial_capacity, dtype=torch.float) 35 | remaining_demands = self.demand.clone().float() 36 | return visited.to(device), routes.to(device), remaining_capacity.to(device), remaining_demands.to(device) 37 | 38 | def reset(self): 39 | self.visited, self.routes, self.remaining_capacity, self.remaining_demands = self.init_state() 40 | self.time_step = 0 41 | 42 | def step(self, action): 43 | ''' update customer and vehicle states 44 | @param action: (batch_size, idx(1)) 45 | 1. visited[idx] = True 46 | 2. routes += action 47 | 3. remaining_capacity = 48 | if idx == 0: initial_capacity 49 | otherwise: max(0, remaining_capacity - demands[idx]) 50 | 4. remaining_demands[idx] = 0 51 | 5. time_step += 1 52 | ''' 53 | action = action.squeeze(-1) 54 | # 1. 55 | self.visited.scatter_(1, action.unsqueeze(1), True) 56 | # 2. 57 | self.routes = torch.cat((self.routes, action.unsqueeze(1)), dim=1) 58 | # 3. 59 | prev_capacity = self.remaining_capacity 60 | curr_demands = self.remaining_demands.gather(1, action.unsqueeze(1)) 61 | self.remaining_capacity[action==0] = self.initial_capacity 62 | self.remaining_capacity[action!=0] = torch.maximum(torch.zeros(self.batch_size, 1)[action!=0].to(device), prev_capacity[action!=0] - curr_demands[action!=0]) 63 | # 4. 64 | self.remaining_demands.scatter_(1, action.unsqueeze(1), 0) 65 | # 5. 66 | self.time_step = self.time_step + 1 67 | 68 | def get_mask(self, last_action): 69 | ''' compute the mask for current states 70 | @param last_action: (batch_size, 1) 71 | 1. if remaining_demands[idx] == 0 or 72 | remaining_demands[idx] >= remaining_capacity: set idx mask True 73 | 2. if last_idx == 0: set the warehouse mask True 74 | 3. if mask is all True: set the warehouse mask False 75 | ''' 76 | mask = self.visited.clone() 77 | last_action = last_action.squeeze(-1) 78 | # 1. 79 | mask[(self.remaining_demands>=self.remaining_capacity)] = True 80 | # 2. 81 | mask[last_action==0, 0] = True 82 | # 3. 83 | mask[mask.all(dim=1), 0] = False 84 | return mask 85 | 86 | def all_visited(self): 87 | return (self.visited == True).all() 88 | 89 | def dist_per_step(self, prev_step, curr_step): 90 | ''' 91 | @param prev_step: (batch_size, 1) 92 | @param curr_step: (batch_size, 1) 93 | @return: distance of single step (batch_size, 1) 94 | ''' 95 | idx = torch.arange(start=0, end=batch_size, step=1).unsqueeze(1) 96 | reward = self.distance[idx, prev_step, curr_step] 97 | return reward 98 | 99 | def get_reward(self): 100 | ''' 101 | @return: routing distance after last action (batch_size, 1) 102 | ''' 103 | prev_step = self.routes[:, -2:-1] 104 | curr_step = self.routes[:, -1:] 105 | reward = self.dist_per_step(prev_step, curr_step) 106 | return reward 107 | 108 | def calc_distance(self): 109 | ''' 110 | @return: total distance of the routes (batch_size, 1) 111 | ''' 112 | total_dist = torch.zeros(self.batch_size, 1).to(device) 113 | for i in range(1, self.routes.size(-1)): 114 | prev_step = self.routes[:, (i-1):i] 115 | curr_step = self.routes[:, i:(i+1)] 116 | dist = self.dist_per_step(prev_step, curr_step) 117 | total_dist = total_dist + dist 118 | return total_dist 119 | 120 | def decode_routes(self): 121 | ''' decode route sequence into a matrix 122 | @return: (batch_size, node_num+1, node_num+1) 123 | ''' 124 | matrix = torch.zeros(self.batch_size, self.node_num+1, self.node_num+1, dtype=torch.float) 125 | idx = torch.arange(start=0, end=batch_size, step=1).unsqueeze(1) 126 | for i in range(1, self.routes.size(-1)): 127 | prev_step = self.routes[:, (i-1):i] 128 | curr_step = self.routes[:, i:(i+1)] 129 | matrix[idx, prev_step, curr_step] = 1 130 | 131 | return matrix.long() -------------------------------------------------------------------------------- /code/GCN.py: -------------------------------------------------------------------------------- 1 | from configs import * 2 | from attention import * 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | class GCN(nn.Module): 7 | def __init__(self, 8 | node_hidden_dim, 9 | edge_hidden_dim, 10 | gcn_num_layers, 11 | k): 12 | super(GCN, self).__init__() 13 | 14 | self.node_hidden_dim = node_hidden_dim 15 | self.edge_hidden_dim = edge_hidden_dim 16 | self.gcn_num_layers = gcn_num_layers 17 | self.k = k 18 | 19 | self.W1 = nn.Linear(2, self.node_hidden_dim) # node_W1 20 | self.W2 = nn.Linear(2, self.node_hidden_dim // 2) # node_W2 21 | self.W3 = nn.Linear(1, self.node_hidden_dim // 2) # node_W3 22 | self.W4 = nn.Linear(1, self.edge_hidden_dim // 2) # edge_W4 23 | self.W5 = nn.Linear(1, self.edge_hidden_dim // 2) # edge_W5 24 | 25 | self.node_embedding = nn.Linear(self.node_hidden_dim, self.node_hidden_dim, bias=False) # Eq5 26 | self.edge_embedding = nn.Linear(self.edge_hidden_dim, self.edge_hidden_dim, bias=False) # Eq6 27 | 28 | self.gcn_layers = nn.ModuleList([GCNLayer(self.node_hidden_dim) for i in range(self.gcn_num_layers)]) 29 | 30 | self.relu = nn.ReLU() 31 | 32 | def adjacency(self, m): 33 | ''' 34 | @param m: distance (node_num, node_num) 35 | ''' 36 | a = torch.zeros_like(m) 37 | idx = torch.argsort(m, dim=1)[:, 1:(self.k+1)] 38 | a.scatter_(1, idx, 1) 39 | a.fill_diagonal_(-1) 40 | 41 | return a 42 | 43 | def find_neighbors(self, m): 44 | ''' find index of neighbors for each node 45 | @param m: distance (batch_size, node_num, node_num) 46 | ''' 47 | neighbor_idx = [] 48 | for i in range(m.shape[0]): 49 | idx = torch.argsort(m[i, :, :], dim=1)[:, 1:(self.k+1)].numpy() 50 | neighbor_idx.append(idx) 51 | return torch.LongTensor(neighbor_idx).to(device) 52 | 53 | def forward(self, x_c, x_d, m): 54 | ''' 55 | @param x_c: coordination (batch_size, node_num(N+1), 2) 56 | @param x_d: demand (batch_size, node_num(N+1)) 57 | @param m: distance (batch_size, node_num(N+1), node_num(N+1)) 58 | ''' 59 | # Eq 2 60 | x0 = self.relu(self.W1(x_c[:, :1, :])) # (batch_size, 1, node_hidden_dim) 61 | xi = self.relu(torch.cat((self.W2(x_c[:, 1:, :]), self.W3(x_d.unsqueeze(2)[:, 1:, :])), dim=-1)) # (batch_size, node_num(N), node_hidden_dim) 62 | x = torch.cat((x0, xi), dim=1) 63 | # Eq 3 64 | a = torch.Tensor([self.adjacency(m[i, :, :]).numpy() for i in range(m.shape[0])]).to(device) 65 | # Eq 4 66 | y = self.relu(torch.cat((self.W4(m.unsqueeze(3)), self.W5(a.unsqueeze(3))), dim=-1)) 67 | # Eq 5 68 | h_node = self.node_embedding(x) 69 | # Eq 6 70 | h_edge = self.edge_embedding(y) 71 | 72 | # index of neighbors 73 | N = self.find_neighbors(m) 74 | 75 | # GCN layers 76 | for gcn_layer in self.gcn_layers: 77 | h_node, h_edge = gcn_layer(h_node, h_edge, N) 78 | 79 | return h_node, h_edge 80 | 81 | 82 | class GCNLayer(nn.Module): 83 | def __init__(self, hidden_dim): 84 | super(GCNLayer, self).__init__() 85 | 86 | # node GCN layers 87 | self.W_node = nn.Linear(hidden_dim, hidden_dim) 88 | self.V_node_in = nn.Linear(hidden_dim, hidden_dim) 89 | self.V_node = nn.Linear(2 * hidden_dim, hidden_dim) 90 | self.attn = AttentionEncoder(hidden_dim) 91 | self.relu = nn.ReLU() 92 | self.ln1_node = nn.LayerNorm(hidden_dim) 93 | self.ln2_node = nn.LayerNorm(hidden_dim) 94 | 95 | # edge GCN layers 96 | self.W_edge = nn.Linear(hidden_dim, hidden_dim) 97 | self.V_edge_in = nn.Linear(hidden_dim, hidden_dim) 98 | self.V_edge = nn.Linear(2 * hidden_dim, hidden_dim) 99 | self.W1_edge = nn.Linear(hidden_dim, hidden_dim) 100 | self.W2_edge = nn.Linear(hidden_dim, hidden_dim) 101 | self.W3_edge = nn.Linear(hidden_dim, hidden_dim) 102 | self.relu = nn.ReLU() 103 | self.ln1_edge = nn.LayerNorm(hidden_dim) 104 | self.ln2_edge = nn.LayerNorm(hidden_dim) 105 | 106 | self.hidden_dim = hidden_dim 107 | 108 | def forward(self, x, e, neighbor_index): 109 | ''' 110 | @param x: (batch_size, node_num(N+1), node_hidden_dim) 111 | @param e: (batch_size, node_num(N+1), node_num(N+1), edge_hidden_dim) 112 | @param neighbor_index: (batch_size, node_num(N+1), k) 113 | ''' 114 | # node embedding 115 | batch_size, node_num = x.size(0), x.size(1) 116 | node_hidden_dim = x.size(-1) 117 | t = x.unsqueeze(1).repeat(1, node_num, 1, 1) 118 | 119 | neighbor_index = neighbor_index.unsqueeze(3).repeat(1, 1, 1, node_hidden_dim) 120 | neighbor = t.gather(2, neighbor_index) 121 | neighbor = neighbor.view(batch_size, node_num, -1, node_hidden_dim) 122 | 123 | # Eq 7/9 124 | h_nb_node = self.ln1_node(x + self.relu(self.W_node(self.attn(x, neighbor)))) 125 | # Eq 12, Eq 8 126 | h_node = self.ln2_node(h_nb_node + self.relu(self.V_node(torch.cat([self.V_node_in(x), h_nb_node], dim=-1)))) 127 | 128 | # edge embedding 129 | x_from = x.unsqueeze(2).repeat(1, 1, node_num, 1) 130 | x_to = x.unsqueeze(1).repeat(1, node_num, 1, 1) 131 | # Eq 7/10, Eq 11 132 | h_nb_edge = self.ln1_edge(e + self.relu(self.W_edge(self.W1_edge(e) + self.W2_edge(x_from) + self.W3_edge(x_to)))) 133 | # Eq 13, Eq 8 134 | h_edge = self.ln2_edge(h_nb_edge + self.relu(self.V_edge(torch.cat((self.V_edge_in(e), h_nb_edge), dim=-1)))) 135 | 136 | return h_node, h_edge --------------------------------------------------------------------------------