├── Data_Generator.py ├── PointerNet.py ├── README.md └── Train.py /Data_Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import itertools 5 | from tqdm import tqdm 6 | 7 | 8 | def tsp_opt(points): 9 | """ 10 | Dynamic programing solution for TSP - O(2^n*n^2) 11 | https://gist.github.com/mlalevic/6222750 12 | 13 | :param points: List of (x, y) points 14 | :return: Optimal solution 15 | """ 16 | 17 | def length(x_coord, y_coord): 18 | return np.linalg.norm(np.asarray(x_coord) - np.asarray(y_coord)) 19 | 20 | # Calculate all lengths 21 | all_distances = [[length(x, y) for y in points] for x in points] 22 | # Initial value - just distance from 0 to every other point + keep the track of edges 23 | A = {(frozenset([0, idx+1]), idx+1): (dist, [0, idx+1]) for idx, dist in enumerate(all_distances[0][1:])} 24 | cnt = len(points) 25 | for m in range(2, cnt): 26 | B = {} 27 | for S in [frozenset(C) | {0} for C in itertools.combinations(range(1, cnt), m)]: 28 | for j in S - {0}: 29 | # This will use 0th index of tuple for ordering, the same as if key=itemgetter(0) used 30 | B[(S, j)] = min([(A[(S-{j}, k)][0] + all_distances[k][j], A[(S-{j}, k)][1] + [j]) 31 | for k in S if k != 0 and k != j]) 32 | A = B 33 | res = min([(A[d][0] + all_distances[0][d[1]], A[d][1]) for d in iter(A)]) 34 | return np.asarray(res[1]) 35 | 36 | 37 | class TSPDataset(Dataset): 38 | """ 39 | Random TSP dataset 40 | 41 | """ 42 | 43 | def __init__(self, data_size, seq_len, solver=tsp_opt, solve=True): 44 | self.data_size = data_size 45 | self.seq_len = seq_len 46 | self.solve = solve 47 | self.solver = solver 48 | self.data = self._generate_data() 49 | 50 | def __len__(self): 51 | return self.data_size 52 | 53 | def __getitem__(self, idx): 54 | tensor = torch.from_numpy(self.data['Points_List'][idx]).float() 55 | solution = torch.from_numpy(self.data['Solutions'][idx]).long() if self.solve else None 56 | 57 | sample = {'Points':tensor, 'Solution':solution} 58 | 59 | return sample 60 | 61 | def _generate_data(self): 62 | """ 63 | :return: Set of points_list ans their One-Hot vector solutions 64 | """ 65 | points_list = [] 66 | solutions = [] 67 | data_iter = tqdm(range(self.data_size), unit='data') 68 | for i, _ in enumerate(data_iter): 69 | data_iter.set_description('Data points %i/%i' % (i+1, self.data_size)) 70 | points_list.append(np.random.random((self.seq_len, 2))) 71 | solutions_iter = tqdm(points_list, unit='solve') 72 | if self.solve: 73 | for i, points in enumerate(solutions_iter): 74 | solutions_iter.set_description('Solved %i/%i' % (i+1, len(points_list))) 75 | solutions.append(self.solver(points)) 76 | else: 77 | solutions = None 78 | 79 | return {'Points_List':points_list, 'Solutions':solutions} 80 | 81 | def _to1hotvec(self, points): 82 | """ 83 | :param points: List of integers representing the points indexes 84 | :return: Matrix of One-Hot vectors 85 | """ 86 | vec = np.zeros((len(points), self.seq_len)) 87 | for i, v in enumerate(vec): 88 | v[points[i]] = 1 89 | 90 | return vec 91 | -------------------------------------------------------------------------------- /PointerNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | 6 | 7 | class Encoder(nn.Module): 8 | """ 9 | Encoder class for Pointer-Net 10 | """ 11 | 12 | def __init__(self, embedding_dim, 13 | hidden_dim, 14 | n_layers, 15 | dropout, 16 | bidir): 17 | """ 18 | Initiate Encoder 19 | 20 | :param Tensor embedding_dim: Number of embbeding channels 21 | :param int hidden_dim: Number of hidden units for the LSTM 22 | :param int n_layers: Number of layers for LSTMs 23 | :param float dropout: Float between 0-1 24 | :param bool bidir: Bidirectional 25 | """ 26 | 27 | super(Encoder, self).__init__() 28 | self.hidden_dim = hidden_dim//2 if bidir else hidden_dim 29 | self.n_layers = n_layers*2 if bidir else n_layers 30 | self.bidir = bidir 31 | self.lstm = nn.LSTM(embedding_dim, 32 | self.hidden_dim, 33 | n_layers, 34 | dropout=dropout, 35 | bidirectional=bidir) 36 | 37 | # Used for propagating .cuda() command 38 | self.h0 = Parameter(torch.zeros(1), requires_grad=False) 39 | self.c0 = Parameter(torch.zeros(1), requires_grad=False) 40 | 41 | def forward(self, embedded_inputs, 42 | hidden): 43 | """ 44 | Encoder - Forward-pass 45 | 46 | :param Tensor embedded_inputs: Embedded inputs of Pointer-Net 47 | :param Tensor hidden: Initiated hidden units for the LSTMs (h, c) 48 | :return: LSTMs outputs and hidden units (h, c) 49 | """ 50 | 51 | embedded_inputs = embedded_inputs.permute(1, 0, 2) 52 | 53 | outputs, hidden = self.lstm(embedded_inputs, hidden) 54 | 55 | return outputs.permute(1, 0, 2), hidden 56 | 57 | def init_hidden(self, embedded_inputs): 58 | """ 59 | Initiate hidden units 60 | 61 | :param Tensor embedded_inputs: The embedded input of Pointer-NEt 62 | :return: Initiated hidden units for the LSTMs (h, c) 63 | """ 64 | 65 | batch_size = embedded_inputs.size(0) 66 | 67 | # Reshaping (Expanding) 68 | h0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers, 69 | batch_size, 70 | self.hidden_dim) 71 | c0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers, 72 | batch_size, 73 | self.hidden_dim) 74 | 75 | return h0, c0 76 | 77 | 78 | class Attention(nn.Module): 79 | """ 80 | Attention model for Pointer-Net 81 | """ 82 | 83 | def __init__(self, input_dim, 84 | hidden_dim): 85 | """ 86 | Initiate Attention 87 | 88 | :param int input_dim: Input's diamention 89 | :param int hidden_dim: Number of hidden units in the attention 90 | """ 91 | 92 | super(Attention, self).__init__() 93 | 94 | self.input_dim = input_dim 95 | self.hidden_dim = hidden_dim 96 | 97 | self.input_linear = nn.Linear(input_dim, hidden_dim) 98 | self.context_linear = nn.Conv1d(input_dim, hidden_dim, 1, 1) 99 | self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True) 100 | self._inf = Parameter(torch.FloatTensor([float('-inf')]), requires_grad=False) 101 | self.tanh = nn.Tanh() 102 | self.softmax = nn.Softmax() 103 | 104 | # Initialize vector V 105 | nn.init.uniform(self.V, -1, 1) 106 | 107 | def forward(self, input, 108 | context, 109 | mask): 110 | """ 111 | Attention - Forward-pass 112 | 113 | :param Tensor input: Hidden state h 114 | :param Tensor context: Attention context 115 | :param ByteTensor mask: Selection mask 116 | :return: tuple of - (Attentioned hidden state, Alphas) 117 | """ 118 | 119 | # (batch, hidden_dim, seq_len) 120 | inp = self.input_linear(input).unsqueeze(2).expand(-1, -1, context.size(1)) 121 | 122 | # (batch, hidden_dim, seq_len) 123 | context = context.permute(0, 2, 1) 124 | ctx = self.context_linear(context) 125 | 126 | # (batch, 1, hidden_dim) 127 | V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1) 128 | 129 | # (batch, seq_len) 130 | att = torch.bmm(V, self.tanh(inp + ctx)).squeeze(1) 131 | if len(att[mask]) > 0: 132 | att[mask] = self.inf[mask] 133 | alpha = self.softmax(att) 134 | 135 | hidden_state = torch.bmm(ctx, alpha.unsqueeze(2)).squeeze(2) 136 | 137 | return hidden_state, alpha 138 | 139 | def init_inf(self, mask_size): 140 | self.inf = self._inf.unsqueeze(1).expand(*mask_size) 141 | 142 | 143 | class Decoder(nn.Module): 144 | """ 145 | Decoder model for Pointer-Net 146 | """ 147 | 148 | def __init__(self, embedding_dim, 149 | hidden_dim): 150 | """ 151 | Initiate Decoder 152 | 153 | :param int embedding_dim: Number of embeddings in Pointer-Net 154 | :param int hidden_dim: Number of hidden units for the decoder's RNN 155 | """ 156 | 157 | super(Decoder, self).__init__() 158 | self.embedding_dim = embedding_dim 159 | self.hidden_dim = hidden_dim 160 | 161 | self.input_to_hidden = nn.Linear(embedding_dim, 4 * hidden_dim) 162 | self.hidden_to_hidden = nn.Linear(hidden_dim, 4 * hidden_dim) 163 | self.hidden_out = nn.Linear(hidden_dim * 2, hidden_dim) 164 | self.att = Attention(hidden_dim, hidden_dim) 165 | 166 | # Used for propagating .cuda() command 167 | self.mask = Parameter(torch.ones(1), requires_grad=False) 168 | self.runner = Parameter(torch.zeros(1), requires_grad=False) 169 | 170 | def forward(self, embedded_inputs, 171 | decoder_input, 172 | hidden, 173 | context): 174 | """ 175 | Decoder - Forward-pass 176 | 177 | :param Tensor embedded_inputs: Embedded inputs of Pointer-Net 178 | :param Tensor decoder_input: First decoder's input 179 | :param Tensor hidden: First decoder's hidden states 180 | :param Tensor context: Encoder's outputs 181 | :return: (Output probabilities, Pointers indices), last hidden state 182 | """ 183 | 184 | batch_size = embedded_inputs.size(0) 185 | input_length = embedded_inputs.size(1) 186 | 187 | # (batch, seq_len) 188 | mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1) 189 | self.att.init_inf(mask.size()) 190 | 191 | # Generating arang(input_length), broadcasted across batch_size 192 | runner = self.runner.repeat(input_length) 193 | for i in range(input_length): 194 | runner.data[i] = i 195 | runner = runner.unsqueeze(0).expand(batch_size, -1).long() 196 | 197 | outputs = [] 198 | pointers = [] 199 | 200 | def step(x, hidden): 201 | """ 202 | Recurrence step function 203 | 204 | :param Tensor x: Input at time t 205 | :param tuple(Tensor, Tensor) hidden: Hidden states at time t-1 206 | :return: Hidden states at time t (h, c), Attention probabilities (Alpha) 207 | """ 208 | 209 | # Regular LSTM 210 | h, c = hidden 211 | 212 | gates = self.input_to_hidden(x) + self.hidden_to_hidden(h) 213 | input, forget, cell, out = gates.chunk(4, 1) 214 | 215 | input = F.sigmoid(input) 216 | forget = F.sigmoid(forget) 217 | cell = F.tanh(cell) 218 | out = F.sigmoid(out) 219 | 220 | c_t = (forget * c) + (input * cell) 221 | h_t = out * F.tanh(c_t) 222 | 223 | # Attention section 224 | hidden_t, output = self.att(h_t, context, torch.eq(mask, 0)) 225 | hidden_t = F.tanh(self.hidden_out(torch.cat((hidden_t, h_t), 1))) 226 | 227 | return hidden_t, c_t, output 228 | 229 | # Recurrence loop 230 | for _ in range(input_length): 231 | h_t, c_t, outs = step(decoder_input, hidden) 232 | hidden = (h_t, c_t) 233 | 234 | # Masking selected inputs 235 | masked_outs = outs * mask 236 | 237 | # Get maximum probabilities and indices 238 | max_probs, indices = masked_outs.max(1) 239 | one_hot_pointers = (runner == indices.unsqueeze(1).expand(-1, outs.size()[1])).float() 240 | 241 | # Update mask to ignore seen indices 242 | mask = mask * (1 - one_hot_pointers) 243 | 244 | # Get embedded inputs by max indices 245 | embedding_mask = one_hot_pointers.unsqueeze(2).expand(-1, -1, self.embedding_dim).byte() 246 | decoder_input = embedded_inputs[embedding_mask.data].view(batch_size, self.embedding_dim) 247 | 248 | outputs.append(outs.unsqueeze(0)) 249 | pointers.append(indices.unsqueeze(1)) 250 | 251 | outputs = torch.cat(outputs).permute(1, 0, 2) 252 | pointers = torch.cat(pointers, 1) 253 | 254 | return (outputs, pointers), hidden 255 | 256 | 257 | class PointerNet(nn.Module): 258 | """ 259 | Pointer-Net 260 | """ 261 | 262 | def __init__(self, embedding_dim, 263 | hidden_dim, 264 | lstm_layers, 265 | dropout, 266 | bidir=False): 267 | """ 268 | Initiate Pointer-Net 269 | 270 | :param int embedding_dim: Number of embbeding channels 271 | :param int hidden_dim: Encoders hidden units 272 | :param int lstm_layers: Number of layers for LSTMs 273 | :param float dropout: Float between 0-1 274 | :param bool bidir: Bidirectional 275 | """ 276 | 277 | super(PointerNet, self).__init__() 278 | self.embedding_dim = embedding_dim 279 | self.bidir = bidir 280 | self.embedding = nn.Linear(2, embedding_dim) 281 | self.encoder = Encoder(embedding_dim, 282 | hidden_dim, 283 | lstm_layers, 284 | dropout, 285 | bidir) 286 | self.decoder = Decoder(embedding_dim, hidden_dim) 287 | self.decoder_input0 = Parameter(torch.FloatTensor(embedding_dim), requires_grad=False) 288 | 289 | # Initialize decoder_input0 290 | nn.init.uniform(self.decoder_input0, -1, 1) 291 | 292 | def forward(self, inputs): 293 | """ 294 | PointerNet - Forward-pass 295 | 296 | :param Tensor inputs: Input sequence 297 | :return: Pointers probabilities and indices 298 | """ 299 | 300 | batch_size = inputs.size(0) 301 | input_length = inputs.size(1) 302 | 303 | decoder_input0 = self.decoder_input0.unsqueeze(0).expand(batch_size, -1) 304 | 305 | inputs = inputs.view(batch_size * input_length, -1) 306 | embedded_inputs = self.embedding(inputs).view(batch_size, input_length, -1) 307 | 308 | encoder_hidden0 = self.encoder.init_hidden(embedded_inputs) 309 | encoder_outputs, encoder_hidden = self.encoder(embedded_inputs, 310 | encoder_hidden0) 311 | if self.bidir: 312 | decoder_hidden0 = (torch.cat(encoder_hidden[0][-2:], dim=-1), 313 | torch.cat(encoder_hidden[1][-2:], dim=-1)) 314 | else: 315 | decoder_hidden0 = (encoder_hidden[0][-1], 316 | encoder_hidden[1][-1]) 317 | (outputs, pointers), decoder_hidden = self.decoder(embedded_inputs, 318 | decoder_input0, 319 | decoder_hidden0, 320 | encoder_outputs) 321 | 322 | return outputs, pointers -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointerNet 2 | Pytorch implementation of Pointer Network - [Link](http://arxiv.org/pdf/1506.03134v1.pdf) 3 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Pytorch implementation of Pointer Network. 4 | 5 | http://arxiv.org/pdf/1506.03134v1.pdf. 6 | 7 | """ 8 | 9 | import torch 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | 15 | import numpy as np 16 | import argparse 17 | from tqdm import tqdm 18 | 19 | from PointerNet import PointerNet 20 | from Data_Generator import TSPDataset 21 | 22 | parser = argparse.ArgumentParser(description="Pytorch implementation of Pointer-Net") 23 | 24 | # Data 25 | parser.add_argument('--train_size', default=1000000, type=int, help='Training data size') 26 | parser.add_argument('--val_size', default=10000, type=int, help='Validation data size') 27 | parser.add_argument('--test_size', default=10000, type=int, help='Test data size') 28 | parser.add_argument('--batch_size', default=256, type=int, help='Batch size') 29 | # Train 30 | parser.add_argument('--nof_epoch', default=50000, type=int, help='Number of epochs') 31 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') 32 | # GPU 33 | parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu') 34 | # TSP 35 | parser.add_argument('--nof_points', type=int, default=5, help='Number of points in TSP') 36 | # Network 37 | parser.add_argument('--embedding_size', type=int, default=128, help='Embedding size') 38 | parser.add_argument('--hiddens', type=int, default=512, help='Number of hidden units') 39 | parser.add_argument('--nof_lstms', type=int, default=2, help='Number of LSTM layers') 40 | parser.add_argument('--dropout', type=float, default=0., help='Dropout value') 41 | parser.add_argument('--bidir', default=True, action='store_true', help='Bidirectional') 42 | 43 | params = parser.parse_args() 44 | 45 | if params.gpu and torch.cuda.is_available(): 46 | USE_CUDA = True 47 | print('Using GPU, %i devices.' % torch.cuda.device_count()) 48 | else: 49 | USE_CUDA = False 50 | 51 | model = PointerNet(params.embedding_size, 52 | params.hiddens, 53 | params.nof_lstms, 54 | params.dropout, 55 | params.bidir) 56 | 57 | dataset = TSPDataset(params.train_size, 58 | params.nof_points) 59 | 60 | dataloader = DataLoader(dataset, 61 | batch_size=params.batch_size, 62 | shuffle=True, 63 | num_workers=4) 64 | 65 | if USE_CUDA: 66 | model.cuda() 67 | net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 68 | cudnn.benchmark = True 69 | 70 | CCE = torch.nn.CrossEntropyLoss() 71 | model_optim = optim.Adam(filter(lambda p: p.requires_grad, 72 | model.parameters()), 73 | lr=params.lr) 74 | 75 | losses = [] 76 | 77 | for epoch in range(params.nof_epoch): 78 | batch_loss = [] 79 | iterator = tqdm(dataloader, unit='Batch') 80 | 81 | for i_batch, sample_batched in enumerate(iterator): 82 | iterator.set_description('Batch %i/%i' % (epoch+1, params.nof_epoch)) 83 | 84 | train_batch = Variable(sample_batched['Points']) 85 | target_batch = Variable(sample_batched['Solution']) 86 | 87 | if USE_CUDA: 88 | train_batch = train_batch.cuda() 89 | target_batch = target_batch.cuda() 90 | 91 | o, p = model(train_batch) 92 | o = o.contiguous().view(-1, o.size()[-1]) 93 | 94 | target_batch = target_batch.view(-1) 95 | 96 | loss = CCE(o, target_batch) 97 | 98 | losses.append(loss.data[0]) 99 | batch_loss.append(loss.data[0]) 100 | 101 | model_optim.zero_grad() 102 | loss.backward() 103 | model_optim.step() 104 | 105 | iterator.set_postfix(loss='{}'.format(loss.data[0])) 106 | 107 | iterator.set_postfix(loss=np.average(batch_loss)) 108 | --------------------------------------------------------------------------------