├── .gitignore ├── README.md ├── attention_dynamic_model.py ├── attention_graph_encoder.py ├── backup_results_VRP_20_2021-08-31.csv ├── checkpts └── .gitignore ├── dvrp_env.yml ├── environment.py ├── layers.py ├── learning_curve_plot_VRP_20_2021-08-31_start=0, end=40.jpg ├── reinforce_baseline.py ├── test_20n_1024bs.ipynb ├── test_results_VRP_20_2021-08-31_start=0, end=40.xlsx ├── train.py ├── train_model.ipynb ├── train_model_loading_from_past.ipynb ├── train_results_VRP_20_2021-08-31_start=0, end=40.xlsx ├── utils.py ├── utils_demo.py └── valsets └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | /.ipynb_checkpoints 2 | **/__pycache__ 3 | /.vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic-Attention-Model-for-VRP---Pytorch 2 | 3 | This is an implementation for the paper "A Deep Reinforcement Learning Algorithm Using Dynamic Attention Model for Vehicle Routing Problems" (https://arxiv.org/abs/2002.03282) done in Pytorch and based on a tensorflow implementation https://github.com/d-eremeev/ADM-VRP. 4 | 5 | Pytorch is well known to be faster than Tensorflow in many cases if used properly. Here we noticed increases of +400% speed over Tensorflow implementations in our tests without the memory-efficient gradient trick (and a bit less with it). 6 | 7 | To run take a look at **./test_20n_1024bs.ipynb**. This simple notebook was used to test 40 epochs of CVRP_20 following the description proposed by Nazari et. al (as done in the paper) with batch sizes of 1024. Results for this notebook (except for the valsets and checkpts are included for demonstrative purposes; ex. take a look at ./backup_results_VRP_20_2021-08-31.csv. 8 | 9 | About the memory-efficient gradient trick: 10 | 11 | I noticed that we could rearrange the formula of the gradients being used such that when translating it to code we sacrifice runtime in order to stop storing huge computation graphs in cuda memory. This trick significantly reduces the memory being used; if you want to see exactly what it is please take a look into the commit itself (https://github.com/Roberto09/Dynamic-Attention-Model-for-VRP---Pytorch/commit/fc1f9a8b6650fcd3cae23fb18db826147a29c3a1). The explanation of how/why this works can be found in the "Gradient Computation Improvement for AM-D" section of a paper we wrote (https://arxiv.org/abs/2211.13922). 12 | -------------------------------------------------------------------------------- /attention_dynamic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.categorical import Categorical 5 | import math 6 | import numpy as np 7 | 8 | from attention_graph_encoder import GraphAttentionEncoder 9 | from layers import scaled_attention 10 | from environment import AgentVRP 11 | from utils import get_dev_of_mod 12 | 13 | def set_decode_type(model, decode_type): 14 | model.set_decode_type(decode_type) 15 | 16 | class AttentionDynamicModel(nn.Module): 17 | 18 | def __init__(self, 19 | embedding_dim, 20 | n_encode_layers=2, 21 | n_heads=8, 22 | tanh_clipping=10. 23 | ): 24 | 25 | super().__init__() 26 | 27 | # attributes for MHA 28 | self.embedding_dim = embedding_dim 29 | self.n_encode_layers = n_encode_layers 30 | self.decode_type = None 31 | 32 | # attributes for VRP problem 33 | self.problem = AgentVRP 34 | self.n_heads = n_heads 35 | 36 | # Encoder part 37 | self.embedder = GraphAttentionEncoder(input_dim=self.embedding_dim, 38 | num_heads=self.n_heads, 39 | num_layers=self.n_encode_layers 40 | ) 41 | 42 | # Decoder part 43 | 44 | self.output_dim = self.embedding_dim 45 | self.num_heads = n_heads 46 | 47 | self.head_depth = self.output_dim // self.num_heads 48 | self.dk_mha_decoder = float(self.head_depth) # for decoding in mha_decoder 49 | self.dk_get_loc_p = float(self.output_dim) # for decoding in mha_decoder 50 | 51 | if self.output_dim % self.num_heads != 0: 52 | raise ValueError("number of heads must divide d_model=output_dim") 53 | 54 | self.tanh_clipping = tanh_clipping 55 | 56 | # we split projection matrix Wq into 2 matrices: Wq*[h_c, h_N, D] = Wq_context*h_c + Wq_step_context[h_N, D] 57 | self.wq_context = nn.Linear(self.embedding_dim, self.output_dim) # (d_q_context, output_dim) 58 | self.wq_step_context = nn.Linear(self.embedding_dim + 1, self.output_dim, bias=False) # (d_q_step_context, output_dim) 59 | 60 | # we need two Wk projections since there is MHA followed by 1-head attention - they have different keys K 61 | self.wk = nn.Linear(self.embedding_dim, self.output_dim, bias=False) # (d_k, output_dim) 62 | self.wk_tanh = nn.Linear(self.embedding_dim, self.output_dim, bias=False) # (d_k_tanh, output_dim) 63 | 64 | # we dont need Wv projection for 1-head attention: only need attention weights as outputs 65 | self.wv = nn.Linear(self.embedding_dim, self.output_dim, bias=False) # (d_v, output_dim) 66 | 67 | # we dont need wq for 1-head tanh attention, since we can absorb it into w_out 68 | self.w_out = nn.Linear(self.embedding_dim, self.output_dim, bias=False) # (d_model, d_model) 69 | 70 | self.dev = None 71 | 72 | def set_decode_type(self, decode_type): 73 | self.decode_type = decode_type 74 | 75 | def split_heads(self, tensor, batch_size): 76 | """Function for computing attention on several heads simultaneously 77 | Splits tensor to be multi headed. 78 | """ 79 | # (batch_size, seq_len, output_dim) -> (batch_size, seq_len, n_heads, head_depth) 80 | splitted_tensor = tensor.view(batch_size, -1, self.n_heads, self.head_depth) 81 | return splitted_tensor.transpose(1, 2) # (batch_size, n_heads, seq_len, head_depth) 82 | 83 | def _select_node(self, logits): 84 | """Select next node based on decoding type. 85 | """ 86 | 87 | # assert tf.reduce_all(logits) == logits, "Probs should not contain any nans" 88 | 89 | if self.decode_type == "greedy": 90 | selected = torch.argmax(logits, dim=-1) # (batch_size, 1) 91 | 92 | elif self.decode_type == "sampling": 93 | # logits has a shape of (batch_size, 1, n_nodes), we have to squeeze it 94 | # to (batch_size, n_nodes) since tf.random.categorical requires matrix 95 | cat_dist = Categorical(logits=logits[:, 0, :]) # creates categorical distribution from tensor (batch_size) 96 | selected = cat_dist.sample() # takes a single sample from distribution 97 | else: 98 | assert False, "Unknown decode type" 99 | 100 | return torch.squeeze(selected, -1) # (batch_size,) 101 | 102 | def get_step_context(self, state, embeddings): 103 | """Takes a state and graph embeddings, 104 | Returns a part [h_N, D] of context vector [h_c, h_N, D], 105 | that is related to RL Agent last step. 106 | """ 107 | # index of previous node 108 | prev_node = state.prev_a.to(self.dev) # (batch_size, 1) 109 | 110 | # from embeddings=(batch_size, n_nodes, input_dim) select embeddings of previous nodes 111 | cur_embedded_node = embeddings.gather(1, prev_node.view(prev_node.shape[0], -1, 1) 112 | .repeat_interleave(embeddings.shape[-1], -1)) # (batch_size, 1, input_dim) 113 | 114 | # add remaining capacity 115 | step_context = torch.cat([cur_embedded_node, (self.problem.VEHICLE_CAPACITY - state.used_capacity[:, :, None]).to(self.dev)], dim=-1) 116 | 117 | return step_context # (batch_size, 1, input_dim + 1) 118 | 119 | def decoder_mha(self, Q, K, V, mask=None): 120 | """ Computes Multi-Head Attention part of decoder 121 | Args: 122 | mask: a mask for visited nodes, 123 | has shape (batch_size, seq_len_q, seq_len_k), seq_len_q = 1 for context vector attention in decoder 124 | Q: query (context vector for decoder) 125 | has shape (batch_size, n_heads, seq_len_q, head_depth) with seq_len_q = 1 for context_vector attention in decoder 126 | K, V: key, value (projections of nodes embeddings) 127 | have shape (batch_size, n_heads, seq_len_k, head_depth), (batch_size, n_heads, seq_len_v, head_depth), 128 | with seq_len_k = seq_len_v = n_nodes for decoder 129 | """ 130 | 131 | # Add dimension to mask so that it can be broadcasted across heads 132 | # (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k) 133 | if mask is not None: 134 | mask = mask.unsqueeze(1) 135 | 136 | attention = scaled_attention(Q, K, V, mask) # (batch_size, n_heads, seq_len_q, head_depth) 137 | # transpose attention to (batch_size, seq_len_q, n_heads, head_depth) 138 | attention = attention.transpose(1, 2).contiguous() 139 | # concatenate results of all heads (batch_size, seq_len_q, self.output_dim) 140 | attention = attention.view(self.batch_size, -1, self.output_dim) 141 | 142 | output = self.w_out(attention) 143 | 144 | return output 145 | 146 | def get_log_p(self, Q, K, mask=None): 147 | """Single-Head attention sublayer in decoder, 148 | computes log-probabilities for node selection. 149 | 150 | Args: 151 | mask: mask for nodes 152 | Q: query (output of mha layer) 153 | has shape (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context attention in decoder 154 | K: key (projection of node embeddings) 155 | has shape (batch_size, seq_len_k, output_dim), seq_len_k = n_nodes for decoder 156 | """ 157 | 158 | compatibility = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1]) 159 | compatibility = torch.tanh(compatibility) * self.tanh_clipping 160 | if mask is not None: compatibility = compatibility.masked_fill(mask == 1, -1e9) 161 | 162 | log_p = F.log_softmax(compatibility, dim=-1) # (batch_size, seq_len_q, seq_len_k) 163 | 164 | return log_p 165 | 166 | def get_likelihood_selection(self, _log_p, a): 167 | 168 | # Get log_p corresponding to selected actions for every batch 169 | indices = a.view(a.shape[0], -1) 170 | select = _log_p.gather(-1, indices) 171 | return select.view(-1) 172 | 173 | 174 | def get_projections(self, embeddings, context_vectors): 175 | 176 | # we compute some projections (common for each policy step) before decoding loop for efficiency 177 | K = self.wk(embeddings) # (batch_size, n_nodes, output_dim) 178 | K_tanh = self.wk_tanh(embeddings) # (batch_size, n_nodes, output_dim) 179 | V = self.wv(embeddings) # (batch_size, n_nodes, output_dim) 180 | Q_context = self.wq_context(context_vectors[:, None, :]) # (batch_size, 1, output_dim) 181 | 182 | # we dont need to split K_tanh since there is only 1 head; Q will be split in decoding loop 183 | K = self.split_heads(K, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth) 184 | V = self.split_heads(V, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth) 185 | 186 | return K_tanh, Q_context, K, V 187 | 188 | 189 | def fwd_rein_loss(self, inputs, baseline, bl_vals, num_batch, return_pi=False): 190 | """ 191 | Forward and calculate loss for REINFORCE algorithm in a memory efficient way. 192 | This sacrifices a bit of performance but is way better in memory terms and works 193 | by reordering the terms in the gradient formula such that we don't store gradients 194 | for all the seguence for a long time which hence produces a lot of memory consumption. 195 | """ 196 | 197 | on_training = self.training 198 | self.eval() 199 | with torch.no_grad(): 200 | cost, log_likelihood, seq = self(inputs, True) 201 | bl_val = bl_vals[num_batch] if bl_vals is not None else baseline.eval(inputs, cost) 202 | pre_cost = cost - bl_val.detach() 203 | detached_loss = torch.mean((pre_cost) * log_likelihood) 204 | 205 | if on_training: self.train() 206 | return detached_loss, self(inputs, return_pi, seq, pre_cost) 207 | 208 | def forward(self, inputs, return_pi=False, pre_selects=None, pre_cost=None): 209 | """ 210 | Forward method. Works as expected except and as described on the paper, however 211 | if pre_selects is None which hence implies that pre_cost should be none it's because 212 | fwd_rein_loss is calling it; check that method for a description of why this is useful. 213 | """ 214 | 215 | self.batch_size = inputs[0].shape[0] 216 | 217 | state = self.problem(inputs) # use CPU inputs for state 218 | inputs = self.set_input_device(inputs) # sent inputs to GPU for training if it's being used 219 | sequences = [] 220 | ll = torch.zeros(self.batch_size) 221 | 222 | if pre_selects is not None: 223 | pre_selects = pre_selects.transpose(0, 1) 224 | # Perform decoding steps 225 | pre_select_idx = 0 226 | while not state.all_finished(): 227 | 228 | state.i = torch.zeros(1, dtype=torch.int64) 229 | att_mask, cur_num_nodes = state.get_att_mask() 230 | att_mask, cur_num_nodes = att_mask.to(self.dev), cur_num_nodes.to(self.dev) 231 | embeddings, context_vectors = self.embedder(inputs, att_mask, cur_num_nodes) 232 | K_tanh, Q_context, K, V = self.get_projections(embeddings, context_vectors) 233 | 234 | while not state.partial_finished(): 235 | 236 | step_context = self.get_step_context(state, embeddings) # (batch_size, 1, input_dim + 1) 237 | Q_step_context = self.wq_step_context(step_context) # (batch_size, 1, output_dim) 238 | Q = Q_context + Q_step_context 239 | 240 | # split heads for Q 241 | Q = self.split_heads(Q, self.batch_size) # (batch_size, num_heads, 1, head_depth) 242 | 243 | # get current mask 244 | mask = state.get_mask().to(self.dev) # (batch_size, 1, n_nodes) True -> mask, i.e. agent can NOT go 245 | 246 | # compute MHA decoder vectors for current mask 247 | mha = self.decoder_mha(Q, K, V, mask) # (batch_size, 1, output_dim) 248 | 249 | # compute probabilities 250 | log_p = self.get_log_p(mha, K_tanh, mask) # (batch_size, 1, n_nodes) 251 | 252 | # next step is to select node 253 | if pre_selects is None: 254 | selected = self._select_node(log_p.detach()) # (batch_size,) 255 | else: 256 | selected = pre_selects[pre_select_idx] 257 | 258 | state.step(selected.detach().cpu()) 259 | 260 | curr_ll = self.get_likelihood_selection(log_p[:, 0, :].cpu(), selected.detach().cpu()) 261 | if pre_selects is not None: 262 | curr_loss = (curr_ll * pre_cost).sum() / self.batch_size 263 | curr_loss.backward(retain_graph=True) 264 | curr_ll = curr_ll.detach() 265 | ll += curr_ll 266 | 267 | sequences.append(selected.detach().cpu()) 268 | pre_select_idx += 1 269 | # torch.cuda.empty_cache() 270 | # torch.cuda.empty_cache() 271 | 272 | pi = torch.stack(sequences, dim=1) # (batch_size, len(outputs)) 273 | cost = self.problem.get_costs((inputs[0].detach().cpu(), inputs[1].detach().cpu(), inputs[2].detach().cpu()), pi) 274 | 275 | ret = [cost, ll] 276 | if return_pi: ret.append(pi) 277 | return ret 278 | 279 | def set_input_device(self, inp_tens): 280 | if self.dev is None: self.dev = get_dev_of_mod(self) 281 | return(inp_tens[0].to(self.dev), inp_tens[1].to(self.dev), inp_tens[2].to(self.dev)) 282 | -------------------------------------------------------------------------------- /attention_graph_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers import MultiHeadAttention 5 | 6 | class MultiHeadAttentionLayer(nn.Module): 7 | """Feed-Forward Sublayer: fully-connected Feed-Forward network, 8 | built based on MHA vectors from MultiHeadAttention layer with skip-connections 9 | 10 | Args: 11 | num_heads: number of attention heads in MHA layers. 12 | input_dim: embedding size that will be used as d_model in MHA layers. 13 | feed_forward_hidden: number of neuron units in each FF layer. 14 | 15 | Call arguments: 16 | x: batch of shape (batch_size, n_nodes, node_embedding_size). 17 | mask: mask for MHA layer 18 | 19 | Returns: 20 | outputs of shape (batch_size, n_nodes, input_dim) 21 | 22 | """ 23 | 24 | def __init__(self, input_dim, num_heads, feed_forward_hidden=512, **kwargs): 25 | super().__init__(**kwargs) 26 | self.mha = MultiHeadAttention(n_heads=num_heads, d_model=input_dim) 27 | 28 | self.ff1 = nn.Linear(input_dim, feed_forward_hidden) 29 | self.ff2 = nn.Linear(feed_forward_hidden, input_dim) 30 | 31 | def forward(self, x, mask=None): 32 | mha_out = self.mha(x, x, x, mask) 33 | sc1_out = torch.add(x, mha_out) 34 | tanh1_out = torch.tanh(sc1_out) 35 | 36 | ff1_out = self.ff1(tanh1_out) 37 | relu1_out = F.relu(ff1_out) 38 | ff2_out = self.ff2(relu1_out) 39 | sc2_out = torch.add(tanh1_out, ff2_out) 40 | tanh2_out = torch.tanh(sc2_out) 41 | 42 | return tanh2_out 43 | 44 | class GraphAttentionEncoder(nn.Module): 45 | """Graph Encoder, which uses MultiHeadAttentionLayer sublayer. 46 | 47 | Args: 48 | input_dim: embedding size that will be used as d_model in MHA layers. 49 | num_heads: number of attention heads in MHA layers. 50 | num_layers: number of attention layers that will be used in encoder. 51 | feed_forward_hidden: number of neuron units in each FF layer. 52 | 53 | Call arguments: 54 | x: tuples of 3 tensors: (batch_size, 2), (batch_size, n_nodes-1, 2), (batch_size, n_nodes-1) 55 | First tensor contains coordinates for depot, second one is for coordinates of other nodes, 56 | Last tensor is for normalized demands for nodes except depot 57 | 58 | mask: mask for MHA layer 59 | 60 | Returns: 61 | Embedding for all nodes + mean embedding for graph. 62 | Tuples ((batch_size, n_nodes, input_dim), (batch_size, input_dim)) 63 | """ 64 | 65 | def __init__(self, input_dim, num_heads, num_layers, feed_forward_hidden=512): 66 | super().__init__() 67 | 68 | self.input_dim = input_dim 69 | self.num_layers = num_layers 70 | self.num_heads = num_heads 71 | self.feed_forward_hidden = feed_forward_hidden 72 | 73 | # initial embeddings (batch_size, n_nodes-1, 2) --> (batch-size, input_dim), separate for depot and other nodes 74 | self.init_embed_depot = nn.Linear(2, self.input_dim) # nn.Linear(2, embedding_dim) 75 | self.init_embed = nn.Linear(3, self.input_dim) 76 | 77 | self.mha_layers = [MultiHeadAttentionLayer(self.input_dim, self.num_heads, self.feed_forward_hidden) 78 | for _ in range(self.num_layers)] 79 | self.mha_layers = nn.ModuleList(self.mha_layers) 80 | 81 | def forward(self, x, mask=None, cur_num_nodes=None): 82 | 83 | x = torch.cat((self.init_embed_depot(x[0])[:, None, :], # (batch_size, 2) --> (batch_size, 1, 2) 84 | self.init_embed(torch.cat((x[1], x[2][:, :, None]), -1)) # (batch_size, n_nodes-1, 2) + (batch_size, n_nodes-1) 85 | ), 1) # (batch_size, n_nodes, input_dim) 86 | 87 | # stack attention layers 88 | for i in range(self.num_layers): 89 | x = self.mha_layers[i](x, mask) 90 | 91 | if mask is not None: 92 | output = (x, torch.sum(x, 1) / cur_num_nodes) 93 | else: 94 | output = (x, torch.mean(x, 1)) 95 | 96 | return output # (embeds of nodes, avg graph embed)=((batch_size, n_nodes, input), (batch_size, input_dim)) -------------------------------------------------------------------------------- /backup_results_VRP_20_2021-08-31.csv: -------------------------------------------------------------------------------- 1 | epochs,train_loss,train_cost,val_cost 2 | 0,0.057725433,8.126152,7.142528 3 | 1,0.43523985,7.0198774,6.8331623 4 | 2,-0.05244048,6.799927,6.692892 5 | 3,-0.21513543,6.7059703,6.6481056 6 | 4,-0.16910644,6.6476564,6.585553 7 | 5,-0.21084756,6.606468,6.55772 8 | 6,-0.20123357,6.5790243,6.533435 9 | 7,-0.20451881,6.5594954,6.523542 10 | 8,-0.17670833,6.5427947,6.505786 11 | 9,-0.17904146,6.530585,6.4873476 12 | 10,-0.19809744,6.519344,6.4940104 13 | 11,-0.15416735,6.508296,6.478488 14 | 12,-0.16621172,6.4995356,6.4869556 15 | 13,-0.13755234,6.4940825,6.47636 16 | 14,-0.11532383,6.4876075,6.460555 17 | 15,-0.14074893,6.4802,6.4586205 18 | 16,-0.13124135,6.474483,6.4573145 19 | 17,-0.11284375,6.4697905,6.4460807 20 | 18,-0.12861899,6.464823,6.437971 21 | 19,-0.13243529,6.459164,6.4339194 22 | 20,-0.13290825,6.4535875,6.438004 23 | 21,-0.12190188,6.44997,6.443571 24 | 22,-0.11239375,6.447536,6.420986 25 | 23,-0.12801744,6.443129,6.431984 26 | 24,-0.11823546,6.440166,6.42777 27 | 25,-0.107929625,6.436949,6.418355 28 | 26,-0.09892246,6.4335003,6.4274316 29 | 27,-0.09260898,6.430896,6.412428 30 | 28,-0.11238246,6.429055,6.4071865 31 | 29,-0.11275884,6.424093,6.4058414 32 | 30,-0.11065548,6.4216547,6.401159 33 | 31,-0.10574161,6.4185834,6.3972263 34 | 32,-0.1168384,6.415805,6.399093 35 | 33,-0.11321704,6.4131637,6.389376 36 | 34,-0.1172516,6.4110584,6.393464 37 | 35,-0.11503467,6.410205,6.3980913 38 | 36,-0.109255955,6.4086895,6.392976 39 | 37,-0.10428008,6.4064713,6.39075 40 | 38,-0.0974484,6.4040713,6.3854094 41 | 39,-0.108119674,6.403427,6.39277 42 | -------------------------------------------------------------------------------- /checkpts/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /dvrp_env.yml: -------------------------------------------------------------------------------- 1 | name: dl_ml 2 | channels: 3 | - pytorch 4 | - plotly 5 | - numba 6 | - conda-forge 7 | - anaconda 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _tflow_select=2.3.0=mkl 12 | - absl-py=0.11.0=pyhd3eb1b0_1 13 | - aiohttp=3.6.3=py38h7b6447c_0 14 | - alembic=1.4.3=pyh9f0ad1d_0 15 | - argon2-cffi=20.1.0=py38h25fe258_2 16 | - astunparse=1.6.3=py_0 17 | - async-timeout=3.0.1=py38_0 18 | - async_generator=1.10=py_0 19 | - attrs=20.3.0=pyhd3eb1b0_0 20 | - backcall=0.2.0=pyh9f0ad1d_0 21 | - backports=1.0=py_2 22 | - backports.functools_lru_cache=1.6.1=py_0 23 | - blas=1.0=mkl 24 | - bleach=3.2.1=pyh9f0ad1d_0 25 | - blinker=1.4=py38_0 26 | - brotlipy=0.7.0=py38h27cfd23_1003 27 | - c-ares=1.17.1=h27cfd23_0 28 | - ca-certificates=2021.1.19=h06a4308_0 29 | - cachetools=4.2.0=pyhd3eb1b0_0 30 | - certifi=2020.12.5=py38h06a4308_0 31 | - cffi=1.14.4=py38h261ae71_0 32 | - chardet=3.0.4=py38h06a4308_1003 33 | - click=7.1.2=py_0 34 | - cliff=3.5.0=pyhd8ed1ab_0 35 | - cmaes=0.7.0=pyhac0dd68_0 36 | - cmd2=0.9.22=py38h32f6830_1 37 | - colorama=0.4.4=pyh9f0ad1d_0 38 | - colorlog=4.0.2=py38_1000 39 | - cryptography=3.3.1=py38h3c74f83_0 40 | - cudatoolkit=11.0.221=h6bb024c_0 41 | - cycler=0.10.0=py_2 42 | - dbus=1.13.18=hb2f20db_0 43 | - decorator=4.4.2=py_0 44 | - defusedxml=0.6.0=py_0 45 | - entrypoints=0.3=pyhd8ed1ab_1003 46 | - et_xmlfile=1.0.1=py_1001 47 | - expat=2.2.10=he6710b0_2 48 | - fontconfig=2.13.1=he4413a7_1000 49 | - freetype=2.10.4=h5ab3b9f_0 50 | - fsspec=0.8.5=pyhd8ed1ab_0 51 | - future=0.18.2=py38h578d9bd_3 52 | - gast=0.3.3=py_0 53 | - glib=2.66.1=h92f7085_0 54 | - google-auth=1.24.0=pyhd3eb1b0_0 55 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 56 | - google-pasta=0.2.0=py_0 57 | - grpcio=1.31.0=py38hf8bcb03_0 58 | - gst-plugins-base=1.14.0=hbbd80ab_1 59 | - gstreamer=1.14.0=h28cd5cc_2 60 | - h5py=2.10.0=py38hd6299e0_1 61 | - hdf5=1.10.6=hb1b8bf9_0 62 | - icu=58.2=hf484d3e_1000 63 | - idna=2.10=py_0 64 | - importlib-metadata=2.0.0=py_1 65 | - importlib_metadata=2.0.0=1 66 | - intel-openmp=2020.2=254 67 | - ipykernel=5.4.2=py38h81c977d_0 68 | - ipython=7.12.0=py38h5ca1d4c_0 69 | - ipython_genutils=0.2.0=py_1 70 | - jdcal=1.4.1=py_0 71 | - jedi=0.17.2=py38h06a4308_1 72 | - jinja2=2.11.2=pyh9f0ad1d_0 73 | - joblib=1.0.0=pyhd3eb1b0_0 74 | - jpeg=9b=h024ee3a_2 75 | - jsonschema=3.2.0=py_2 76 | - jupyter_client=6.1.11=pyhd8ed1ab_1 77 | - jupyter_contrib_core=0.3.3=py_2 78 | - jupyter_contrib_nbextensions=0.5.1=py38h32f6830_1 79 | - jupyter_core=4.7.0=py38h578d9bd_0 80 | - jupyter_highlight_selected_word=0.2.0=py38h32f6830_1002 81 | - jupyter_latex_envs=1.4.6=py38h32f6830_1001 82 | - jupyter_nbextensions_configurator=0.4.1=py38h32f6830_2 83 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 84 | - keras-preprocessing=1.1.0=py_1 85 | - kiwisolver=1.3.1=py38h82cb98a_0 86 | - lcms2=2.11=h396b838_0 87 | - ld_impl_linux-64=2.33.1=h53a641e_7 88 | - libedit=3.1.20191231=h14c3975_1 89 | - libffi=3.3=he6710b0_2 90 | - libgcc-ng=9.1.0=hdf63c60_0 91 | - libgfortran-ng=7.3.0=hdf63c60_0 92 | - libpng=1.6.37=hbc83047_0 93 | - libprotobuf=3.13.0.1=hd408876_0 94 | - libsodium=1.0.18=h36c2ea0_1 95 | - libstdcxx-ng=9.1.0=hdf63c60_0 96 | - libtiff=4.1.0=h2733197_1 97 | - libuuid=2.32.1=h14c3975_1000 98 | - libuv=1.40.0=h7b6447c_0 99 | - libxcb=1.13=h14c3975_1002 100 | - libxml2=2.9.10=hb55368b_3 101 | - libxslt=1.1.34=hc22bd24_0 102 | - llvmlite=0.35.0=py38hf484d3e_0 103 | - lxml=4.6.2=py38h9120a33_0 104 | - lz4-c=1.9.2=heb0550a_3 105 | - mako=1.1.4=pyh44b312d_0 106 | - markdown=3.3.3=py38h06a4308_0 107 | - markupsafe=1.1.1=py38h8df0ef7_2 108 | - matplotlib=3.3.3=py38h578d9bd_0 109 | - matplotlib-base=3.3.3=py38h5c7f4ab_0 110 | - mistune=0.8.4=py38h25fe258_1002 111 | - mkl=2020.2=256 112 | - mkl-service=2.3.0=py38he904b0f_0 113 | - mkl_fft=1.2.0=py38h23d657b_0 114 | - mkl_random=1.1.1=py38h0573a6f_0 115 | - multidict=4.7.6=py38h7b6447c_1 116 | - nbclient=0.5.1=py_0 117 | - nbconvert=6.0.7=py38_0 118 | - nbformat=5.1.2=pyhd8ed1ab_0 119 | - ncurses=6.2=he6710b0_1 120 | - nest-asyncio=1.4.3=pyhd8ed1ab_0 121 | - ninja=1.10.2=py38hff7bd54_0 122 | - notebook=6.2.0=py38h578d9bd_0 123 | - numba=0.52.0=np1.11py3.8h04863e7_g18825058a_0 124 | - numpy=1.19.2=py38h54aff64_0 125 | - numpy-base=1.19.2=py38hfa32c7d_0 126 | - oauthlib=3.1.0=py_0 127 | - olefile=0.46=py_0 128 | - openpyxl=3.0.5=py_0 129 | - openssl=1.1.1j=h27cfd23_0 130 | - opt_einsum=3.1.0=py_0 131 | - optuna=2.4.0=pyhd8ed1ab_0 132 | - packaging=20.8=pyhd3deb0d_0 133 | - pandas=1.2.0=py38ha9443f7_0 134 | - pandoc=2.11.3.2=h7f98852_0 135 | - pandocfilters=1.4.2=py_1 136 | - parso=0.7.0=py_0 137 | - pbr=5.5.1=pyh9f0ad1d_0 138 | - pcre=8.44=he1b5a44_0 139 | - pexpect=4.8.0=pyh9f0ad1d_2 140 | - pickleshare=0.7.5=py_1003 141 | - pillow=8.1.0=py38he98fc37_0 142 | - pip=20.3.3=py38h06a4308_0 143 | - plotly=4.14.3=py_0 144 | - prettytable=2.0.0=pyhd8ed1ab_0 145 | - prometheus_client=0.9.0=pyhd3deb0d_0 146 | - prompt-toolkit=3.0.10=pyha770c72_0 147 | - prompt_toolkit=3.0.10=hd8ed1ab_0 148 | - protobuf=3.13.0.1=py38he6710b0_1 149 | - pthread-stubs=0.4=h36c2ea0_1001 150 | - ptyprocess=0.7.0=pyhd3deb0d_0 151 | - pyasn1=0.4.8=py_0 152 | - pyasn1-modules=0.2.8=py_0 153 | - pycparser=2.20=py_2 154 | - pygments=2.7.4=pyhd8ed1ab_0 155 | - pyjwt=1.7.1=py38_0 156 | - pyopenssl=20.0.1=pyhd3eb1b0_1 157 | - pyparsing=2.4.7=pyh9f0ad1d_0 158 | - pyperclip=1.8.1=pyhd3deb0d_0 159 | - pyqt=5.9.2=py38h05f1152_4 160 | - pyrsistent=0.17.3=py38h25fe258_1 161 | - pysocks=1.7.1=py38h06a4308_0 162 | - python=3.8.5=h7579374_1 163 | - python-dateutil=2.8.1=py_0 164 | - python-editor=1.0.4=py_0 165 | - python_abi=3.8=1_cp38 166 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 167 | - pytorch-lightning=1.1.4=pyhd8ed1ab_0 168 | - pytz=2020.5=pyhd3eb1b0_0 169 | - pyyaml=5.3.1=py38h8df0ef7_1 170 | - pyzmq=19.0.2=py38ha71036d_2 171 | - qt=5.9.7=h5867ecd_1 172 | - readline=8.0=h7b6447c_0 173 | - requests=2.25.1=pyhd3eb1b0_0 174 | - requests-oauthlib=1.3.0=py_0 175 | - retrying=1.3.3=py_2 176 | - rsa=4.7=pyhd3eb1b0_1 177 | - scikit-learn=0.23.2=py38h0573a6f_0 178 | - scipy=1.5.2=py38h0b6359f_0 179 | - seaborn=0.11.0=py_0 180 | - send2trash=1.5.0=py_0 181 | - setuptools=51.1.2=py38h06a4308_4 182 | - sip=4.19.13=py38he6710b0_0 183 | - six=1.15.0=py38h06a4308_0 184 | - sqlalchemy=1.3.21=py38h27cfd23_0 185 | - sqlite=3.33.0=h62c20be_0 186 | - stevedore=3.3.0=py38h578d9bd_1 187 | - tensorboard=2.3.0=pyh4dce500_0 188 | - tensorboard-plugin-wit=1.6.0=py_0 189 | - tensorflow=2.2.0=mkl_py38h6d3daf0_0 190 | - tensorflow-base=2.2.0=mkl_py38h5059a2d_0 191 | - tensorflow-estimator=2.2.0=pyh208ff02_0 192 | - termcolor=1.1.0=py38_1 193 | - terminado=0.9.2=py38h578d9bd_0 194 | - testpath=0.4.4=py_0 195 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 196 | - tk=8.6.10=hbc83047_0 197 | - torchaudio=0.7.2=py38 198 | - torchvision=0.8.2=py38_cu110 199 | - tornado=6.1=py38h25fe258_0 200 | - tqdm=4.56.0=pyhd8ed1ab_0 201 | - traitlets=5.0.5=py_0 202 | - typing_extensions=3.7.4.3=py_0 203 | - urllib3=1.26.2=pyhd3eb1b0_0 204 | - wcwidth=0.2.5=pyh9f0ad1d_2 205 | - webencodings=0.5.1=py_1 206 | - werkzeug=1.0.1=py_0 207 | - wheel=0.36.2=pyhd3eb1b0_0 208 | - wrapt=1.12.1=py38h7b6447c_1 209 | - xorg-libxau=1.0.9=h14c3975_0 210 | - xorg-libxdmcp=1.1.3=h516909a_0 211 | - xz=5.2.5=h7b6447c_0 212 | - yaml=0.2.5=h516909a_0 213 | - yarl=1.6.3=py38h27cfd23_0 214 | - zeromq=4.3.3=h58526e2_3 215 | - zipp=3.4.0=pyhd3eb1b0_0 216 | - zlib=1.2.11=h7b6447c_3 217 | - zstd=1.4.5=h9ceee32_0 218 | prefix: /home/roberto/anaconda3/envs/dl_ml 219 | 220 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class AgentVRP(): 4 | VEHICLE_CAPACITY = 1.0 5 | 6 | def __init__(self, input): 7 | depot = input[0] # (batch_size, 2) 8 | loc = input[1] # (batch_size, n_nodes, 2) 9 | self.demand = input[2] # (batch_size, n_nodes) 10 | 11 | self.batch_size, self.n_loc, _ = loc.shape 12 | 13 | # Coordinates of depot + other nodes -> (batch_size, 1+n_nodes, 2) 14 | self.coords = torch.cat((depot[:, None, :], loc), dim=-2) 15 | 16 | # Indices of graphs in batch 17 | self.ids = torch.arange(self.batch_size) # (batch_size) 18 | 19 | # State 20 | self.prev_a = torch.zeros(self.batch_size, 1, dtype=torch.int64) 21 | self.from_depot = self.prev_a == 0 22 | self.used_capacity = torch.zeros(self.batch_size, 1) 23 | 24 | # Nodes that have been visited will be marked with 1 25 | self.visited = torch.zeros(self.batch_size, 1, self.n_loc+1) 26 | 27 | # Step counter 28 | self.i = torch.zeros(1, dtype=torch.int64) 29 | 30 | @staticmethod 31 | def outer_pr(a, b): 32 | """ Outer product of a and b row vectors. 33 | result[k] = matmul( a[k].t(), b[k] ) 34 | """ 35 | return torch.einsum('ki,kj->kij', a, b) 36 | 37 | def get_att_mask(self): 38 | """ Mask (batchsize, n_nodes, n_nodes) for attention encoder. 39 | We maks alredy visited nodes except for depot (can be visited multiple times). 40 | 41 | True -> should mask (can NOT visit) 42 | False -> shouldn't mask (can visit) 43 | """ 44 | # Remove depot from mask (1st column) 45 | att_mask = torch.squeeze(self.visited, dim=-2)[:, 1:] # (batch_size, 1, n_nodes) -> (batch_size, n_nodes-1) 46 | 47 | # Number of nodes in new instance after masking 48 | cur_num_nodes = self.n_loc + 1 - att_mask.sum(dim=1, keepdims=True) # (batch_size, 1) 49 | 50 | att_mask = torch.cat((torch.zeros(att_mask.shape[0], 1), att_mask), dim=-1) # add depot -> (batch_size, n_nodes) 51 | 52 | ones_mask = torch.ones_like(att_mask) 53 | 54 | # Create square attention mask. 55 | # In a (n_nodes, n_nodes) matrix this masks all rows and columns of visited nodes 56 | att_mask = AgentVRP.outer_pr(att_mask, ones_mask) \ 57 | + AgentVRP.outer_pr(ones_mask, att_mask) \ 58 | - AgentVRP.outer_pr(att_mask, att_mask) # (batch_size, n_nodes, n_nodes) 59 | return att_mask == 1, cur_num_nodes 60 | 61 | def all_finished(self): 62 | """ Checks if all routes are finished 63 | """ 64 | return torch.all(self.visited == 1).item() 65 | 66 | def partial_finished(self): 67 | """Checks if partial solution for all graphs has been built; i.e. all agents came back to depot 68 | """ 69 | return (torch.all(self.from_depot == 1) and self.i != 0).item() 70 | 71 | def get_mask(self): 72 | """ Returns a mask (batch_size, 1, n_nodes) with available actions. 73 | Impossible nodes are masked. 74 | 75 | True -> should mask (can NOT visit) 76 | False -> shouldn't mask (can visit) 77 | """ 78 | 79 | # Exclude depot 80 | visited_loc = self.visited[:, :, 1:] 81 | 82 | # Mark nodes which exceed vehicle capacity 83 | exceeds_cap = self.demand + self.used_capacity > self.VEHICLE_CAPACITY 84 | 85 | # Maks nodes that area already visited or have too much demand or when they arrived to depot 86 | mask_loc = (visited_loc == 1) | (exceeds_cap[:, None, :]) \ 87 | | ((self.i > 0) & self.from_depot[:, None, :]) 88 | 89 | # We can choose depot if we are not in depot OR all nodes are visited 90 | # equivalent to: we mask the depot if we are in it AND there're still mode nodes to visit 91 | mask_depot = self.from_depot[:, None, :] & ((mask_loc == False).sum(dim=-1, keepdims=True) > 0) 92 | 93 | return torch.cat([mask_depot, mask_loc], dim=-1) 94 | 95 | def step(self, action): 96 | 97 | # Update current state 98 | selected = action[:, None] 99 | self.prev_a = selected 100 | self.from_depot = self.prev_a == 0 101 | 102 | # Shift indices by 1 since self.demand doesn't consider depot 103 | selected_demand = self.demand.gather(-1, (self.prev_a - 1).clamp_min(0).view(-1, 1)) # (batch_size, 1) 104 | 105 | # Add current node capacity to used capacity and set it to 0 if we return from depot 106 | self.used_capacity = (self.used_capacity + selected_demand) * (self.from_depot == False) 107 | 108 | # Update visited nodes (set 1 to visited nodes) 109 | self.visited[self.ids, [0], action] = 1 110 | 111 | self.i += 1 112 | 113 | @staticmethod 114 | def get_costs(dataset, pi): 115 | 116 | # Place nodes with coordinates in order of decoder tour 117 | loc_with_depot = torch.cat((dataset[0][:, None, :], dataset[1]), dim=1) # (batch_size, n_nodes, 2) 118 | d = loc_with_depot.gather(1, pi.view(pi.shape[0], -1, 1).repeat_interleave(2, -1)) 119 | 120 | # Calculation of total distance 121 | # Note: first element of pi is not depot, but the first selected node in path 122 | # and last element from longest path is not depot 123 | 124 | return ((torch.norm(d[:, 1:] - d[:, :-1], dim=-1)).sum(dim=-1) # intra node distances 125 | + (torch.norm(d[:, 0] - dataset[0], dim=-1)) # distance from depot to first 126 | + (torch.norm(d[:, -1] - dataset[0], dim=-1))) # distance from last node of longest path to depot -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | def scaled_attention(query, key, value, mask=None): 7 | """ Function that performs scaled attention given q, k, v and mask. 8 | q, k, v can have multiple batches and heads, defined across the first dimensions 9 | and the last 2 dimensions for a given sample of them are in row vector format. 10 | matmul is brodcasted across batches. 11 | """ 12 | qk = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.shape[-1]) 13 | if mask is not None: qk = qk.masked_fill(mask == 1, -1e9) 14 | qk = F.softmax(qk, dim=-1) 15 | return torch.matmul(qk, value) 16 | 17 | class MultiHeadAttention(nn.Module): 18 | """ Attention Layer - multi-head scaled dot product attention (for encoder and decoder) 19 | Observation: This MHA is currently implemented to only support singe-gpu machines 20 | 21 | Args: 22 | num_heads: number of attention heads which will be computed in parallel 23 | d_model: embedding size of output AND input features 24 | * in reality it shouldn't be neccesary that input and ouptut features are the same dimension 25 | but its the current case for this class. 26 | 27 | Call arguments: 28 | q: query, shape (..., seq_len_q, depth_q) 29 | k: key, shape == (..., seq_len_k, depth_k) 30 | v: value, shape == (..., seq_len_v, depth_v) 31 | mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k) or None. 32 | 33 | Since we use scaled-product attention, we assume seq_len_k = seq_len_v 34 | 35 | Returns: 36 | attention outputs of shape (batch_size, seq_len_q, d_model) 37 | """ 38 | def __init__(self, n_heads, d_model, **kwargs): 39 | super(MultiHeadAttention, self).__init__() 40 | self.n_heads, self.d_model = n_heads, d_model 41 | self.head_depth = self.d_model // self.n_heads 42 | 43 | assert self.d_model % self.n_heads == 0 44 | 45 | # define weight matrices 46 | self.wq = nn.Linear(self.d_model, self.d_model, bias=False) 47 | self.wk = nn.Linear(self.d_model, self.d_model, bias=False) 48 | self.wv = nn.Linear(self.d_model, self.d_model, bias=False) 49 | 50 | self.w_out = nn.Linear(self.d_model, self.d_model, bias=False) 51 | 52 | def split_heads(self, tensor, batch_size): 53 | """ Function that splits the heads. This happens in the same tensor since this class doesn't 54 | support multiple-gpu. Observe inline comments for more details on shapes. 55 | """ 56 | # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads, head_depth) 57 | splitted_tensor = tensor.view(batch_size, -1, self.n_heads, self.head_depth) 58 | return splitted_tensor.transpose(1, 2) # (batch_size, n_heads, seq_len, head_depth) 59 | 60 | def forward(self, query, key, value, mask=None): 61 | # shape of q: (batch_size, seq_len_q, d_query) 62 | batch_size = query.shape[0] 63 | 64 | # project query, key and value to d_model dimensional space 65 | # this is equivalent to projecting them each to a head_depth dimensional space (for every head) 66 | # but with a single matrix 67 | Q = self.wq(query) # (batch_size, seq_len_q, d_query) -> (batch_size, seq_len_q, d_model) 68 | K = self.wk(key) # ... -> (batch_size, seq_len_k, d_model) 69 | V = self.wv(value) # ... -> (batch_size, seq_len_v, d_model) 70 | 71 | # split individual heads 72 | Q = self.split_heads(Q, batch_size) # ... -> (batch_size, n_heads, seq_len_q, head_depth) 73 | K = self.split_heads(K, batch_size) # ... -> (batch_size, n_heads, seq_len_k, head_depth) 74 | V = self.split_heads(V, batch_size) # ... -> (batch_size, n_heads, seq_len_v, head_depth) 75 | 76 | 77 | # Add dimension to mask so that it can be broadcasted across heads 78 | # (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k) 79 | if mask is not None: 80 | mask = mask.unsqueeze(1) 81 | 82 | # perform attention for each q=(seq_len_q, head_depth), k=(seq_len_k, head_depth), v=(seq_len_v, head_depth) 83 | attention = scaled_attention(Q, K, V, mask) # (batch_size, n_heads, seq_len_q, head_depth) 84 | # transpose attention to (batch_size, seq_len_q, n_heads, head_depth) 85 | attention = attention.transpose(1, 2).contiguous() 86 | # concatenate results of all heads (batch_size, seq_len_q, self.d_model) 87 | attention = attention.view(batch_size, -1, self.d_model) 88 | 89 | # project attention to same dimension; observe this is equivalent to summing individual projection 90 | # as sugested in paper 91 | output = self.w_out(attention) # (batch_size, seq_len_q, d_model) 92 | 93 | return output -------------------------------------------------------------------------------- /learning_curve_plot_VRP_20_2021-08-31_start=0, end=40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roberto09/Dynamic-Attention-Model-for-VRP---Pytorch/47a2a1219f1890972f9393f0d193b26bb8b8888e/learning_curve_plot_VRP_20_2021-08-31_start=0, end=40.jpg -------------------------------------------------------------------------------- /reinforce_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.stats import ttest_rel 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | from attention_dynamic_model import AttentionDynamicModel 7 | from attention_dynamic_model import set_decode_type 8 | from utils import generate_data_onfly, FastTensorDataLoader, get_dev_of_mod 9 | 10 | 11 | def copy_of_pt_model(model, embedding_dim=128, graph_size=20): 12 | """Copy model weights to new model 13 | """ 14 | CAPACITIES = {10: 20., 15 | 20: 30., 16 | 50: 40., 17 | 100: 50. 18 | } 19 | 20 | data_random = [torch.rand((2, 2,), dtype=torch.float32), 21 | torch.rand((2, graph_size, 2), dtype=torch.float32), 22 | torch.randint(low=1, high= 10, size=(2, graph_size), dtype=torch.float32)/CAPACITIES[graph_size]] 23 | 24 | new_model = AttentionDynamicModel(embedding_dim).to(get_dev_of_mod(model)) 25 | set_decode_type(new_model, "sampling") 26 | new_model.eval() 27 | with torch.no_grad(): 28 | _, _ = new_model(data_random) 29 | model_dict = model.state_dict() 30 | new_model.load_state_dict(model_dict) 31 | 32 | new_model.eval() 33 | 34 | return new_model 35 | 36 | def get_costs_rollout(model, train_batches, disable_tqdm): 37 | costs_list = [] 38 | for batch in tqdm(train_batches, disable=disable_tqdm, desc="Rollout greedy execution"): 39 | cost, _ = model(batch) 40 | costs_list.append(cost) 41 | # torch.cuda.empty_cache() 42 | return costs_list 43 | 44 | def rollout(model, dataset, batch_size = 1000, disable_tqdm = False): 45 | # Evaluate model in greedy mode 46 | set_decode_type(model, "greedy") 47 | 48 | train_batches = FastTensorDataLoader(dataset[0],dataset[1],dataset[2], batch_size=batch_size, shuffle=False) 49 | 50 | model_was_training = model.training 51 | model.eval() 52 | 53 | with torch.no_grad(): 54 | costs_list = get_costs_rollout(model, train_batches, disable_tqdm) 55 | 56 | if model_was_training: model.train() # restore original model training state 57 | 58 | return torch.cat(costs_list, dim=0) 59 | 60 | 61 | def validate(dataset, model, batch_size=1000): 62 | """Validates model on given dataset in greedy mode 63 | """ 64 | # rollout will set the model to eval mode and turn it back to it's original mode after it finishes 65 | val_costs = rollout(model, dataset, batch_size=batch_size) 66 | set_decode_type(model, "sampling") 67 | mean_cost = torch.mean(val_costs) 68 | print(f"Validation score: {np.round(mean_cost, 4)}") 69 | return mean_cost 70 | 71 | 72 | class RolloutBaseline: 73 | 74 | def __init__(self, model, filename, 75 | from_checkpoint=False, 76 | path_to_checkpoint=None, 77 | wp_n_epochs=1, 78 | epoch=0, 79 | num_samples=10000, 80 | warmup_exp_beta=0.8, 81 | embedding_dim=128, 82 | graph_size=20 83 | ): 84 | """ 85 | Args: 86 | model: current model 87 | filename: suffix for baseline checkpoint filename 88 | from_checkpoint: start from checkpoint flag 89 | path_to_checkpoint: path to baseline model weights 90 | wp_n_epochs: number of warm-up epochs 91 | epoch: current epoch number 92 | num_samples: number of samples to be generated for baseline dataset 93 | warmup_exp_beta: warmup mixing parameter (exp. moving average parameter) 94 | 95 | """ 96 | 97 | self.num_samples = num_samples 98 | self.cur_epoch = epoch 99 | self.wp_n_epochs = wp_n_epochs 100 | self.beta = warmup_exp_beta 101 | 102 | # controls the amount of warmup 103 | self.alpha = 0.0 104 | 105 | self.running_average_cost = None 106 | 107 | # Checkpoint params 108 | self.filename = filename 109 | self.from_checkpoint = from_checkpoint 110 | self.path_to_checkpoint = path_to_checkpoint 111 | 112 | # Problem params 113 | self.embedding_dim = embedding_dim 114 | self.graph_size = graph_size 115 | 116 | # create and evaluate initial baseline 117 | self._update_baseline(model, epoch) 118 | 119 | 120 | def _update_baseline(self, model, epoch): 121 | 122 | # Load or copy baseline model based on self.from_checkpoint condition 123 | if self.from_checkpoint and self.alpha == 0: 124 | print('Baseline model loaded') 125 | self.model = load_pt_model(self.path_to_checkpoint, 126 | embedding_dim=self.embedding_dim, 127 | graph_size=self.graph_size, 128 | device = get_dev_of_mod(model)) 129 | self.model.eval() 130 | else: 131 | self.model = copy_of_pt_model(model, 132 | embedding_dim=self.embedding_dim, 133 | graph_size=self.graph_size) 134 | self.model.eval() 135 | torch.save(self.model.state_dict(),'./checkpts/baseline_checkpoint_epoch_{}_{}'.format(epoch, self.filename)) 136 | 137 | # We generate a new dataset for baseline model on each baseline update to prevent possible overfitting 138 | self.dataset = generate_data_onfly(num_samples=self.num_samples, graph_size=self.graph_size) 139 | 140 | print(f"Evaluating baseline model on baseline dataset (epoch = {epoch})") 141 | self.bl_vals = rollout(self.model, self.dataset) 142 | self.mean = torch.mean(self.bl_vals) 143 | self.cur_epoch = epoch 144 | 145 | def ema_eval(self, cost): 146 | """This is running average of cost through previous batches (only for warm-up epochs) 147 | """ 148 | 149 | if self.running_average_cost is None: 150 | self.running_average_cost = torch.mean(cost) 151 | else: 152 | self.running_average_cost = self.beta * self.running_average_cost + (1. - self.beta) * torch.mean(cost) 153 | 154 | return self.running_average_cost 155 | 156 | def eval(self, batch, cost): 157 | """Evaluates current baseline model on single training batch 158 | """ 159 | 160 | if self.alpha == 0: 161 | return self.ema_eval(cost) 162 | 163 | if self.alpha < 1: 164 | v_ema = self.ema_eval(cost) 165 | else: 166 | v_ema = torch.tensor(0.0) 167 | 168 | with torch.no_grad(): 169 | v_b, _ = self.model(batch) 170 | 171 | # Combination of baseline cost and exp. moving average cost 172 | return self.alpha * v_b.detach() + (1 - self.alpha) * v_ema.detach() 173 | 174 | def eval_all(self, dataset): 175 | """Evaluates current baseline model on the whole dataset only for non warm-up epochs 176 | """ 177 | 178 | if self.alpha < 1: 179 | return None 180 | 181 | val_costs = rollout(self.model, dataset, batch_size=2048) 182 | 183 | return val_costs 184 | 185 | def epoch_callback(self, model, epoch): 186 | """Compares current baseline model with the training model and updates baseline if it is improved 187 | """ 188 | 189 | self.cur_epoch = epoch 190 | 191 | print(f"Evaluating candidate model on baseline dataset (callback epoch = {self.cur_epoch})") 192 | candidate_vals = rollout(model, self.dataset) # costs for training model on baseline dataset 193 | candidate_mean = torch.mean(candidate_vals) 194 | 195 | diff = candidate_mean - self.mean 196 | 197 | print(f"Epoch {self.cur_epoch} candidate mean {candidate_mean}, baseline epoch {self.cur_epoch} mean {self.mean}, difference {diff}") 198 | 199 | if diff < 0: 200 | # statistic + p-value 201 | t, p = ttest_rel(candidate_vals, self.bl_vals) 202 | 203 | p_val = p / 2 204 | print(f"p-value: {p_val}") 205 | 206 | if p_val < 0.05: 207 | print('Update baseline') 208 | self._update_baseline(model, self.cur_epoch) 209 | 210 | # alpha controls the amount of warmup 211 | if self.alpha < 1.0: 212 | self.alpha = (self.cur_epoch + 1) / float(self.wp_n_epochs) 213 | print(f"alpha was updated to {self.alpha}") 214 | 215 | 216 | def load_pt_model(path, embedding_dim=128, graph_size=20, n_encode_layers=2, device='cpu'): 217 | """Load model weights from hd5 file 218 | """ 219 | CAPACITIES = {10: 20., 220 | 20: 30., 221 | 50: 40., 222 | 100: 50. 223 | } 224 | 225 | data_random = [torch.rand((2, 2,), dtype=torch.float32), 226 | torch.rand((2, graph_size, 2), dtype=torch.float32), 227 | torch.randint(low=1, high= 10, size=(2, graph_size), dtype=torch.float32)/CAPACITIES[graph_size]] 228 | 229 | model_loaded = AttentionDynamicModel(embedding_dim,n_encode_layers=n_encode_layers).to(device) 230 | 231 | set_decode_type(model_loaded, "greedy") 232 | _, _ = model_loaded(data_random) 233 | 234 | model_loaded.load_state_dict(torch.load(path)) 235 | 236 | return model_loaded 237 | -------------------------------------------------------------------------------- /test_results_VRP_20_2021-08-31_start=0, end=40.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roberto09/Dynamic-Attention-Model-for-VRP---Pytorch/47a2a1219f1890972f9393f0d193b26bb8b8888e/test_results_VRP_20_2021-08-31_start=0, end=40.xlsx -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pandas as pd 3 | import torch 4 | 5 | from attention_dynamic_model import set_decode_type 6 | from reinforce_baseline import validate 7 | 8 | from utils import generate_data_onfly, get_results, get_cur_time 9 | from time import gmtime, strftime 10 | from utils import FastTensorDataLoader, get_dev_of_mod 11 | 12 | class IterativeMean(): 13 | def __init__(self): 14 | self.sum = 0 15 | self.n = 0 16 | 17 | def update_state(self, val): 18 | self.sum += val 19 | self.n += 1 20 | 21 | def result(self): 22 | return self.sum / self.n 23 | 24 | def train_model(optimizer, 25 | model_torch, 26 | baseline, 27 | validation_dataset, 28 | samples = 1280000, 29 | batch = 128, 30 | val_batch_size = 1000, 31 | start_epoch = 0, 32 | end_epoch = 5, 33 | from_checkpoint = False, 34 | grad_norm_clipping = 1.0, 35 | batch_verbose = 1000, 36 | graph_size = 20, 37 | filename = None, 38 | mem_efficient=True, 39 | ): 40 | 41 | if filename is None: 42 | filename = 'VRP_{}_{}'.format(graph_size, strftime("%Y-%m-%d", gmtime())) 43 | 44 | def rein_loss(model, inputs, baseline, num_batch): 45 | """Calculate loss for REINFORCE algorithm 46 | """ 47 | 48 | # Evaluate model, get costs and log probabilities 49 | cost, log_likelihood = model(inputs) 50 | 51 | # Evaluate baseline 52 | # For first wp_n_epochs we take the combination of baseline and ema for previous batches 53 | # after that we take a slice of precomputed baseline values 54 | bl_val = bl_vals[num_batch] if bl_vals is not None else baseline.eval(inputs, cost) 55 | 56 | # Calculate loss 57 | reinforce_loss = torch.mean((cost - bl_val.detach()) * log_likelihood) 58 | 59 | return reinforce_loss, torch.mean(cost) 60 | 61 | def mem_efficient_rein_loss(model, inputs, baseline, num_batch): 62 | rein_detached_loss, (cost, log_likelihood) = model.fwd_rein_loss(inputs, baseline, bl_vals, num_batch) 63 | return rein_detached_loss, torch.mean(cost) 64 | 65 | def grad(model, inputs, baseline, num_batch): 66 | """Calculate gradients 67 | """ 68 | if mem_efficient: 69 | loss, cost = mem_efficient_rein_loss(model, inputs, baseline, num_batch) 70 | else: 71 | loss, cost = rein_loss(model, inputs, baseline, num_batch) 72 | loss.backward() 73 | grads = [param.grad.view(-1) for param in model.parameters()] 74 | grads = torch.cat(grads) 75 | # we can return detached loss since it's backwarded already above 76 | return loss.detach(), cost, grads 77 | 78 | # For plotting 79 | train_loss_results = [] 80 | train_cost_results = [] 81 | val_cost_avg = [] 82 | 83 | # Training loop 84 | for epoch in range(start_epoch, end_epoch): 85 | 86 | # Create dataset on current epoch 87 | data = generate_data_onfly(num_samples=samples, graph_size=graph_size) 88 | 89 | epoch_loss_avg = IterativeMean() 90 | epoch_cost_avg = IterativeMean() 91 | 92 | # Skip warm-up stage when we continue training from checkpoint 93 | if from_checkpoint and baseline.alpha != 1.0: 94 | print('Skipping warm-up mode') 95 | baseline.alpha = 1.0 96 | 97 | # If epoch > wp_n_epochs then precompute baseline values for the whole dataset else None 98 | bl_vals = baseline.eval_all(data) # (samples, ) or None 99 | bl_vals = torch.reshape(bl_vals, (-1, batch)) if bl_vals is not None else None # (n_batches, batch) or None 100 | 101 | print("Current decode type: {}".format(model_torch.decode_type)) 102 | 103 | train_batches = FastTensorDataLoader(data[0],data[1],data[2], batch_size=batch, shuffle=False) 104 | 105 | for num_batch, x_batch in tqdm(enumerate(train_batches), desc="batch calculation at epoch {}".format(epoch)): 106 | optimizer.zero_grad() 107 | # Optimize the model 108 | loss_value, cost_val, grads = grad(model_torch, x_batch, baseline, num_batch) 109 | 110 | # Clip gradients by grad_norm_clipping 111 | init_global_norm = torch.linalg.norm(grads) 112 | torch.nn.utils.clip_grad_norm_(model_torch.parameters(), grad_norm_clipping) 113 | grads = [param.grad.view(-1) for param in model_torch.parameters()] 114 | grads = torch.cat(grads) 115 | global_norm = torch.linalg.norm(grads) 116 | 117 | if num_batch%batch_verbose == 0: 118 | print("grad_global_norm = {}, clipped_norm = {}".format(init_global_norm.cpu().numpy(), global_norm.cpu().numpy())) 119 | 120 | optimizer.step() 121 | 122 | # Track progress 123 | epoch_loss_avg.update_state(loss_value) 124 | epoch_cost_avg.update_state(cost_val) 125 | 126 | if num_batch%batch_verbose == 0: 127 | print("Epoch {} (batch = {}): Loss: {}: Cost: {}".format(epoch, num_batch, epoch_loss_avg.result(), epoch_cost_avg.result())) 128 | 129 | # Update baseline if the candidate model is good enough. In this case also create new baseline dataset 130 | baseline.epoch_callback(model_torch, epoch) 131 | set_decode_type(model_torch, "sampling") 132 | 133 | # Save model weights 134 | torch.save(model_torch.state_dict(),'./checkpts/model_checkpoint_epoch_{}_{}'.format(epoch, filename)) 135 | 136 | # Validate current model 137 | val_cost = validate(validation_dataset, model_torch, val_batch_size) 138 | val_cost_avg.append(val_cost) 139 | 140 | train_loss_results.append(epoch_loss_avg.result()) 141 | train_cost_results.append(epoch_cost_avg.result()) 142 | 143 | pd.DataFrame(data={'epochs': list(range(start_epoch, epoch+1)), 144 | 'train_loss': [x.cpu().numpy() for x in train_loss_results], 145 | 'train_cost': [x.cpu().numpy() for x in train_cost_results], 146 | 'val_cost': [x.cpu().numpy() for x in val_cost_avg] 147 | }).to_csv('backup_results_' + filename + '.csv', index=False) 148 | 149 | print(get_cur_time(), "Epoch {}: Loss: {}: Cost: {}".format(epoch, epoch_loss_avg.result(), epoch_cost_avg.result())) 150 | 151 | # Make plots and save results 152 | filename_for_results = filename + '_start={}, end={}'.format(start_epoch, end_epoch) 153 | get_results([x.numpy() for x in train_loss_results], 154 | [x.numpy() for x in train_cost_results], 155 | [x.numpy() for x in val_cost_avg], 156 | save_results=True, 157 | filename=filename_for_results, 158 | plots=True) 159 | 160 | -------------------------------------------------------------------------------- /train_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from attention_dynamic_model import AttentionDynamicModel, set_decode_type\n", 11 | "from reinforce_baseline import RolloutBaseline\n", 12 | "from train import train_model\n", 13 | "\n", 14 | "from time import strftime, gmtime\n", 15 | "from utils import create_data_on_disk, get_cur_time" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "* change batch sizes" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Params of model\n", 32 | "SAMPLES = 512# 128*10000\n", 33 | "BATCH = 128\n", 34 | "START_EPOCH = 0\n", 35 | "END_EPOCH = 10\n", 36 | "FROM_CHECKPOINT = False\n", 37 | "embedding_dim = 128\n", 38 | "LEARNING_RATE = 0.0001\n", 39 | "ROLLOUT_SAMPLES = 10000\n", 40 | "NUMBER_OF_WP_EPOCHS = 1\n", 41 | "GRAD_NORM_CLIPPING = 1.0\n", 42 | "BATCH_VERBOSE = 1000\n", 43 | "VAL_BATCH_SIZE = 1000\n", 44 | "VALIDATE_SET_SIZE = 10000\n", 45 | "SEED = 1234\n", 46 | "GRAPH_SIZE = 50\n", 47 | "FILENAME = 'VRP_{}_{}'.format(GRAPH_SIZE, strftime(\"%Y-%m-%d\", gmtime()))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "2021-03-26 18:56:07 model initialized\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "# Initialize model\n", 65 | "model_pt = AttentionDynamicModel(embedding_dim).cuda()\n", 66 | "set_decode_type(model_pt, \"sampling\")\n", 67 | "print(get_cur_time(), 'model initialized')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "2021-03-26 18:56:07 validation dataset created and saved on the disk\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "# Create and save validation dataset\n", 85 | "validation_dataset = create_data_on_disk(GRAPH_SIZE,\n", 86 | " VALIDATE_SET_SIZE,\n", 87 | " is_save=True,\n", 88 | " filename=FILENAME,\n", 89 | " is_return=True,\n", 90 | " seed = SEED)\n", 91 | "print(get_cur_time(), 'validation dataset created and saved on the disk')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Initialize optimizer\n", 101 | "optimizer = torch.optim.Adam(params=model_pt.parameters(), lr=LEARNING_RATE)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": { 108 | "scrolled": true 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stderr", 113 | "output_type": "stream", 114 | "text": [ 115 | "\r", 116 | "Rollout greedy execution: 0%| | 0/10 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Initialize baseline\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m baseline = RolloutBaseline(model_pt,\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mwp_n_epochs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNUMBER_OF_WP_EPOCHS\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mROLLOUT_SAMPLES\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 141 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/reinforce_baseline.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, filename, from_checkpoint, path_to_checkpoint, wp_n_epochs, epoch, num_samples, warmup_exp_beta, embedding_dim, graph_size)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;31m# create and evaluate initial baseline\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_baseline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 142 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/reinforce_baseline.py\u001b[0m in \u001b[0;36m_update_baseline\u001b[0;34m(self, model, epoch)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Evaluating baseline model on baseline dataset (epoch = {epoch})\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbl_vals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrollout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbl_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcur_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 143 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/reinforce_baseline.py\u001b[0m in \u001b[0;36mrollout\u001b[0;34m(model, dataset, batch_size, disable_tqdm)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mcosts_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_costs_rollout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable_tqdm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmodel_was_training\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# restore original model training state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 144 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/reinforce_baseline.py\u001b[0m in \u001b[0;36mget_costs_rollout\u001b[0;34m(model, train_batches, disable_tqdm)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mcosts_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdisable_tqdm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Rollout greedy execution\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0mcost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0mcosts_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcost\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;31m# torch.cuda.empty_cache()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 145 | "\u001b[0;32m~/anaconda3/envs/dl_ml/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 146 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/attention_dynamic_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs, return_pi)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_finished\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mstep_context\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membeddings\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (batch_size, 1, input_dim + 1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0mQ_step_context\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwq_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_context\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (batch_size, 1, output_dim)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mQ\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mQ_context\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mQ_step_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 147 | "\u001b[0;32m~/Desktop/projects/papers/dvrp/solving_dvrp_with_transformers/attention_dynamic_model.py\u001b[0m in \u001b[0;36mget_step_context\u001b[0;34m(self, state, embeddings)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;31m# from embeddings=(batch_size, n_nodes, input_dim) select embeddings of previous nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m cur_embedded_node = embeddings.gather(1, prev_node.view(prev_node.shape[0], -1, 1)\n\u001b[0m\u001b[1;32m 112\u001b[0m .repeat_interleave(embeddings.shape[-1], -1)) # (batch_size, 1, input_dim)\n\u001b[1;32m 113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 148 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "# Initialize baseline\n", 154 | "baseline = RolloutBaseline(model_pt,\n", 155 | " wp_n_epochs = NUMBER_OF_WP_EPOCHS,\n", 156 | " epoch = 0,\n", 157 | " num_samples=ROLLOUT_SAMPLES,\n", 158 | " filename = FILENAME,\n", 159 | " from_checkpoint = FROM_CHECKPOINT,\n", 160 | " embedding_dim=embedding_dim,\n", 161 | " graph_size=GRAPH_SIZE\n", 162 | " )\n", 163 | "print(get_cur_time(), 'baseline initialized')" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "torch.cuda.empty_cache()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "scrolled": false 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "%%time\n", 184 | "train_model(optimizer,\n", 185 | " model_pt,\n", 186 | " baseline,\n", 187 | " validation_dataset,\n", 188 | " samples = SAMPLES,\n", 189 | " batch = BATCH,\n", 190 | " val_batch_size = VAL_BATCH_SIZE,\n", 191 | " start_epoch = START_EPOCH,\n", 192 | " end_epoch = END_EPOCH,\n", 193 | " from_checkpoint = FROM_CHECKPOINT,\n", 194 | " grad_norm_clipping = GRAD_NORM_CLIPPING,\n", 195 | " batch_verbose = BATCH_VERBOSE,\n", 196 | " graph_size = GRAPH_SIZE,\n", 197 | " filename = FILENAME\n", 198 | " )" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.8.5" 226 | }, 227 | "varInspector": { 228 | "cols": { 229 | "lenName": 16, 230 | "lenType": 16, 231 | "lenVar": 40 232 | }, 233 | "kernels_config": { 234 | "python": { 235 | "delete_cmd_postfix": "", 236 | "delete_cmd_prefix": "del ", 237 | "library": "var_list.py", 238 | "varRefreshCmd": "print(var_dic_list())" 239 | }, 240 | "r": { 241 | "delete_cmd_postfix": ") ", 242 | "delete_cmd_prefix": "rm(", 243 | "library": "var_list.r", 244 | "varRefreshCmd": "cat(var_dic_list()) " 245 | } 246 | }, 247 | "types_to_exclude": [ 248 | "module", 249 | "function", 250 | "builtin_function_or_method", 251 | "instance", 252 | "_Feature" 253 | ], 254 | "window_display": false 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 4 259 | } 260 | -------------------------------------------------------------------------------- /train_model_loading_from_past.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "known-strip", 7 | "metadata": { 8 | "scrolled": false 9 | }, 10 | "outputs": [ 11 | { 12 | "name": "stdout", 13 | "output_type": "stream", 14 | "text": [ 15 | "2021-03-26 19:10:11 model initialized\n", 16 | "2021-03-26 19:10:11 validation dataset created and saved on the disk\n" 17 | ] 18 | }, 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "\r", 24 | "Rollout greedy execution: 0%| | 0/10 [00:00" 455 | ] 456 | }, 457 | "metadata": { 458 | "needs_background": "light" 459 | }, 460 | "output_type": "display_data" 461 | } 462 | ], 463 | "source": [ 464 | "import torch\n", 465 | "from time import gmtime, strftime\n", 466 | "\n", 467 | "from attention_dynamic_model import AttentionDynamicModel, set_decode_type\n", 468 | "from reinforce_baseline import RolloutBaseline\n", 469 | "from train import train_model\n", 470 | "\n", 471 | "from utils import create_data_on_disk, get_cur_time\n", 472 | "\n", 473 | "# Params of model\n", 474 | "SAMPLES = 512 # 128*10000\n", 475 | "BATCH = 128\n", 476 | "START_EPOCH = 0\n", 477 | "END_EPOCH = 5\n", 478 | "FROM_CHECKPOINT = False\n", 479 | "embedding_dim = 128\n", 480 | "LEARNING_RATE = 0.0001\n", 481 | "ROLLOUT_SAMPLES = 10000\n", 482 | "NUMBER_OF_WP_EPOCHS = 1\n", 483 | "GRAD_NORM_CLIPPING = 1.0\n", 484 | "BATCH_VERBOSE = 1000\n", 485 | "VAL_BATCH_SIZE = 1000\n", 486 | "VALIDATE_SET_SIZE = 10000\n", 487 | "SEED = 1234\n", 488 | "GRAPH_SIZE = 50\n", 489 | "FILENAME = 'VRP_{}_{}'.format(GRAPH_SIZE, strftime(\"%Y-%m-%d\", gmtime()))\n", 490 | "\n", 491 | "# Initialize model\n", 492 | "model = AttentionDynamicModel(embedding_dim).cuda()\n", 493 | "set_decode_type(model, \"sampling\")\n", 494 | "print(get_cur_time(), 'model initialized')\n", 495 | "\n", 496 | "# Create and save validation dataset\n", 497 | "validation_dataset = create_data_on_disk(GRAPH_SIZE,\n", 498 | " VALIDATE_SET_SIZE,\n", 499 | " is_save=True,\n", 500 | " filename=FILENAME,\n", 501 | " is_return=True,\n", 502 | " seed = SEED)\n", 503 | "print(get_cur_time(), 'validation dataset created and saved on the disk')\n", 504 | "\n", 505 | "# Initialize optimizer\n", 506 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n", 507 | "\n", 508 | "# Initialize baseline\n", 509 | "baseline = RolloutBaseline(model,\n", 510 | " wp_n_epochs = NUMBER_OF_WP_EPOCHS,\n", 511 | " epoch = 0,\n", 512 | " num_samples=ROLLOUT_SAMPLES,\n", 513 | " filename = FILENAME,\n", 514 | " from_checkpoint = FROM_CHECKPOINT,\n", 515 | " embedding_dim=embedding_dim,\n", 516 | " graph_size=GRAPH_SIZE\n", 517 | " )\n", 518 | "print(get_cur_time(), 'baseline initialized')\n", 519 | "\n", 520 | "train_model(optimizer,\n", 521 | " model,\n", 522 | " baseline,\n", 523 | " validation_dataset,\n", 524 | " samples = SAMPLES,\n", 525 | " batch = BATCH,\n", 526 | " val_batch_size = VAL_BATCH_SIZE,\n", 527 | " start_epoch = START_EPOCH,\n", 528 | " end_epoch = END_EPOCH,\n", 529 | " from_checkpoint = FROM_CHECKPOINT,\n", 530 | " grad_norm_clipping = GRAD_NORM_CLIPPING,\n", 531 | " batch_verbose = BATCH_VERBOSE,\n", 532 | " graph_size = GRAPH_SIZE,\n", 533 | " filename = FILENAME\n", 534 | " )" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 2, 540 | "id": "modular-maria", 541 | "metadata": { 542 | "scrolled": false 543 | }, 544 | "outputs": [ 545 | { 546 | "name": "stderr", 547 | "output_type": "stream", 548 | "text": [ 549 | "\r", 550 | "Rollout greedy execution: 0%| | 0/10 [00:00" 968 | ] 969 | }, 970 | "metadata": { 971 | "needs_background": "light" 972 | }, 973 | "output_type": "display_data" 974 | } 975 | ], 976 | "source": [ 977 | "import torch\n", 978 | "from time import gmtime, strftime\n", 979 | "\n", 980 | "from attention_dynamic_model import set_decode_type\n", 981 | "from reinforce_baseline import RolloutBaseline\n", 982 | "from train import train_model\n", 983 | "\n", 984 | "from utils import get_cur_time\n", 985 | "from reinforce_baseline import load_pt_model\n", 986 | "from utils import read_from_pickle\n", 987 | "\n", 988 | "\n", 989 | "SAMPLES = 512 # 128*10000\n", 990 | "BATCH = 128\n", 991 | "LEARNING_RATE = 0.0001\n", 992 | "ROLLOUT_SAMPLES = 10000\n", 993 | "NUMBER_OF_WP_EPOCHS = 1\n", 994 | "GRAD_NORM_CLIPPING = 1.0\n", 995 | "BATCH_VERBOSE = 1000\n", 996 | "VAL_BATCH_SIZE = 1000\n", 997 | "VALIDATE_SET_SIZE = 10000\n", 998 | "SEED = 1234\n", 999 | "GRAPH_SIZE = 50\n", 1000 | "FILENAME = 'VRP_{}_{}'.format(GRAPH_SIZE, strftime(\"%Y-%m-%d\", gmtime()))\n", 1001 | "\n", 1002 | "START_EPOCH = 5\n", 1003 | "END_EPOCH = 10\n", 1004 | "FROM_CHECKPOINT = True\n", 1005 | "embedding_dim = 128\n", 1006 | "MODEL_PATH = 'checkpts/model_checkpoint_epoch_4_VRP_50_2021-03-27'\n", 1007 | "VAL_SET_PATH = 'valsets/Validation_dataset_VRP_50_2021-03-27.pkl'\n", 1008 | "BASELINE_MODEL_PATH = 'checkpts/baseline_checkpoint_epoch_3_VRP_50_2021-03-27'\n", 1009 | "\n", 1010 | "# Initialize model\n", 1011 | "model = load_pt_model(MODEL_PATH,\n", 1012 | " embedding_dim=embedding_dim,\n", 1013 | " graph_size=GRAPH_SIZE, device='cuda')\n", 1014 | "set_decode_type(model, \"sampling\")\n", 1015 | "print(get_cur_time(), 'model loaded')\n", 1016 | "\n", 1017 | "# Create and save validation dataset\n", 1018 | "validation_dataset = read_from_pickle(VAL_SET_PATH)\n", 1019 | "print(get_cur_time(), 'validation dataset loaded')\n", 1020 | "\n", 1021 | "# Initialize optimizer\n", 1022 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n", 1023 | "\n", 1024 | "# Initialize baseline\n", 1025 | "baseline = RolloutBaseline(model,\n", 1026 | " wp_n_epochs = NUMBER_OF_WP_EPOCHS,\n", 1027 | " epoch = START_EPOCH,\n", 1028 | " num_samples=ROLLOUT_SAMPLES,\n", 1029 | " filename = FILENAME,\n", 1030 | " from_checkpoint = FROM_CHECKPOINT,\n", 1031 | " embedding_dim=embedding_dim,\n", 1032 | " graph_size=GRAPH_SIZE,\n", 1033 | " path_to_checkpoint = BASELINE_MODEL_PATH)\n", 1034 | "print(get_cur_time(), 'baseline initialized')\n", 1035 | "train_model(optimizer,\n", 1036 | " model,\n", 1037 | " baseline,\n", 1038 | " validation_dataset,\n", 1039 | " samples = SAMPLES,\n", 1040 | " batch = BATCH,\n", 1041 | " val_batch_size = VAL_BATCH_SIZE,\n", 1042 | " start_epoch = START_EPOCH,\n", 1043 | " end_epoch = END_EPOCH,\n", 1044 | " from_checkpoint = FROM_CHECKPOINT,\n", 1045 | " grad_norm_clipping = GRAD_NORM_CLIPPING,\n", 1046 | " batch_verbose = BATCH_VERBOSE,\n", 1047 | " graph_size = GRAPH_SIZE,\n", 1048 | " filename = FILENAME\n", 1049 | " )" 1050 | ] 1051 | } 1052 | ], 1053 | "metadata": { 1054 | "kernelspec": { 1055 | "display_name": "Python 3", 1056 | "language": "python", 1057 | "name": "python3" 1058 | }, 1059 | "language_info": { 1060 | "codemirror_mode": { 1061 | "name": "ipython", 1062 | "version": 3 1063 | }, 1064 | "file_extension": ".py", 1065 | "mimetype": "text/x-python", 1066 | "name": "python", 1067 | "nbconvert_exporter": "python", 1068 | "pygments_lexer": "ipython3", 1069 | "version": "3.8.5" 1070 | }, 1071 | "varInspector": { 1072 | "cols": { 1073 | "lenName": 16, 1074 | "lenType": 16, 1075 | "lenVar": 40 1076 | }, 1077 | "kernels_config": { 1078 | "python": { 1079 | "delete_cmd_postfix": "", 1080 | "delete_cmd_prefix": "del ", 1081 | "library": "var_list.py", 1082 | "varRefreshCmd": "print(var_dic_list())" 1083 | }, 1084 | "r": { 1085 | "delete_cmd_postfix": ") ", 1086 | "delete_cmd_prefix": "rm(", 1087 | "library": "var_list.r", 1088 | "varRefreshCmd": "cat(var_dic_list()) " 1089 | } 1090 | }, 1091 | "types_to_exclude": [ 1092 | "module", 1093 | "function", 1094 | "builtin_function_or_method", 1095 | "instance", 1096 | "_Feature" 1097 | ], 1098 | "window_display": false 1099 | } 1100 | }, 1101 | "nbformat": 4, 1102 | "nbformat_minor": 5 1103 | } 1104 | -------------------------------------------------------------------------------- /train_results_VRP_20_2021-08-31_start=0, end=40.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Roberto09/Dynamic-Attention-Model-for-VRP---Pytorch/47a2a1219f1890972f9393f0d193b26bb8b8888e/train_results_VRP_20_2021-08-31_start=0, end=40.xlsx -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import plotly.graph_objects as go 7 | import numpy as np 8 | from datetime import datetime 9 | import time 10 | 11 | CAPACITIES = { 12 | 10: 20., 13 | 20: 30., 14 | 50: 40., 15 | 100: 50. 16 | } 17 | 18 | def set_random_seed(seed): 19 | torch.manual_seed(seed) 20 | 21 | 22 | def create_data_on_disk(graph_size, num_samples, is_save=True, filename=None, is_return=False, seed=1234): 23 | """Generate validation dataset (with SEED) and save 24 | """ 25 | 26 | set_random_seed(seed) 27 | depo = torch.rand((num_samples, 2)) 28 | 29 | set_random_seed(seed) 30 | graphs = torch.rand((num_samples, graph_size, 2)) 31 | 32 | set_random_seed(seed) 33 | demand = torch.randint(low=1, high=10, size=(num_samples, graph_size), dtype=torch.float32) / CAPACITIES[graph_size] 34 | 35 | if is_save: 36 | save_to_pickle('./valsets/Validation_dataset_{}.pkl'.format(filename), (depo, graphs, demand)) 37 | 38 | if is_return: 39 | return (depo, graphs, demand) 40 | 41 | 42 | def save_to_pickle(filename, item): 43 | """Save to pickle 44 | """ 45 | with open(filename, 'wb') as handle: 46 | pickle.dump(item, handle) 47 | 48 | 49 | def read_from_pickle(path, return_data_set=True, num_samples=None): 50 | """Read dataset from file (pickle) 51 | """ 52 | 53 | objects = [] 54 | with (open(path, "rb")) as openfile: 55 | while True: 56 | try: 57 | objects.append(pickle.load(openfile)) 58 | except EOFError: 59 | break 60 | objects = objects[0] 61 | if return_data_set: 62 | depo, graphs, demand = objects 63 | if num_samples is not None: 64 | return (depo[:num_samples], graphs[:num_samples], demand[:num_samples]) 65 | else: 66 | return (depo, graphs, demand) 67 | else: 68 | return objects 69 | 70 | 71 | def generate_data_onfly(num_samples=10000, graph_size=20): 72 | """Generate temp dataset in memory 73 | """ 74 | 75 | depo = torch.rand((num_samples, 2)) 76 | graphs = torch.rand((num_samples, graph_size, 2)) 77 | demand = torch.randint(low=1, high=10, size=(num_samples, graph_size), dtype=torch.float32) / CAPACITIES[graph_size] 78 | 79 | return (depo, graphs, demand) 80 | 81 | 82 | def get_results(train_loss_results, train_cost_results, val_cost, save_results=True, filename=None, plots=True): 83 | 84 | epochs_num = len(train_loss_results) 85 | 86 | df_train = pd.DataFrame(data={'epochs': list(range(epochs_num)), 87 | 'loss': np.array(train_loss_results), 88 | 'cost': np.array(train_cost_results), 89 | }) 90 | df_test = pd.DataFrame(data={'epochs': list(range(epochs_num)), 91 | 'val_сost': np.array(val_cost)}) 92 | if save_results: 93 | df_train.to_excel('train_results_{}.xlsx'.format(filename), index=False) 94 | df_test.to_excel('test_results_{}.xlsx'.format(filename), index=False) 95 | 96 | if plots: 97 | plt.figure(figsize=(15, 9)) 98 | ax = sns.lineplot(x='epochs', y='loss', data=df_train, color='salmon', label='train loss') 99 | ax2 = ax.twinx() 100 | sns.lineplot(x='epochs', y='cost', data=df_train, color='cornflowerblue', label='train cost', ax=ax2) 101 | sns.lineplot(x='epochs', y='val_сost', data=df_test, palette='darkblue', label='val cost').set(ylabel='cost') 102 | ax.legend(loc=(0.75, 0.90), ncol=1) 103 | ax2.legend(loc=(0.75, 0.95), ncol=2) 104 | ax.grid(axis='x') 105 | ax2.grid(True) 106 | plt.savefig('learning_curve_plot_{}.jpg'.format(filename)) 107 | plt.show() 108 | 109 | class FastTensorDataLoader: 110 | """ 111 | A DataLoader-like object for a set of tensors that can be much faster than 112 | TensorDataset + DataLoader because dataloader grabs individual indices of 113 | the dataset and calls cat (slow). 114 | Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 115 | """ 116 | def __init__(self, *tensors, batch_size=32, shuffle=False): 117 | """ 118 | Initialize a FastTensorDataLoader. 119 | :param *tensors: tensors to store. Must have the same length @ dim 0. 120 | :param batch_size: batch size to load. 121 | :param shuffle: if True, shuffle the data *in-place* whenever an 122 | iterator is created out of this object. 123 | :returns: A FastTensorDataLoader. 124 | """ 125 | assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) 126 | self.tensors = tensors 127 | 128 | self.dataset_len = self.tensors[0].shape[0] 129 | self.batch_size = batch_size 130 | self.shuffle = shuffle 131 | 132 | # Calculate # batches 133 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 134 | if remainder > 0: 135 | n_batches += 1 136 | self.n_batches = n_batches 137 | def __iter__(self): 138 | if self.shuffle: 139 | r = torch.randperm(self.dataset_len) 140 | self.tensors = [t[r] for t in self.tensors] 141 | self.i = 0 142 | return self 143 | 144 | def __next__(self): 145 | if self.i >= self.dataset_len: 146 | raise StopIteration 147 | batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) 148 | self.i += self.batch_size 149 | return batch 150 | 151 | def __len__(self): 152 | return self.n_batches 153 | 154 | def get_journey(batch, pi, title, ind_in_batch=0): 155 | """Plots journey of agent 156 | 157 | Args: 158 | batch: dataset of graphs 159 | pi: paths of agent obtained from model 160 | ind_in_batch: index of graph in batch to be plotted 161 | """ 162 | 163 | # Remove extra zeros 164 | pi_ = get_clean_path(pi[ind_in_batch].numpy()) 165 | 166 | # Unpack variables 167 | depo_coord = batch[0][ind_in_batch].numpy() 168 | points_coords = batch[1][ind_in_batch].numpy() 169 | demands = batch[2][ind_in_batch].numpy() 170 | node_labels = ['(' + str(x[0]) + ', ' + x[1] + ')' for x in enumerate(demands.round(2).astype(str))] 171 | 172 | # Concatenate depot and points 173 | full_coords = np.concatenate((depo_coord.reshape(1, 2), points_coords)) 174 | 175 | # Get list with agent loops in path 176 | list_of_paths = [] 177 | cur_path = [] 178 | for idx, node in enumerate(pi_): 179 | 180 | cur_path.append(node) 181 | 182 | if idx != 0 and node == 0: 183 | if cur_path[0] != 0: 184 | cur_path.insert(0, 0) 185 | list_of_paths.append(cur_path) 186 | cur_path = [] 187 | 188 | list_of_path_traces = [] 189 | for path_counter, path in enumerate(list_of_paths): 190 | coords = full_coords[[int(x) for x in path]] 191 | 192 | # Calculate length of each agent loop 193 | lengths = np.sqrt(np.sum(np.diff(coords, axis=0) ** 2, axis=1)) 194 | total_length = np.sum(lengths) 195 | 196 | list_of_path_traces.append(go.Scatter(x=coords[:, 0], 197 | y=coords[:, 1], 198 | mode="markers+lines", 199 | name=f"path_{path_counter}, length={total_length:.2f}", 200 | opacity=1.0)) 201 | 202 | trace_points = go.Scatter(x=points_coords[:, 0], 203 | y=points_coords[:, 1], 204 | mode='markers+text', 205 | name='destinations', 206 | text=node_labels, 207 | textposition='top center', 208 | marker=dict(size=7), 209 | opacity=1.0 210 | ) 211 | 212 | trace_depo = go.Scatter(x=[depo_coord[0]], 213 | y=[depo_coord[1]], 214 | text=['1.0'], textposition='bottom center', 215 | mode='markers+text', 216 | marker=dict(size=15), 217 | name='depot' 218 | ) 219 | 220 | layout = go.Layout(title='Example: {}'.format(title), 221 | xaxis=dict(title='X coordinate'), 222 | yaxis=dict(title='Y coordinate'), 223 | showlegend=True, 224 | width=1000, 225 | height=1000, 226 | template="plotly_white" 227 | ) 228 | 229 | data = [trace_points, trace_depo] + list_of_path_traces 230 | print('Current path: ', pi_) 231 | fig = go.Figure(data=data, layout=layout) 232 | fig.show() 233 | 234 | def get_cur_time(): 235 | """Returns local time as string 236 | """ 237 | ts = time.time() 238 | return datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') 239 | 240 | 241 | def get_clean_path(arr): 242 | """Returns extra zeros from path. 243 | Dynamical model generates duplicated zeros for several graphs when obtaining partial solutions. 244 | """ 245 | 246 | p1, p2 = 0, 1 247 | output = [] 248 | 249 | while p2 < len(arr): 250 | 251 | if arr[p1] != arr[p2]: 252 | output.append(arr[p1]) 253 | if p2 == len(arr) - 1: 254 | output.append(arr[p2]) 255 | 256 | p1 += 1 257 | p2 += 1 258 | 259 | if output[0] != 0: 260 | output.insert(0, 0.0) 261 | if output[-1] != 0: 262 | output.append(0.0) 263 | 264 | return output 265 | 266 | 267 | def get_dev_of_mod(model): 268 | return next(model.parameters()).device 269 | 270 | 271 | def _open_data(path): 272 | return open(path, 'rb') 273 | 274 | def get_lhk_solved_data(path_instances, path_sols): 275 | """ 276 | - instances[i][0] -> depot(x, y) 277 | - instances[i][1] -> nodes(x, y) * samples 278 | - instances[i][2] -> nodes(demand) * samples 279 | - instances[i][3] -> capacity (of vehicle) (should be the same for all in theory) 280 | 281 | - sols[0][i][0] -> cost 282 | - sols[0][i][1] -> path (doesn't include depot at the end) 283 | - sols[1] -> ? 284 | - sols[0][1][2] -> ? 285 | """ 286 | 287 | with _open_data(path_instances) as f: 288 | instances_data = pickle.load(f) 289 | with _open_data(path_sols) as f: 290 | sols_data = pickle.load(f) 291 | 292 | capacity_denominator = CAPACITIES[len(instances_data[0][1])] 293 | 294 | depot_locs = (list(map(lambda x: x[0], instances_data))) # (samples, 2) 295 | nodes_locs = (list(map(lambda x: x[1], instances_data))) # (samples, nodes, 2) 296 | nodes_demand = (list(map(lambda x: list(map(lambda d: d/capacity_denominator, x[2])), instances_data))) # (samples, nodes) 297 | instances = (depot_locs, nodes_locs, nodes_demand) 298 | 299 | path_indices = list(map(lambda x: x[1], sols_data[0])) # (samples, path_len) 300 | costs = list(map(lambda x: x[0], sols_data[0])) # (samples) 301 | 302 | capacities = (list(map(lambda x: x[3], instances_data))) # (samples) 303 | 304 | return instances, path_indices, costs, capacities -------------------------------------------------------------------------------- /utils_demo.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import plotly.graph_objects as go 4 | from plotly.subplots import make_subplots 5 | import numpy as np 6 | 7 | 8 | def f_get_results_plot_seaborn(data, title, graph_size=20): 9 | fig = plt.figure(figsize=(15, 9)) 10 | ax = fig.add_subplot() 11 | ax.plot(data['epochs'], data['train_loss'], color='salmon', label='train loss') 12 | ax2 = ax.twinx() 13 | ax2.plot(data['epochs'], data['train_cost'], color='cornflowerblue', label='train cost') 14 | ax2.plot(data['epochs'], data['val_cost'], color='darkblue', label='val cost') 15 | 16 | if graph_size == 20: 17 | am_val = 6.4 18 | else: 19 | am_val = 10.98 20 | 21 | plt.axhline(y=am_val, color='black', linestyle='--', linewidth=1.5, label='AM article best score') 22 | 23 | fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax.transAxes) 24 | 25 | ax.set_ylabel('Loss') 26 | ax2.set_ylabel('Cost') 27 | ax.set_xlabel('Epochs') 28 | ax.grid(False) 29 | ax2.grid(False) 30 | ax2.set_yticks(np.arange(min(data['val_cost'].min(), data['train_cost'].min())-0.2, 31 | max(data['val_cost'].max(), data['train_cost'].max())+0.1, 32 | 0.1).round(2)) 33 | plt.title('Learning Curve: ' + title) 34 | plt.show() 35 | 36 | 37 | def f_get_results_plot_plotly(data, title, graph_size=20): 38 | # Create figure with secondary y-axis 39 | fig = make_subplots(specs=[[{"secondary_y": True}]]) 40 | 41 | # Add traces 42 | fig.add_trace( 43 | go.Scatter(x=data['epochs'], y=data['train_loss'], name="train loss", marker_color='salmon'), 44 | secondary_y=False, 45 | ) 46 | 47 | fig.add_trace( 48 | go.Scatter(x=data['epochs'], y=data['train_cost'], name="train cost", marker_color='cornflowerblue'), 49 | secondary_y=True, 50 | ) 51 | 52 | fig.add_trace( 53 | go.Scatter(x=data['epochs'], y=data['val_cost'], name="val cost", marker_color='darkblue'), 54 | secondary_y=True, 55 | ) 56 | 57 | # Add figure title 58 | fig.update_layout( 59 | title_text="Learning Curve: " + title, 60 | width=950, 61 | height=650, 62 | # plot_bgcolor='rgba(0,0,0,0)' 63 | template="plotly_white" 64 | ) 65 | 66 | # Set x-axis title 67 | fig.update_xaxes(title_text="Number of epoch") 68 | 69 | # Set y-axes titles 70 | fig.update_yaxes(title_text="Loss", secondary_y=False, showgrid=False, zeroline=False) 71 | fig.update_yaxes(title_text="Cost", secondary_y=True, dtick=0.1)#, nticks=20) 72 | 73 | fig.show() -------------------------------------------------------------------------------- /valsets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore --------------------------------------------------------------------------------