├── .gitignore ├── LICENSE ├── README.md ├── beam_search.py ├── img ├── sort10-0.png ├── sort10-1.png ├── sort15-0.png ├── sort15-1.png ├── sort20-0.png ├── sort20-1.png ├── tsp_20_train_reward.png ├── tsp_20_val_reward.png ├── tsp_50_train_reward.png └── tsp_50_val_reward.png ├── main.sh ├── neural_combinatorial_rl.py ├── plot_attention.py ├── scripts ├── hyperparam_search.py ├── plot_reward.py └── tune_hyper.sh ├── sorting_task.py ├── trainer.py └── tsp_task.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | outputs/ 3 | logs/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Patrick E. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neural-combinatorial-rl-pytorch 2 | 3 | PyTorch implementation of [Neural Combinatorial Optimization with Reinforcement Learning](https://arxiv.org/abs/1611.09940). 4 | 5 | I have implemented the basic RL pretraining model with greedy decoding from the paper. An implementation of the supervised learning baseline model is available [here](https://github.com/pemami4911/neural-combinatorial-rl-tensorflow). Instead of a critic network, I got my results below on TSP from using an exponential moving average critic. The critic network is simply commented out in my code right now. From correspondence with a few others, it was determined that the exponential moving average critic significantly helped improve results. 6 | 7 | My implementation uses a stochastic decoding policy in the pointer network, realized via PyTorch's `torch.multinomial()`, during training, and beam search (**not yet finished**, only supports 1 beam a.k.a. greedy) for decoding when testing the model. 8 | 9 | Currently, there is support for a sorting task and the planar symmetric Euclidean TSP. 10 | 11 | See `main.sh` for an example of how to run the code. 12 | 13 | Use the `--load_path $LOAD_PATH` and `--is_train False` flags to load a saved model. 14 | 15 | To load a saved model and view the pointer network's attention layer, also use the `--plot_attention True` flag. 16 | 17 | Please, feel free to notify me if you encounter any errors, or if you'd like to submit a pull request to improve this implementation. 18 | 19 | ## Adding other tasks 20 | 21 | This implementation can be extended to support other combinatorial optimization problems. See `sorting_task.py` and `tsp_task.py` for examples on how to add. The key thing is to provide a dataset class and a reward function that takes in a sample solution, selected by the pointer network from the input, and returns a scalar reward. For the sorting task, the agent received a reward proportional to the length of the longest strictly increasing subsequence in the decoded output (e.g., `[1, 3, 5, 2, 4] -> 3/5 = 0.6`). 22 | 23 | ## Dependencies 24 | 25 | * Python=3.6 (should be OK with v >= 3.4) 26 | * PyTorch=0.2 and 0.3 27 | * tqdm 28 | * matplotlib 29 | * [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger) 30 | 31 | PyTorch 0.4 compatibility is available on branch `pytorch-0.4`. 32 | 33 | ## TSP Results 34 | 35 | Results for 1 random seed over 50 epochs (each epoch is 10,000 batches of size 128). After each epoch, I validated performance on 1000 held out graphs. I used the same hyperparameters from the paper, as can be seen in `main.sh`. The dashed line shows the value indicated in Table 2 of Bello, et. al for comparison. The log scale x axis for the training reward is used to show how the tour length drops early on. 36 | 37 | ![TSP 20 Train](img/tsp_20_train_reward.png) 38 | ![TSP 20 Val](img/tsp_20_val_reward.png) 39 | ![TSP 50 Train](img/tsp_50_train_reward.png) 40 | ![TSP 50 Val](img/tsp_50_val_reward.png) 41 | 42 | ## Sort Results 43 | 44 | I trained a model on `sort10` for 4 epochs of 1,000,000 randomly generated samples. I tested it on a dataset of size 10,000. Then, I tested the same model on `sort15` and `sort20` to test the generalization capabilities. 45 | 46 | Test results on 10,000 samples (A reward of 1.0 means the network perfectly sorted the input): 47 | 48 | | task | average reward | variance | 49 | |---|---|---| 50 | | sort10 | 0.9966 | 0.0005 | 51 | | sort15 | 0.7484 | 0.0177 | 52 | | sort20 | 0.5586 | 0.0060 | 53 | 54 | 55 | Example prediction on `sort10`: 56 | 57 | ``` 58 | input: [4, 7, 5, 0, 3, 2, 6, 8, 9, 1] 59 | output: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 60 | ``` 61 | 62 | ### Attention visualization 63 | 64 | Plot the pointer network's attention layer with the argument `--plot_attention True` 65 | 66 | ## TODO 67 | 68 | * [ ] Add RL pretraining-Sampling 69 | * [ ] Add RL pretraining-Active Search 70 | * [ ] Active Search 71 | * [ ] Asynchronous training a la A3C 72 | * [X] Refactor `USE_CUDA` variable 73 | * [ ] Finish implementing beam search decoding to support > 1 beam 74 | * [ ] Add support for variable length inputs 75 | 76 | ## Acknowledgements 77 | 78 | Special thanks to the repos [devsisters/neural-combinatorial-rl-tensorflow](https://github.com/devsisters/neural-combinatorial-rl-tensorflow) and [MaximumEntropy/Seq2Seq-PyTorch](https://github.com/MaximumEntropy/Seq2Seq-PyTorch) for getting me started, and @ricgama for figuring out that weird bug with `clone()` 79 | 80 | -------------------------------------------------------------------------------- /beam_search.py: -------------------------------------------------------------------------------- 1 | # beam search implementation in PyTorch.""" 2 | # 3 | # 4 | # hyp1#-hyp1---hyp1 -hyp1 5 | # \ / 6 | # hyp2 \-hyp2 /-hyp2#hyp2 7 | # / \ 8 | # hyp3#-hyp3---hyp3 -hyp3 9 | # ======================== 10 | # 11 | # Takes care of beams, back pointers, and scores. 12 | 13 | # Code borrowed from https://github.com/MaximumEntropy/Seq2Seq-PyTorch/blob/master/beam_search.py, 14 | # who borrowed it from PyTorch OpenNMT example 15 | # https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Beam.py 16 | # :-) 17 | 18 | import torch 19 | 20 | 21 | class Beam(object): 22 | """Ordered beam of candidate outputs. Fixed length.""" 23 | 24 | def __init__(self, size, steps, cuda=False): 25 | """Initialize params.""" 26 | self.size = size 27 | self.done = False 28 | self.pad = -1 29 | self.steps = steps 30 | self.current_step = 0 31 | self.tt = torch.cuda if cuda else torch 32 | 33 | # The score for each translation on the beam. 34 | self.scores = self.tt.FloatTensor(size).zero_() 35 | 36 | # The backpointers at each time-step. 37 | self.prevKs = [] 38 | 39 | # The outputs at each time-step. 40 | self.nextYs = [self.tt.LongTensor(size).fill_(self.pad)] 41 | 42 | # The attentions (matrix) for each time. 43 | self.attn = [] 44 | 45 | # Get the outputs for the current timestep. 46 | def get_current_state(self): 47 | """Get state of beam.""" 48 | return self.nextYs[-1] 49 | 50 | # Get the backpointers for the current timestep. 51 | def get_current_origin(self): 52 | """Get the backpointer to the beam at this step.""" 53 | return self.prevKs[-1] 54 | 55 | # Given prob over words for every last beam `wordLk` and attention 56 | # `attnOut`: Compute and update the beam search. 57 | # 58 | # Parameters: 59 | # 60 | # * `wordLk`- probs of advancing from the last step (K x words) 61 | # * `attnOut`- attention at the last step 62 | # 63 | # Returns: True if beam search is complete. 64 | 65 | def advance(self, workd_lk): 66 | """Advance the beam.""" 67 | num_words = workd_lk.size(1) 68 | 69 | # Sum the previous scores. 70 | if len(self.prevKs) > 0: 71 | beam_lk = workd_lk + self.scores.unsqueeze(1).expand_as(workd_lk) 72 | else: 73 | beam_lk = workd_lk[0] 74 | 75 | flat_beam_lk = beam_lk.view(-1) 76 | 77 | bestScores, bestScoresId = flat_beam_lk.topk(self.size, 0, True, True) 78 | self.scores = bestScores 79 | 80 | # bestScoresId is flattened beam x word array, so calculate which 81 | # word and beam each score came from 82 | prev_k = bestScoresId / num_words 83 | self.prevKs.append(prev_k) 84 | self.nextYs.append(bestScoresId - prev_k * num_words) 85 | 86 | self.current_step += 1 87 | # End condition is when top-of-beam is EOS. 88 | if self.current_step == self.steps: 89 | self.done = True 90 | 91 | return self.done 92 | 93 | def sort_best(self): 94 | """Sort the beam.""" 95 | return torch.sort(self.scores, 0, True) 96 | 97 | # Get the score of the best in the beam. 98 | def get_best(self): 99 | """Get the most likely candidate.""" 100 | scores, ids = self.sort_best() 101 | return scores[1], ids[1] 102 | 103 | # Walk back to construct the full hypothesis. 104 | # 105 | # Parameters. 106 | # 107 | # * `k` - the position in the beam to construct. 108 | # 109 | # Returns. 110 | # 111 | # 1. The hypothesis 112 | # 2. The attention at each time step. 113 | def get_hyp(self, k): 114 | """Get hypotheses.""" 115 | hyp = [] 116 | # print(len(self.prevKs), len(self.nextYs), len(self.attn)) 117 | for j in range(len(self.prevKs) - 1, -1, -1): 118 | hyp.append(self.nextYs[j + 1][k]) 119 | k = self.prevKs[j][k] 120 | 121 | return hyp[::-1] 122 | -------------------------------------------------------------------------------- /img/sort10-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort10-0.png -------------------------------------------------------------------------------- /img/sort10-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort10-1.png -------------------------------------------------------------------------------- /img/sort15-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort15-0.png -------------------------------------------------------------------------------- /img/sort15-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort15-1.png -------------------------------------------------------------------------------- /img/sort20-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort20-0.png -------------------------------------------------------------------------------- /img/sort20-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/sort20-1.png -------------------------------------------------------------------------------- /img/tsp_20_train_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/tsp_20_train_reward.png -------------------------------------------------------------------------------- /img/tsp_20_val_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/tsp_20_val_reward.png -------------------------------------------------------------------------------- /img/tsp_50_train_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/tsp_50_train_reward.png -------------------------------------------------------------------------------- /img/tsp_50_val_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pemami4911/neural-combinatorial-rl-pytorch/fc20970dda459891715b584a2a45fb9c9bdc7b8c/img/tsp_50_val_reward.png -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TASK='tsp_50' 4 | BEAM_SIZE=1 5 | EMBEDDING_DIM=128 6 | HIDDEN_DIM=128 7 | BATCH_SIZE=128 8 | ACTOR_NET_LR=1e-3 9 | CRITIC_NET_LR=1e-3 10 | ACTOR_LR_DECAY_RATE=0.96 11 | ACTOR_LR_DECAY_STEP=5000 12 | CRITIC_LR_DECAY_RATE=0.96 13 | CRITIC_LR_DECAY_STEP=5000 14 | N_PROCESS_BLOCKS=3 15 | N_GLIMPSES=1 16 | N_EPOCHS=50 17 | EPOCH_START=0 18 | MAX_GRAD_NORM=1.0 19 | RANDOM_SEED=$1 20 | RUN_NAME="$ACTOR_NET_LR-seed-$RANDOM_SEED" 21 | TRAIN_SIZE=1280000 22 | VAL_SIZE=1000 23 | LOAD_PATH="outputs/tsp_20/LR3-$ACTOR_NET_LR-seed-$RANDOM_SEED/epoch-5.pt" 24 | USE_CUDA=True 25 | DISABLE_TENSORBOARD=False 26 | REWARD_SCALE=1 27 | USE_TANH=True 28 | CRITIC_BETA=0.8 29 | 30 | ./trainer.py --task $TASK --beam_size $BEAM_SIZE --actor_net_lr $ACTOR_NET_LR --critic_net_lr $CRITIC_NET_LR --n_epochs $N_EPOCHS --random_seed $RANDOM_SEED --max_grad_norm $MAX_GRAD_NORM --run_name $RUN_NAME --epoch_start $EPOCH_START --train_size $TRAIN_SIZE --n_process_blocks $N_PROCESS_BLOCKS --batch_size $BATCH_SIZE --actor_lr_decay_rate $ACTOR_LR_DECAY_RATE --actor_lr_decay_step $ACTOR_LR_DECAY_STEP --critic_lr_decay_rate $CRITIC_LR_DECAY_RATE --critic_lr_decay_step $CRITIC_LR_DECAY_STEP --embedding_dim $EMBEDDING_DIM --hidden_dim $HIDDEN_DIM --val_size $VAL_SIZE --n_glimpses $N_GLIMPSES --use_cuda $USE_CUDA --disable_tensorboard $DISABLE_TENSORBOARD --reward_scale $REWARD_SCALE --use_tanh $USE_TANH --critic_beta $CRITIC_BETA 31 | 32 | -------------------------------------------------------------------------------- /neural_combinatorial_rl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import math 7 | import numpy as np 8 | 9 | from beam_search import Beam 10 | 11 | 12 | class Encoder(nn.Module): 13 | """Maps a graph represented as an input sequence 14 | to a hidden vector""" 15 | def __init__(self, input_dim, hidden_dim, use_cuda): 16 | super(Encoder, self).__init__() 17 | self.hidden_dim = hidden_dim 18 | self.lstm = nn.LSTM(input_dim, hidden_dim) 19 | self.use_cuda = use_cuda 20 | self.enc_init_state = self.init_hidden(hidden_dim) 21 | 22 | def forward(self, x, hidden): 23 | output, hidden = self.lstm(x, hidden) 24 | return output, hidden 25 | 26 | def init_hidden(self, hidden_dim): 27 | """Trainable initial hidden state""" 28 | enc_init_hx = Variable(torch.zeros(hidden_dim), requires_grad=False) 29 | if self.use_cuda: 30 | enc_init_hx = enc_init_hx.cuda() 31 | 32 | #enc_init_hx.data.uniform_(-(1. / math.sqrt(hidden_dim)), 33 | # 1. / math.sqrt(hidden_dim)) 34 | 35 | enc_init_cx = Variable(torch.zeros(hidden_dim), requires_grad=False) 36 | if self.use_cuda: 37 | enc_init_cx = enc_init_cx.cuda() 38 | 39 | #enc_init_cx = nn.Parameter(enc_init_cx) 40 | #enc_init_cx.data.uniform_(-(1. / math.sqrt(hidden_dim)), 41 | # 1. / math.sqrt(hidden_dim)) 42 | return (enc_init_hx, enc_init_cx) 43 | 44 | 45 | class Attention(nn.Module): 46 | """A generic attention module for a decoder in seq2seq""" 47 | def __init__(self, dim, use_tanh=False, C=10, use_cuda=True): 48 | super(Attention, self).__init__() 49 | self.use_tanh = use_tanh 50 | self.project_query = nn.Linear(dim, dim) 51 | self.project_ref = nn.Conv1d(dim, dim, 1, 1) 52 | self.C = C # tanh exploration 53 | self.tanh = nn.Tanh() 54 | 55 | v = torch.FloatTensor(dim) 56 | if use_cuda: 57 | v = v.cuda() 58 | self.v = nn.Parameter(v) 59 | self.v.data.uniform_(-(1. / math.sqrt(dim)) , 1. / math.sqrt(dim)) 60 | 61 | def forward(self, query, ref): 62 | """ 63 | Args: 64 | query: is the hidden state of the decoder at the current 65 | time step. batch x dim 66 | ref: the set of hidden states from the encoder. 67 | sourceL x batch x hidden_dim 68 | """ 69 | # ref is now [batch_size x hidden_dim x sourceL] 70 | ref = ref.permute(1, 2, 0) 71 | q = self.project_query(query).unsqueeze(2) # batch x dim x 1 72 | e = self.project_ref(ref) # batch_size x hidden_dim x sourceL 73 | # expand the query by sourceL 74 | # batch x dim x sourceL 75 | expanded_q = q.repeat(1, 1, e.size(2)) 76 | # batch x 1 x hidden_dim 77 | v_view = self.v.unsqueeze(0).expand( 78 | expanded_q.size(0), len(self.v)).unsqueeze(1) 79 | # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL] 80 | u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1) 81 | if self.use_tanh: 82 | logits = self.C * self.tanh(u) 83 | else: 84 | logits = u 85 | return e, logits 86 | 87 | 88 | class Decoder(nn.Module): 89 | def __init__(self, 90 | embedding_dim, 91 | hidden_dim, 92 | max_length, 93 | tanh_exploration, 94 | terminating_symbol, 95 | use_tanh, 96 | decode_type, 97 | n_glimpses=1, 98 | beam_size=0, 99 | use_cuda=True): 100 | super(Decoder, self).__init__() 101 | 102 | self.embedding_dim = embedding_dim 103 | self.hidden_dim = hidden_dim 104 | self.n_glimpses = n_glimpses 105 | self.max_length = max_length 106 | self.terminating_symbol = terminating_symbol 107 | self.decode_type = decode_type 108 | self.beam_size = beam_size 109 | self.use_cuda = use_cuda 110 | 111 | self.input_weights = nn.Linear(embedding_dim, 4 * hidden_dim) 112 | self.hidden_weights = nn.Linear(hidden_dim, 4 * hidden_dim) 113 | 114 | self.pointer = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration, use_cuda=self.use_cuda) 115 | self.glimpse = Attention(hidden_dim, use_tanh=False, use_cuda=self.use_cuda) 116 | self.sm = nn.Softmax() 117 | 118 | def apply_mask_to_logits(self, step, logits, mask, prev_idxs): 119 | if mask is None: 120 | mask = torch.zeros(logits.size()).byte() 121 | if self.use_cuda: 122 | mask = mask.cuda() 123 | 124 | maskk = mask.clone() 125 | 126 | # to prevent them from being reselected. 127 | # Or, allow re-selection and penalize in the objective function 128 | if prev_idxs is not None: 129 | # set most recently selected idx values to 1 130 | maskk[[x for x in range(logits.size(0))], 131 | prev_idxs.data] = 1 132 | logits[maskk] = -np.inf 133 | return logits, maskk 134 | 135 | def forward(self, decoder_input, embedded_inputs, hidden, context): 136 | """ 137 | Args: 138 | decoder_input: The initial input to the decoder 139 | size is [batch_size x embedding_dim]. Trainable parameter. 140 | embedded_inputs: [sourceL x batch_size x embedding_dim] 141 | hidden: the prev hidden state, size is [batch_size x hidden_dim]. 142 | Initially this is set to (enc_h[-1], enc_c[-1]) 143 | context: encoder outputs, [sourceL x batch_size x hidden_dim] 144 | """ 145 | def recurrence(x, hidden, logit_mask, prev_idxs, step): 146 | 147 | hx, cx = hidden # batch_size x hidden_dim 148 | 149 | gates = self.input_weights(x) + self.hidden_weights(hx) 150 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 151 | 152 | ingate = F.sigmoid(ingate) 153 | forgetgate = F.sigmoid(forgetgate) 154 | cellgate = F.tanh(cellgate) 155 | outgate = F.sigmoid(outgate) 156 | 157 | cy = (forgetgate * cx) + (ingate * cellgate) 158 | hy = outgate * F.tanh(cy) # batch_size x hidden_dim 159 | 160 | g_l = hy 161 | for i in range(self.n_glimpses): 162 | ref, logits = self.glimpse(g_l, context) 163 | logits, logit_mask = self.apply_mask_to_logits(step, logits, logit_mask, prev_idxs) 164 | # [batch_size x h_dim x sourceL] * [batch_size x sourceL x 1] = 165 | # [batch_size x h_dim x 1] 166 | g_l = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2) 167 | _, logits = self.pointer(g_l, context) 168 | 169 | logits, logit_mask = self.apply_mask_to_logits(step, logits, logit_mask, prev_idxs) 170 | probs = self.sm(logits) 171 | return hy, cy, probs, logit_mask 172 | 173 | batch_size = context.size(1) 174 | outputs = [] 175 | selections = [] 176 | steps = range(self.max_length) # or until terminating symbol ? 177 | inps = [] 178 | idxs = None 179 | mask = None 180 | 181 | if self.decode_type == "stochastic": 182 | for i in steps: 183 | hx, cx, probs, mask = recurrence(decoder_input, hidden, mask, idxs, i) 184 | hidden = (hx, cx) 185 | # select the next inputs for the decoder [batch_size x hidden_dim] 186 | decoder_input, idxs = self.decode_stochastic( 187 | probs, 188 | embedded_inputs, 189 | selections) 190 | inps.append(decoder_input) 191 | # use outs to point to next object 192 | outputs.append(probs) 193 | selections.append(idxs) 194 | return (outputs, selections), hidden 195 | 196 | elif self.decode_type == "beam_search": 197 | 198 | # Expand input tensors for beam search 199 | decoder_input = Variable(decoder_input.data.repeat(self.beam_size, 1)) 200 | context = Variable(context.data.repeat(1, self.beam_size, 1)) 201 | hidden = (Variable(hidden[0].data.repeat(self.beam_size, 1)), 202 | Variable(hidden[1].data.repeat(self.beam_size, 1))) 203 | 204 | beam = [ 205 | Beam(self.beam_size, self.max_length, cuda=self.use_cuda) 206 | for k in range(batch_size) 207 | ] 208 | 209 | for i in steps: 210 | hx, cx, probs, mask = recurrence(decoder_input, hidden, mask, idxs, i) 211 | hidden = (hx, cx) 212 | 213 | probs = probs.view(self.beam_size, batch_size, -1 214 | ).transpose(0, 1).contiguous() 215 | 216 | n_best = 1 217 | # select the next inputs for the decoder [batch_size x hidden_dim] 218 | decoder_input, idxs, active = self.decode_beam(probs, 219 | embedded_inputs, beam, batch_size, n_best, i) 220 | 221 | inps.append(decoder_input) 222 | # use probs to point to next object 223 | if self.beam_size > 1: 224 | outputs.append(probs[:, 0,:]) 225 | else: 226 | outputs.append(probs.squeeze(0)) 227 | # Check for indexing 228 | selections.append(idxs) 229 | # Should be done decoding 230 | if len(active) == 0: 231 | break 232 | decoder_input = Variable(decoder_input.data.repeat(self.beam_size, 1)) 233 | 234 | return (outputs, selections), hidden 235 | 236 | def decode_stochastic(self, probs, embedded_inputs, selections): 237 | """ 238 | Return the next input for the decoder by selecting the 239 | input corresponding to the max output 240 | 241 | Args: 242 | probs: [batch_size x sourceL] 243 | embedded_inputs: [sourceL x batch_size x embedding_dim] 244 | selections: list of all of the previously selected indices during decoding 245 | Returns: 246 | Tensor of size [batch_size x sourceL] containing the embeddings 247 | from the inputs corresponding to the [batch_size] indices 248 | selected for this iteration of the decoding, as well as the 249 | corresponding indicies 250 | """ 251 | batch_size = probs.size(0) 252 | # idxs is [batch_size] 253 | idxs = probs.multinomial().squeeze(1) 254 | 255 | # due to race conditions, might need to resample here 256 | for old_idxs in selections: 257 | # compare new idxs 258 | # elementwise with the previous idxs. If any matches, 259 | # then need to resample 260 | if old_idxs.eq(idxs).data.any(): 261 | print(' [!] resampling due to race condition') 262 | idxs = probs.multinomial().squeeze(1) 263 | break 264 | 265 | sels = embedded_inputs[idxs.data, [i for i in range(batch_size)], :] 266 | return sels, idxs 267 | 268 | def decode_beam(self, probs, embedded_inputs, beam, batch_size, n_best, step): 269 | active = [] 270 | for b in range(batch_size): 271 | if beam[b].done: 272 | continue 273 | 274 | if not beam[b].advance(probs.data[b]): 275 | active += [b] 276 | 277 | 278 | all_hyp, all_scores = [], [] 279 | for b in range(batch_size): 280 | scores, ks = beam[b].sort_best() 281 | all_scores += [scores[:n_best]] 282 | hyps = zip(*[beam[b].get_hyp(k) for k in ks[:n_best]]) 283 | all_hyp += [hyps] 284 | 285 | all_idxs = Variable(torch.LongTensor([[x for x in hyp] for hyp in all_hyp]).squeeze()) 286 | 287 | if all_idxs.dim() == 2: 288 | if all_idxs.size(1) > n_best: 289 | idxs = all_idxs[:,-1] 290 | else: 291 | idxs = all_idxs 292 | elif all_idxs.dim() == 3: 293 | idxs = all_idxs[:, -1, :] 294 | else: 295 | if all_idxs.size(0) > 1: 296 | idxs = all_idxs[-1] 297 | else: 298 | idxs = all_idxs 299 | 300 | if self.use_cuda: 301 | idxs = idxs.cuda() 302 | 303 | if idxs.dim() > 1: 304 | x = embedded_inputs[idxs.transpose(0,1).contiguous().data, 305 | [x for x in range(batch_size)], :] 306 | else: 307 | x = embedded_inputs[idxs.data, [x for x in range(batch_size)], :] 308 | return x.view(idxs.size(0) * n_best, embedded_inputs.size(2)), idxs, active 309 | 310 | class PointerNetwork(nn.Module): 311 | """The pointer network, which is the core seq2seq 312 | model""" 313 | def __init__(self, 314 | embedding_dim, 315 | hidden_dim, 316 | max_decoding_len, 317 | terminating_symbol, 318 | n_glimpses, 319 | tanh_exploration, 320 | use_tanh, 321 | beam_size, 322 | use_cuda): 323 | super(PointerNetwork, self).__init__() 324 | 325 | self.encoder = Encoder( 326 | embedding_dim, 327 | hidden_dim, 328 | use_cuda) 329 | 330 | self.decoder = Decoder( 331 | embedding_dim, 332 | hidden_dim, 333 | max_length=max_decoding_len, 334 | tanh_exploration=tanh_exploration, 335 | use_tanh=use_tanh, 336 | terminating_symbol=terminating_symbol, 337 | decode_type="stochastic", 338 | n_glimpses=n_glimpses, 339 | beam_size=beam_size, 340 | use_cuda=use_cuda) 341 | 342 | # Trainable initial hidden states 343 | dec_in_0 = torch.FloatTensor(embedding_dim) 344 | if use_cuda: 345 | dec_in_0 = dec_in_0.cuda() 346 | 347 | self.decoder_in_0 = nn.Parameter(dec_in_0) 348 | self.decoder_in_0.data.uniform_(-(1. / math.sqrt(embedding_dim)), 349 | 1. / math.sqrt(embedding_dim)) 350 | 351 | def forward(self, inputs): 352 | """ Propagate inputs through the network 353 | Args: 354 | inputs: [sourceL x batch_size x embedding_dim] 355 | """ 356 | 357 | (encoder_hx, encoder_cx) = self.encoder.enc_init_state 358 | encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) 359 | encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) 360 | 361 | # encoder forward pass 362 | enc_h, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) 363 | 364 | dec_init_state = (enc_h_t[-1], enc_c_t[-1]) 365 | 366 | # repeat decoder_in_0 across batch 367 | decoder_input = self.decoder_in_0.unsqueeze(0).repeat(inputs.size(1), 1) 368 | 369 | (pointer_probs, input_idxs), dec_hidden_t = self.decoder(decoder_input, 370 | inputs, 371 | dec_init_state, 372 | enc_h) 373 | 374 | return pointer_probs, input_idxs 375 | 376 | 377 | class CriticNetwork(nn.Module): 378 | """Useful as a baseline in REINFORCE updates""" 379 | def __init__(self, 380 | embedding_dim, 381 | hidden_dim, 382 | n_process_block_iters, 383 | tanh_exploration, 384 | use_tanh, 385 | use_cuda): 386 | super(CriticNetwork, self).__init__() 387 | 388 | self.hidden_dim = hidden_dim 389 | self.n_process_block_iters = n_process_block_iters 390 | 391 | self.encoder = Encoder( 392 | embedding_dim, 393 | hidden_dim, 394 | use_cuda) 395 | 396 | self.process_block = Attention(hidden_dim, 397 | use_tanh=use_tanh, C=tanh_exploration, use_cuda=use_cuda) 398 | self.sm = nn.Softmax() 399 | self.decoder = nn.Sequential( 400 | nn.Linear(hidden_dim, hidden_dim), 401 | nn.ReLU(), 402 | nn.Linear(hidden_dim, 1) 403 | ) 404 | 405 | def forward(self, inputs): 406 | """ 407 | Args: 408 | inputs: [embedding_dim x batch_size x sourceL] of embedded inputs 409 | """ 410 | 411 | (encoder_hx, encoder_cx) = self.encoder.enc_init_state 412 | encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) 413 | encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) 414 | 415 | # encoder forward pass 416 | enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) 417 | 418 | # grab the hidden state and process it via the process block 419 | process_block_state = enc_h_t[-1] 420 | for i in range(self.n_process_block_iters): 421 | ref, logits = self.process_block(process_block_state, enc_outputs) 422 | process_block_state = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2) 423 | # produce the final scalar output 424 | out = self.decoder(process_block_state) 425 | return out 426 | 427 | class NeuralCombOptRL(nn.Module): 428 | """ 429 | This module contains the PointerNetwork (actor) and 430 | CriticNetwork (critic). It requires 431 | an application-specific reward function 432 | """ 433 | def __init__(self, 434 | input_dim, 435 | embedding_dim, 436 | hidden_dim, 437 | max_decoding_len, 438 | terminating_symbol, 439 | n_glimpses, 440 | n_process_block_iters, 441 | tanh_exploration, 442 | use_tanh, 443 | beam_size, 444 | objective_fn, 445 | is_train, 446 | use_cuda): 447 | super(NeuralCombOptRL, self).__init__() 448 | self.objective_fn = objective_fn 449 | self.input_dim = input_dim 450 | self.is_train = is_train 451 | self.use_cuda = use_cuda 452 | 453 | 454 | self.actor_net = PointerNetwork( 455 | embedding_dim, 456 | hidden_dim, 457 | max_decoding_len, 458 | terminating_symbol, 459 | n_glimpses, 460 | tanh_exploration, 461 | use_tanh, 462 | beam_size, 463 | use_cuda) 464 | 465 | #self.critic_net = CriticNetwork( 466 | # embedding_dim, 467 | # hidden_dim, 468 | # n_process_block_iters, 469 | # tanh_exploration, 470 | # False, 471 | # use_cuda) 472 | 473 | embedding_ = torch.FloatTensor(input_dim, 474 | embedding_dim) 475 | if self.use_cuda: 476 | embedding_ = embedding_.cuda() 477 | self.embedding = nn.Parameter(embedding_) 478 | self.embedding.data.uniform_(-(1. / math.sqrt(embedding_dim)), 479 | 1. / math.sqrt(embedding_dim)) 480 | 481 | def forward(self, inputs): 482 | """ 483 | Args: 484 | inputs: [batch_size, input_dim, sourceL] 485 | """ 486 | batch_size = inputs.size(0) 487 | input_dim = inputs.size(1) 488 | sourceL = inputs.size(2) 489 | 490 | # repeat embeddings across batch_size 491 | # result is [batch_size x input_dim x embedding_dim] 492 | embedding = self.embedding.repeat(batch_size, 1, 1) 493 | embedded_inputs = [] 494 | # result is [batch_size, 1, input_dim, sourceL] 495 | ips = inputs.unsqueeze(1) 496 | 497 | for i in range(sourceL): 498 | # [batch_size x 1 x input_dim] * [batch_size x input_dim x embedding_dim] 499 | # result is [batch_size, embedding_dim] 500 | embedded_inputs.append(torch.bmm( 501 | ips[:, :, :, i].float(), 502 | embedding).squeeze(1)) 503 | 504 | # Result is [sourceL x batch_size x embedding_dim] 505 | embedded_inputs = torch.cat(embedded_inputs).view( 506 | sourceL, 507 | batch_size, 508 | embedding.size(2)) 509 | 510 | # query the actor net for the input indices 511 | # making up the output, and the pointer attn 512 | probs_, action_idxs = self.actor_net(embedded_inputs) 513 | 514 | # Select the actions (inputs pointed to 515 | # by the pointer net) and the corresponding 516 | # logits 517 | # should be size [batch_size x 518 | actions = [] 519 | # inputs is [batch_size, input_dim, sourceL] 520 | inputs_ = inputs.transpose(1, 2) 521 | # inputs_ is [batch_size, sourceL, input_dim] 522 | for action_id in action_idxs: 523 | actions.append(inputs_[[x for x in range(batch_size)], action_id.data, :]) 524 | 525 | if self.is_train: 526 | # probs_ is a list of len sourceL of [batch_size x sourceL] 527 | probs = [] 528 | for prob, action_id in zip(probs_, action_idxs): 529 | probs.append(prob[[x for x in range(batch_size)], action_id.data]) 530 | else: 531 | # return the list of len sourceL of [batch_size x sourceL] 532 | probs = probs_ 533 | 534 | # get the critic value fn estimates for the baseline 535 | # [batch_size] 536 | #v = self.critic_net(embedded_inputs) 537 | 538 | # [batch_size] 539 | R = self.objective_fn(actions, self.use_cuda) 540 | 541 | #return R, v, probs, actions, action_idxs 542 | return R, probs, actions, action_idxs 543 | 544 | -------------------------------------------------------------------------------- /plot_attention.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.ticker as ticker 3 | import numpy as np 4 | 5 | 6 | def plot_attention(in_seq, out_seq, attentions): 7 | """ From http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html""" 8 | 9 | # Set up figure with colorbar 10 | fig = plt.figure() 11 | ax = fig.add_subplot(111) 12 | cax = ax.matshow(attentions, cmap='bone') 13 | fig.colorbar(cax) 14 | 15 | # Set up axes 16 | ax.set_xticklabels([' '] + [str(x) for x in in_seq], rotation=90) 17 | ax.set_yticklabels([' '] + [str(x) for x in out_seq]) 18 | 19 | # Show label at every tick 20 | ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 21 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 22 | 23 | plt.show() 24 | -------------------------------------------------------------------------------- /scripts/hyperparam_search.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import sys 4 | from math import floor, log10 5 | 6 | 7 | if __name__ == '__main__': 8 | exp_i = sys.argv[1] 9 | #rand_seed = int(sys.argv[2]) 10 | 11 | #np.random.seed(rand_seed) 12 | 13 | 14 | exps = [4] 15 | #num = np.arange(1, 9) 16 | num = [2] 17 | 18 | num_trials = 10 19 | 20 | seeds = [123, 343] 21 | 22 | for i in range(num_trials): 23 | #for rs in seeds: 24 | """ 25 | for exp in exps: 26 | for n in num: 27 | 28 | lr = n * (1./(10 ** exp)) 29 | subprocess.call(["./tune_hyper.sh", str(lr), str(rs), exp_i]) 30 | """ 31 | lr = np.random.normal(2e-4, 1e-5) 32 | subprocess.call(["./tune_hyper.sh", str(lr), '4911', exp_i]) 33 | -------------------------------------------------------------------------------- /scripts/plot_reward.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import sys 4 | 5 | if __name__ == '__main__': 6 | 7 | reward_csv = sys.argv[1] 8 | data = int(sys.argv[2]) 9 | 10 | df = pd.DataFrame.from_csv(reward_csv) 11 | 12 | if data == 0: 13 | plt.figure() 14 | plt.plot(df['Step'], df['Value']) 15 | plt.title('TSP 50, Average Tour Length (Training)') 16 | plt.xlabel('Step') 17 | plt.ylabel('Average Tour Length') 18 | plt.show() 19 | else: 20 | # average every 1000 21 | vals = [] 22 | i = 1 23 | s = 0 24 | for index, row in df.iterrows(): 25 | if i % 100 == 0: 26 | vals.append(s /100.) 27 | s = 0 28 | s += row['Value'] 29 | i += 1 30 | plt.figure() 31 | plt.plot(vals) 32 | plt.plot(xrange(10), [5.95 for _ in range(10)]) 33 | plt.title('TSP 50, Average Tour Length (Validation)') 34 | plt.xlabel('Step') 35 | plt.ylabel('Average Tour Length') 36 | plt.show() 37 | -------------------------------------------------------------------------------- /scripts/tune_hyper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TASK='tsp_20' 4 | DROPOUT=0.0 5 | BEAM_SIZE=1 6 | EMBEDDING_DIM=128 7 | HIDDEN_DIM=128 8 | BATCH_SIZE=128 9 | ACTOR_NET_LR=$1 10 | CRITIC_NET_LR=$1 11 | ACTOR_LR_DECAY_RATE=0.96 12 | ACTOR_LR_DECAY_STEP=10000 13 | CRITIC_LR_DECAY_RATE=0.96 14 | CRITIC_LR_DECAY_STEP=10000 15 | N_PROCESS_BLOCKS=3 16 | N_GLIMPSES=1 17 | N_EPOCHS=1 18 | EPOCH_START=0 19 | MAX_GRAD_NORM=1.0 20 | RANDOM_SEED=$2 21 | RUN_NAME="$3-LR-$ACTOR_NET_LR-seed-$RANDOM_SEED" 22 | TRAIN_SIZE=1280000 23 | VAL_SIZE=1000 24 | LOAD_PATH="outputs/tsp_20/hyperparam_search-0.00064669711994-seed-350/epoch-0.pt" 25 | USE_CUDA=True 26 | DISABLE_TENSORBOARD=False 27 | REWARD_SCALE=1 28 | USE_TANH=True 29 | 30 | ./trainer.py --task $TASK --dropout $DROPOUT --beam_size $BEAM_SIZE --actor_net_lr $ACTOR_NET_LR --critic_net_lr $CRITIC_NET_LR --n_epochs $N_EPOCHS --random_seed $RANDOM_SEED --max_grad_norm $MAX_GRAD_NORM --run_name $RUN_NAME --epoch_start $EPOCH_START --train_size $TRAIN_SIZE --n_process_blocks $N_PROCESS_BLOCKS --batch_size $BATCH_SIZE --actor_lr_decay_rate $ACTOR_LR_DECAY_RATE --actor_lr_decay_step $ACTOR_LR_DECAY_STEP --critic_lr_decay_rate $CRITIC_LR_DECAY_RATE --critic_lr_decay_step $CRITIC_LR_DECAY_STEP --embedding_dim $EMBEDDING_DIM --hidden_dim $HIDDEN_DIM --val_size $VAL_SIZE --n_glimpses $N_GLIMPSES --use_cuda $USE_CUDA --disable_tensorboard $DISABLE_TENSORBOARD --reward_scale $REWARD_SCALE --use_tanh $USE_TANH 31 | 32 | -------------------------------------------------------------------------------- /sorting_task.py: -------------------------------------------------------------------------------- 1 | # Generate sorting data and store in .txt 2 | # Define the reward function 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.autograd import Variable 7 | from tqdm import trange, tqdm 8 | import os 9 | import sys 10 | 11 | 12 | def reward(sample_solution, USE_CUDA=False): 13 | """ 14 | The reward for the sorting task is defined as the 15 | length of the longest sorted consecutive subsequence. 16 | 17 | Input sequences must all be the same length. 18 | 19 | Example: 20 | 21 | input | output 22 | ==================== 23 | [1 4 3 5 2] | [5 1 2 3 4] 24 | 25 | The output gets a reward of 4/5, or 0.8 26 | 27 | The range is [1/sourceL, 1] 28 | 29 | Args: 30 | sample_solution: list of len sourceL of [batch_size] 31 | Tensors 32 | Returns: 33 | [batch_size] containing trajectory rewards 34 | """ 35 | batch_size = sample_solution[0].size(0) 36 | sourceL = len(sample_solution) 37 | 38 | longest = Variable(torch.ones(batch_size, 1), requires_grad=False) 39 | current = Variable(torch.ones(batch_size, 1), requires_grad=False) 40 | 41 | if USE_CUDA: 42 | longest = longest.cuda() 43 | current = current.cuda() 44 | 45 | for i in range(1, sourceL): 46 | # compare solution[i-1] < solution[i] 47 | res = torch.lt(sample_solution[i-1], sample_solution[i]) 48 | # if res[i,j] == 1, increment length of current sorted subsequence 49 | current += res.float() 50 | # else, reset current to 1 51 | current[torch.eq(res, 0)] = 1 52 | #current[torch.eq(res, 0)] -= 1 53 | # if, for any, current > longest, update longest 54 | mask = torch.gt(current, longest) 55 | longest[mask] = current[mask] 56 | return -torch.div(longest, sourceL) 57 | 58 | def create_dataset( 59 | train_size, 60 | val_size, 61 | #test_size, 62 | data_dir, 63 | data_len, 64 | seed=None): 65 | 66 | if seed is not None: 67 | torch.manual_seed(seed) 68 | 69 | train_task = 'sorting-size-{}-len-{}-train.txt'.format(train_size, data_len) 70 | val_task = 'sorting-size-{}-len-{}-val.txt'.format(val_size, data_len) 71 | #test_task = 'sorting-size-{}-len-{}-test.txt'.format(test_size, data_len) 72 | 73 | train_fname = os.path.join(data_dir, train_task) 74 | val_fname = os.path.join(data_dir, val_task) 75 | 76 | 77 | if not os.path.isdir(data_dir): 78 | os.mkdir(data_dir) 79 | else: 80 | if os.path.exists(train_fname) and os.path.exists(val_fname): 81 | return train_fname, val_fname 82 | 83 | train_set = open(os.path.join(data_dir, train_task), 'w') 84 | val_set = open(os.path.join(data_dir, val_task), 'w') 85 | #test_set = open(os.path.join(data_dir, test_task), 'w') 86 | 87 | def to_string(tensor): 88 | """ 89 | Convert a a torch.LongTensor 90 | of size data_len to a string 91 | of integers separated by whitespace 92 | and ending in a newline character 93 | """ 94 | line = '' 95 | for j in range(data_len-1): 96 | line += '{} '.format(tensor[j]) 97 | line += str(tensor[-1]) + '\n' 98 | return line 99 | 100 | print('Creating training data set for {}...'.format(train_task)) 101 | 102 | # Generate a training set of size train_size 103 | for i in trange(train_size): 104 | x = torch.randperm(data_len) 105 | train_set.write(to_string(x)) 106 | 107 | print('Creating validation data set for {}...'.format(val_task)) 108 | 109 | for i in trange(val_size): 110 | x = torch.randperm(data_len) 111 | val_set.write(to_string(x)) 112 | 113 | # print('Creating test data set for {}...'.format(test_task)) 114 | # 115 | # for i in trange(test_size): 116 | # x = torch.randperm(data_len) 117 | # test_set.write(to_string(x)) 118 | 119 | train_set.close() 120 | val_set.close() 121 | # test_set.close() 122 | return train_fname, val_fname 123 | 124 | class SortingDataset(Dataset): 125 | 126 | def __init__(self, dataset_fname): 127 | super(SortingDataset, self).__init__() 128 | 129 | print('Loading training data into memory') 130 | self.data_set = [] 131 | with open(dataset_fname, 'r') as dset: 132 | lines = dset.readlines() 133 | for next_line in tqdm(lines): 134 | toks = next_line.split() 135 | sample = torch.zeros(1, len(toks)).long() 136 | for idx, tok in enumerate(toks): 137 | sample[0, idx] = int(tok) 138 | self.data_set.append(sample) 139 | 140 | self.size = len(self.data_set) 141 | 142 | def __len__(self): 143 | return self.size 144 | 145 | def __getitem__(self, idx): 146 | return self.data_set[idx] 147 | 148 | if __name__ == '__main__': 149 | if int(sys.argv[1]) == 0: 150 | #sample = Variable(torch.Tensor([[3, 2, 1, 4, 5], [2, 3, 5, 1, 4]])) 151 | sample = [Variable(torch.Tensor([3,2])), Variable(torch.Tensor([2,3])), Variable(torch.Tensor([1,5])), 152 | Variable(torch.Tensor([4, 1])), Variable(torch.Tensor([5, 4]))] 153 | answer = torch.Tensor([3/5., 3/5]) 154 | 155 | res = reward(sample) 156 | 157 | print('Expected answer: {}, Actual answer: {}'.format(answer, res.data)) 158 | """ 159 | sample = Variable(torch.Tensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]])) 160 | answer = torch.Tensor([1., 1/5]) 161 | 162 | res = reward(sample) 163 | 164 | print('Expected answer: {}, Actual answer: {}'.format(answer, res.data)) 165 | 166 | sample = Variable(torch.Tensor([[1, 2, 5, 4, 3], [4, 1, 2, 3, 5]])) 167 | answer = torch.Tensor([3/5., 4/5]) 168 | 169 | res = reward(sample) 170 | 171 | print('Expected answer: {}, Actual answer: {}'.format(answer, res.data)) 172 | """ 173 | elif int(sys.argv[1]) == 1: 174 | create_sorting_dataset(1000, 100, 'data', 10, 123) 175 | elif int(sys.argv[1]) == 2: 176 | 177 | sorting_data = SortingDataset('data', 'sorting-size-1000-len-10-train.txt', 178 | 'sorting-size-100-len-10-val.txt') 179 | 180 | for i in range(len(sorting_data)): 181 | print(sorting_data[i]) 182 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | from tqdm import tqdm 6 | 7 | import pprint as pp 8 | import numpy as np 9 | 10 | import torch 11 | print(torch.__version__) 12 | import torch.optim as optim 13 | import torch.autograd as autograd 14 | from torch.optim import lr_scheduler 15 | from torch.autograd import Variable 16 | from torch.utils.data import DataLoader 17 | from tensorboard_logger import configure, log_value 18 | 19 | from neural_combinatorial_rl import NeuralCombOptRL 20 | from plot_attention import plot_attention 21 | 22 | 23 | def str2bool(v): 24 | return v.lower() in ('true', '1') 25 | 26 | parser = argparse.ArgumentParser(description="Neural Combinatorial Optimization with RL") 27 | 28 | # Data 29 | parser.add_argument('--task', default='sort_10', help="The task to solve, in the form {COP}_{size}, e.g., tsp_20") 30 | parser.add_argument('--batch_size', default=128, help='') 31 | parser.add_argument('--train_size', default=1000000, help='') 32 | parser.add_argument('--val_size', default=10000, help='') 33 | # Network 34 | parser.add_argument('--embedding_dim', default=128, help='Dimension of input embedding') 35 | parser.add_argument('--hidden_dim', default=128, help='Dimension of hidden layers in Enc/Dec') 36 | parser.add_argument('--n_process_blocks', default=3, help='Number of process block iters to run in the Critic network') 37 | parser.add_argument('--n_glimpses', default=2, help='No. of glimpses to use in the pointer network') 38 | parser.add_argument('--use_tanh', type=str2bool, default=True) 39 | parser.add_argument('--tanh_exploration', default=10, help='Hyperparam controlling exploration in the pointer net by scaling the tanh in the softmax') 40 | parser.add_argument('--dropout', default=0., help='') 41 | parser.add_argument('--terminating_symbol', default='<0>', help='') 42 | parser.add_argument('--beam_size', default=1, help='Beam width for beam search') 43 | 44 | # Training 45 | parser.add_argument('--actor_net_lr', default=1e-4, help="Set the learning rate for the actor network") 46 | parser.add_argument('--critic_net_lr', default=1e-4, help="Set the learning rate for the critic network") 47 | parser.add_argument('--actor_lr_decay_step', default=5000, help='') 48 | parser.add_argument('--critic_lr_decay_step', default=5000, help='') 49 | parser.add_argument('--actor_lr_decay_rate', default=0.96, help='') 50 | parser.add_argument('--critic_lr_decay_rate', default=0.96, help='') 51 | parser.add_argument('--reward_scale', default=2, type=float, help='') 52 | parser.add_argument('--is_train', type=str2bool, default=True, help='') 53 | parser.add_argument('--n_epochs', default=1, help='') 54 | parser.add_argument('--random_seed', default=24601, help='') 55 | parser.add_argument('--max_grad_norm', default=2.0, help='Gradient clipping') 56 | parser.add_argument('--use_cuda', type=str2bool, default=True, help='') 57 | parser.add_argument('--critic_beta', type=float, default=0.9, help='Exp mvg average decay') 58 | 59 | # Misc 60 | parser.add_argument('--log_step', default=50, help='Log info every log_step steps') 61 | parser.add_argument('--log_dir', type=str, default='logs') 62 | parser.add_argument('--run_name', type=str, default='0') 63 | parser.add_argument('--output_dir', type=str, default='outputs') 64 | parser.add_argument('--epoch_start', type=int, default=0, help='Restart at epoch #') 65 | parser.add_argument('--load_path', type=str, default='') 66 | parser.add_argument('--disable_tensorboard', type=str2bool, default=False) 67 | parser.add_argument('--plot_attention', type=str2bool, default=False) 68 | parser.add_argument('--disable_progress_bar', type=str2bool, default=False) 69 | 70 | args = vars(parser.parse_args()) 71 | 72 | # Pretty print the run args 73 | pp.pprint(args) 74 | 75 | # Set the random seed 76 | torch.manual_seed(int(args['random_seed'])) 77 | 78 | # Optionally configure tensorboard 79 | if not args['disable_tensorboard']: 80 | configure(os.path.join(args['log_dir'], args['task'], args['run_name'])) 81 | 82 | # Task specific configuration - generate dataset if needed 83 | task = args['task'].split('_') 84 | COP = task[0] 85 | size = int(task[1]) 86 | data_dir = 'data/' + COP 87 | 88 | if COP == 'sort': 89 | import sorting_task 90 | 91 | input_dim = 1 92 | reward_fn = sorting_task.reward 93 | train_fname, val_fname = sorting_task.create_dataset( 94 | int(args['train_size']), 95 | int(args['val_size']), 96 | data_dir, 97 | data_len=size) 98 | training_dataset = sorting_task.SortingDataset(train_fname) 99 | val_dataset = sorting_task.SortingDataset(val_fname) 100 | elif COP == 'tsp': 101 | import tsp_task 102 | 103 | input_dim = 2 104 | reward_fn = tsp_task.reward 105 | val_fname = tsp_task.create_dataset( 106 | problem_size=str(size), 107 | data_dir=data_dir) 108 | training_dataset = tsp_task.TSPDataset(train=True, size=size, 109 | num_samples=int(args['train_size'])) 110 | val_dataset = tsp_task.TSPDataset(train=True, size=size, 111 | num_samples=int(args['val_size'])) 112 | else: 113 | print('Currently unsupported task!') 114 | exit(1) 115 | 116 | # Load the model parameters from a saved state 117 | if args['load_path'] != '': 118 | print(' [*] Loading model from {}'.format(args['load_path'])) 119 | 120 | model = torch.load( 121 | os.path.join( 122 | os.getcwd(), 123 | args['load_path'] 124 | )) 125 | model.actor_net.decoder.max_length = size 126 | model.is_train = args['is_train'] 127 | else: 128 | # Instantiate the Neural Combinatorial Opt with RL module 129 | model = NeuralCombOptRL( 130 | input_dim, 131 | int(args['embedding_dim']), 132 | int(args['hidden_dim']), 133 | size, # decoder len 134 | args['terminating_symbol'], 135 | int(args['n_glimpses']), 136 | int(args['n_process_blocks']), 137 | float(args['tanh_exploration']), 138 | args['use_tanh'], 139 | int(args['beam_size']), 140 | reward_fn, 141 | args['is_train'], 142 | args['use_cuda']) 143 | 144 | 145 | save_dir = os.path.join(os.getcwd(), 146 | args['output_dir'], 147 | args['task'], 148 | args['run_name']) 149 | 150 | try: 151 | os.makedirs(save_dir) 152 | except: 153 | pass 154 | 155 | #critic_mse = torch.nn.MSELoss() 156 | #critic_optim = optim.Adam(model.critic_net.parameters(), lr=float(args['critic_net_lr'])) 157 | actor_optim = optim.Adam(model.actor_net.parameters(), lr=float(args['actor_net_lr'])) 158 | 159 | actor_scheduler = lr_scheduler.MultiStepLR(actor_optim, 160 | range(int(args['actor_lr_decay_step']), int(args['actor_lr_decay_step']) * 1000, 161 | int(args['actor_lr_decay_step'])), gamma=float(args['actor_lr_decay_rate'])) 162 | 163 | #critic_scheduler = lr_scheduler.MultiStepLR(critic_optim, 164 | # range(int(args['critic_lr_decay_step']), int(args['critic_lr_decay_step']) * 1000, 165 | # int(args['critic_lr_decay_step'])), gamma=float(args['critic_lr_decay_rate'])) 166 | 167 | training_dataloader = DataLoader(training_dataset, batch_size=int(args['batch_size']), 168 | shuffle=True, num_workers=4) 169 | 170 | validation_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=1) 171 | 172 | critic_exp_mvg_avg = torch.zeros(1) 173 | beta = args['critic_beta'] 174 | 175 | if args['use_cuda']: 176 | model = model.cuda() 177 | #critic_mse = critic_mse.cuda() 178 | critic_exp_mvg_avg = critic_exp_mvg_avg.cuda() 179 | 180 | step = 0 181 | val_step = 0 182 | 183 | if not args['is_train']: 184 | args['n_epochs'] = '1' 185 | 186 | 187 | epoch = int(args['epoch_start']) 188 | for i in range(epoch, epoch + int(args['n_epochs'])): 189 | 190 | if args['is_train']: 191 | # put in train mode! 192 | model.train() 193 | 194 | # sample_batch is [batch_size x input_dim x sourceL] 195 | for batch_id, sample_batch in enumerate(tqdm(training_dataloader, 196 | disable=args['disable_progress_bar'])): 197 | 198 | 199 | bat = Variable(sample_batch) 200 | if args['use_cuda']: 201 | bat = bat.cuda() 202 | 203 | R, probs, actions, actions_idxs = model(bat) 204 | 205 | if batch_id == 0: 206 | critic_exp_mvg_avg = R.mean() 207 | else: 208 | critic_exp_mvg_avg = (critic_exp_mvg_avg * beta) + ((1. - beta) * R.mean()) 209 | 210 | advantage = R - critic_exp_mvg_avg 211 | 212 | logprobs = 0 213 | nll = 0 214 | for prob in probs: 215 | # compute the sum of the log probs 216 | # for each tour in the batch 217 | logprob = torch.log(prob) 218 | nll += -logprob 219 | logprobs += logprob 220 | 221 | # guard against nan 222 | nll[(nll != nll).detach()] = 0. 223 | # clamp any -inf's to 0 to throw away this tour 224 | logprobs[(logprobs < -1000).detach()] = 0. 225 | 226 | # multiply each time step by the advanrate 227 | reinforce = advantage * logprobs 228 | actor_loss = reinforce.mean() 229 | 230 | actor_optim.zero_grad() 231 | 232 | actor_loss.backward() 233 | 234 | # clip gradient norms 235 | torch.nn.utils.clip_grad_norm(model.actor_net.parameters(), 236 | float(args['max_grad_norm']), norm_type=2) 237 | 238 | actor_optim.step() 239 | actor_scheduler.step() 240 | 241 | critic_exp_mvg_avg = critic_exp_mvg_avg.detach() 242 | 243 | #critic_scheduler.step() 244 | 245 | #R = R.detach() 246 | #critic_loss = critic_mse(v.squeeze(1), R) 247 | #critic_optim.zero_grad() 248 | #critic_loss.backward() 249 | 250 | #torch.nn.utils.clip_grad_norm(model.critic_net.parameters(), 251 | # float(args['max_grad_norm']), norm_type=2) 252 | 253 | #critic_optim.step() 254 | 255 | step += 1 256 | 257 | if not args['disable_tensorboard']: 258 | log_value('avg_reward', R.mean().data[0], step) 259 | log_value('actor_loss', actor_loss.data[0], step) 260 | #log_value('critic_loss', critic_loss.data[0], step) 261 | log_value('critic_exp_mvg_avg', critic_exp_mvg_avg.data[0], step) 262 | log_value('nll', nll.mean().data[0], step) 263 | 264 | if step % int(args['log_step']) == 0: 265 | print('epoch: {}, train_batch_id: {}, avg_reward: {}'.format( 266 | i, batch_id, R.mean().data[0])) 267 | example_output = [] 268 | example_input = [] 269 | for idx, action in enumerate(actions): 270 | if task[0] == 'tsp': 271 | example_output.append(actions_idxs[idx][0].data[0]) 272 | else: 273 | example_output.append(action[0].data[0]) # <-- ?? 274 | example_input.append(sample_batch[0, :, idx][0]) 275 | #print('Example train input: {}'.format(example_input)) 276 | print('Example train output: {}'.format(example_output)) 277 | 278 | # Use beam search decoding for validation 279 | model.actor_net.decoder.decode_type = "beam_search" 280 | 281 | print('\n~Validating~\n') 282 | 283 | example_input = [] 284 | example_output = [] 285 | avg_reward = [] 286 | 287 | # put in test mode! 288 | model.eval() 289 | 290 | for batch_id, val_batch in enumerate(tqdm(validation_dataloader, 291 | disable=args['disable_progress_bar'])): 292 | bat = Variable(val_batch) 293 | 294 | if args['use_cuda']: 295 | bat = bat.cuda() 296 | 297 | R, probs, actions, action_idxs = model(bat) 298 | 299 | avg_reward.append(R[0].data[0]) 300 | val_step += 1. 301 | 302 | if not args['disable_tensorboard']: 303 | log_value('val_avg_reward', R[0].data[0], int(val_step)) 304 | 305 | if val_step % int(args['log_step']) == 0: 306 | example_output = [] 307 | example_input = [] 308 | for idx, action in enumerate(actions): 309 | if task[0] == 'tsp': 310 | example_output.append(action_idxs[idx][0].data[0]) 311 | else: 312 | example_output.append(action[0].data[0]) 313 | example_input.append(bat[0, :, idx].data[0]) 314 | print('Step: {}'.format(batch_id)) 315 | #print('Example test input: {}'.format(example_input)) 316 | print('Example test output: {}'.format(example_output)) 317 | print('Example test reward: {}'.format(R[0].data[0])) 318 | 319 | 320 | if args['plot_attention']: 321 | probs = torch.cat(probs, 0) 322 | plot_attention(example_input, 323 | example_output, probs.data.cpu().numpy()) 324 | print('Validation overall avg_reward: {}'.format(np.mean(avg_reward))) 325 | print('Validation overall reward var: {}'.format(np.var(avg_reward))) 326 | 327 | if args['is_train']: 328 | model.actor_net.decoder.decode_type = "stochastic" 329 | 330 | print('Saving model...') 331 | 332 | torch.save(model, os.path.join(save_dir, 'epoch-{}.pt'.format(i))) 333 | 334 | # If the task requires generating new data after each epoch, do that here! 335 | if COP == 'tsp': 336 | training_dataset = tsp_task.TSPDataset(train=True, size=size, 337 | num_samples=int(args['train_size'])) 338 | training_dataloader = DataLoader(training_dataset, batch_size=int(args['batch_size']), 339 | shuffle=True, num_workers=1) 340 | if COP == 'sort': 341 | train_fname, _ = sorting_task.create_dataset( 342 | int(args['train_size']), 343 | int(args['val_size']), 344 | data_dir, 345 | data_len=size) 346 | training_dataset = sorting_task.SortingDataset(train_fname) 347 | training_dataloader = DataLoader(training_dataset, batch_size=int(args['batch_size']), 348 | shuffle=True, num_workers=1) 349 | -------------------------------------------------------------------------------- /tsp_task.py: -------------------------------------------------------------------------------- 1 | # code based in part on 2 | # http://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039 3 | # and from 4 | # https://github.com/devsisters/neural-combinatorial-rl-tensorflow/blob/master/data_loader.py 5 | import requests 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | from torch.autograd import Variable 9 | import torch 10 | import os 11 | import numpy as np 12 | import re 13 | import zipfile 14 | import itertools 15 | from collections import namedtuple 16 | 17 | 18 | ####################################### 19 | # Reward Fn 20 | ####################################### 21 | def reward(sample_solution, USE_CUDA=False): 22 | """ 23 | Args: 24 | List of length sourceL of [batch_size] Tensors 25 | Returns: 26 | Tensor of shape [batch_size] containins rewards 27 | """ 28 | batch_size = sample_solution[0].size(0) 29 | n = len(sample_solution) 30 | tour_len = Variable(torch.zeros([batch_size])) 31 | 32 | if USE_CUDA: 33 | tour_len = tour_len.cuda() 34 | 35 | for i in range(n-1): 36 | tour_len += torch.norm(sample_solution[i] - sample_solution[i+1], dim=1) 37 | 38 | tour_len += torch.norm(sample_solution[n-1] - sample_solution[0], dim=1) 39 | 40 | # For TSP_20 - map to a number between 0 and 1 41 | # min_len = 3.5 42 | # max_len = 10. 43 | # TODO: generalize this for any TSP size 44 | #tour_len = -0.1538*tour_len + 1.538 45 | #tour_len[tour_len < 0.] = 0. 46 | return tour_len 47 | 48 | 49 | ####################################### 50 | # Functions for downloading dataset 51 | ####################################### 52 | TSP = namedtuple('TSP', ['x', 'y', 'name']) 53 | 54 | GOOGLE_DRIVE_IDS = { 55 | 'tsp5_train.zip': '0B2fg8yPGn2TCSW1pNTJMXzFPYTg', 56 | 'tsp10_train.zip': '0B2fg8yPGn2TCbHowM0hfOTJCNkU', 57 | 'tsp5-20_train.zip': '0B2fg8yPGn2TCTWNxX21jTDBGeXc', 58 | 'tsp50_train.zip': '0B2fg8yPGn2TCaVQxSl9ab29QajA', 59 | 'tsp20_test.txt': '0B2fg8yPGn2TCdF9TUU5DZVNCNjQ', 60 | 'tsp40_test.txt': '0B2fg8yPGn2TCcjFrYk85SGFVNlU', 61 | 'tsp50_test.txt.zip': '0B2fg8yPGn2TCUVlCQmQtelpZTTQ', 62 | } 63 | 64 | def download_file_from_google_drive(id, destination): 65 | URL = "https://docs.google.com/uc?export=download" 66 | 67 | session = requests.Session() 68 | 69 | response = session.get(URL, params = { 'id' : id }, stream = True) 70 | token = get_confirm_token(response) 71 | 72 | if token: 73 | params = { 'id' : id, 'confirm' : token } 74 | response = session.get(URL, params = params, stream = True) 75 | 76 | save_response_content(response, destination) 77 | return True 78 | 79 | def get_confirm_token(response): 80 | for key, value in response.cookies.items(): 81 | if key.startswith('download_warning'): 82 | return value 83 | return None 84 | 85 | def save_response_content(response, destination): 86 | CHUNK_SIZE = 32768 87 | 88 | with open(destination, "wb") as f: 89 | for chunk in tqdm(response.iter_content(CHUNK_SIZE)): 90 | if chunk: # filter out keep-alive new chunks 91 | f.write(chunk) 92 | 93 | def download_google_drive_file(data_dir, task, min_length, max_length): 94 | paths = {} 95 | for mode in ['train', 'test']: 96 | candidates = [] 97 | candidates.append( 98 | '{}{}_{}'.format(task, max_length, mode)) 99 | candidates.append( 100 | '{}{}-{}_{}'.format(task, min_length, max_length, mode)) 101 | 102 | for key in candidates: 103 | print(key) 104 | for search_key in GOOGLE_DRIVE_IDS.keys(): 105 | if search_key.startswith(key): 106 | path = os.path.join(data_dir, search_key) 107 | print("Download dataset of the paper to {}".format(path)) 108 | 109 | if not os.path.exists(path): 110 | download_file_from_google_drive(GOOGLE_DRIVE_IDS[search_key], path) 111 | if path.endswith('zip'): 112 | with zipfile.ZipFile(path, 'r') as z: 113 | z.extractall(data_dir) 114 | paths[mode] = path 115 | 116 | return paths 117 | 118 | def read_paper_dataset(paths, max_length): 119 | x, y = [], [] 120 | for path in paths: 121 | print("Read dataset {} which is used in the paper..".format(path)) 122 | length = max(re.findall('\d+', path)) 123 | with open(path) as f: 124 | for l in tqdm(f): 125 | inputs, outputs = l.split(' output ') 126 | x.append(np.array(inputs.split(), dtype=np.float32).reshape([-1, 2])) 127 | y.append(np.array(outputs.split(), dtype=np.int32)[:-1]) # skip the last one 128 | 129 | return x, y 130 | 131 | def maybe_generate_and_save(self, except_list=[]): 132 | data = {} 133 | 134 | for name, num in self.data_num.items(): 135 | if name in except_list: 136 | print("Skip creating {} because of given except_list {}".format(name, except_list)) 137 | continue 138 | path = self.get_path(name) 139 | 140 | print("Skip creating {} for [{}]".format(path, self.task)) 141 | tmp = np.load(path) 142 | self.data[name] = TSP(x=tmp['x'], y=tmp['y'], name=name) 143 | 144 | def get_path(self, name): 145 | return os.path.join( 146 | self.data_dir, "{}_{}={}.npz".format( 147 | self.task_name, name, self.data_num[name])) 148 | 149 | def read_zip_and_update_data(self, path, name): 150 | if path.endswith('zip'): 151 | filenames = zipfile.ZipFile(path).namelist() 152 | paths = [os.path.join(self.data_dir, filename) for filename in filenames] 153 | else: 154 | paths = [path] 155 | 156 | x_list, y_list = read_paper_dataset(paths, self.max_length) 157 | 158 | x = np.zeros([len(x_list), self.max_length, 2], dtype=np.float32) 159 | y = np.zeros([len(y_list), self.max_length], dtype=np.int32) 160 | 161 | for idx, (nodes, res) in enumerate(tqdm(zip(x_list, y_list))): 162 | x[idx,:len(nodes)] = nodes 163 | y[idx,:len(res)] = res 164 | 165 | if self.data is None: 166 | self.data = {} 167 | 168 | print("Update [{}] data with {} used in the paper".format(name, path)) 169 | self.data[name] = TSP(x=x, y=y, name=name) 170 | 171 | 172 | def create_dataset( 173 | problem_size, 174 | data_dir): 175 | 176 | def find_or_return_empty(data_dir, problem_size): 177 | #train_fname1 = os.path.join(data_dir, 'tsp{}.txt'.format(problem_size)) 178 | val_fname1 = os.path.join(data_dir, 'tsp{}_test.txt'.format(problem_size)) 179 | #train_fname2 = os.path.join(data_dir, 'tsp-{}.txt'.format(problem_size)) 180 | val_fname2 = os.path.join(data_dir, 'tsp-{}_test.txt'.format(problem_size)) 181 | 182 | if not os.path.isdir(data_dir): 183 | os.mkdir(data_dir) 184 | else: 185 | # if os.path.exists(train_fname1) and os.path.exists(val_fname1): 186 | # return train_fname1, val_fname1 187 | # if os.path.exists(train_fname2) and os.path.exists(val_fname2): 188 | # return train_fname2, val_fname2 189 | # return None, None 190 | 191 | # train, val = find_or_return_empty(data_dir, problem_size) 192 | # if train is None and val is None: 193 | # download_google_drive_file(data_dir, 194 | # 'tsp', '', problem_size) 195 | # train, val = find_or_return_empty(data_dir, problem_size) 196 | 197 | # return train, val 198 | if os.path.exists(val_fname1): 199 | return val_fname1 200 | if os.path.exists(val_fname2): 201 | return val_fname2 202 | return None 203 | 204 | val = find_or_return_empty(data_dir, problem_size) 205 | if val is None: 206 | download_google_drive_file(data_dir, 'tsp', '', problem_size) 207 | val = find_or_return_empty(data_dir, problem_size) 208 | 209 | return val 210 | 211 | 212 | ####################################### 213 | # Dataset 214 | ####################################### 215 | class TSPDataset(Dataset): 216 | 217 | def __init__(self, dataset_fname=None, train=False, size=50, num_samples=1000000, random_seed=1111): 218 | super(TSPDataset, self).__init__() 219 | #start = torch.FloatTensor([[-1], [-1]]) 220 | 221 | torch.manual_seed(random_seed) 222 | 223 | self.data_set = [] 224 | if not train: 225 | with open(dataset_fname, 'r') as dset: 226 | for l in tqdm(dset): 227 | inputs, outputs = l.split(' output ') 228 | sample = torch.zeros(1, ) 229 | x = np.array(inputs.split(), dtype=np.float32).reshape([-1, 2]).T 230 | #y.append(np.array(outputs.split(), dtype=np.int32)[:-1]) # skip the last one 231 | self.data_set.append(x) 232 | else: 233 | # randomly sample points uniformly from [0, 1] 234 | for l in tqdm(range(num_samples)): 235 | x = torch.FloatTensor(2, size).uniform_(0, 1) 236 | #x = torch.cat([start, x], 1) 237 | self.data_set.append(x) 238 | 239 | self.size = len(self.data_set) 240 | 241 | def __len__(self): 242 | return self.size 243 | 244 | def __getitem__(self, idx): 245 | return self.data_set[idx] 246 | 247 | if __name__ == '__main__': 248 | paths = download_google_drive_file('data/tsp', 'tsp', '', '50') 249 | --------------------------------------------------------------------------------