├── README.md ├── main.py ├── model.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-PtrNet 2 | PyTorch implementation of PtrNet to solve sorting problem. 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import train 8 | 9 | 10 | def generate_sanity_check_batch(len, batch_size=64): 11 | out = np.arange(len).reshape((1,len)) 12 | for i in range(batch_size-1): 13 | out = np.append(out, np.arange(len).reshape((1,len)), axis=0) 14 | return out 15 | 16 | 17 | def generate_random_batch(len, batch_size=64): 18 | return np.random.rand(batch_size, len) 19 | 20 | 21 | def generate_sorted_onehot(input): 22 | input = np.array(input, dtype=np.float32) 23 | out = np.zeros((input.shape[1], input.shape[0], input.shape[1]), dtype=np.float32) 24 | for a in range(input.shape[0]): 25 | ind = [b[0] for b in sorted(enumerate(input[a]), key=lambda i:i[1])] 26 | for j in range(len(ind)): 27 | out[j][a][ind[j]] = 1 28 | return out 29 | 30 | 31 | def sanity_check_sorted_onehot(): 32 | arr = [[.3,.1,.2],[.1,.2,.3]] 33 | out = generate_sorted_onehot(arr) 34 | print out 35 | 36 | 37 | # hyper-parameters config 38 | MAX_EPISODES = 1000000 39 | SEQ_LEN = 5 40 | INPUT_DIM = 1 41 | HIDDEN_SIZE = 128 42 | BATCH_SIZE = 128 43 | LEARNING_RATE = 0.002 44 | 45 | sanity_check_sorted_onehot() 46 | 47 | # main code 48 | def train_model(): 49 | trainer = train.Trainer(SEQ_LEN, INPUT_DIM, HIDDEN_SIZE, BATCH_SIZE, LEARNING_RATE) 50 | for i in range(MAX_EPISODES): 51 | # input_batch = generate_sanity_check_batch(SEQ_LEN, BATCH_SIZE) 52 | input_batch = generate_random_batch(SEQ_LEN, BATCH_SIZE) 53 | correct_out = generate_sorted_onehot(input_batch) 54 | 55 | trainer.train(input_batch, correct_out) 56 | 57 | if i % 1000 == 0: 58 | trainer.save_model(i) 59 | 60 | def test_model(ep): 61 | trainer = train.Trainer(SEQ_LEN, INPUT_DIM, HIDDEN_SIZE, BATCH_SIZE, LEARNING_RATE) 62 | trainer.load_model(ep) 63 | input_batch = generate_random_batch(SEQ_LEN, BATCH_SIZE) 64 | trainer.test_batch(input_batch) 65 | 66 | train_model() 67 | # test_model(6000) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class PtrNet(nn.Module): 8 | 9 | def __init__(self, batch_size, seq_len, input_dim, hidden_dim): 10 | super(PtrNet, self).__init__() 11 | 12 | self.batch_size = batch_size # B 13 | self.seq_len = seq_len # N 14 | self.input_dim = input_dim # I 15 | self.hidden_dim = hidden_dim # H 16 | 17 | # encoder 18 | self.encoder = [] 19 | for i in range(self.seq_len): 20 | cell = nn.LSTMCell(self.input_dim, self.hidden_dim) 21 | self.encoder.append(cell) 22 | 23 | # decoder 24 | self.decoder = [] 25 | for i in range(self.seq_len): 26 | cell = nn.LSTMCell(self.hidden_dim, self.hidden_dim) 27 | self.decoder.append(cell) 28 | 29 | # for creating pointers 30 | self.W_encoder = nn.Linear(self.hidden_dim, self.hidden_dim) 31 | self.W_decoder = nn.Linear(self.hidden_dim, self.hidden_dim) 32 | self.V = nn.Linear(self.hidden_dim, self.input_dim) 33 | 34 | def forward(self, input): 35 | encoded_input = [] 36 | 37 | # initialize hidden state and cell state as random 38 | h = Variable(torch.zeros([self.batch_size, self.hidden_dim])) # B*H 39 | c = Variable(torch.zeros([self.batch_size, self.hidden_dim])) # B*H 40 | for i in range(self.seq_len): 41 | inp = Variable(torch.from_numpy(input[:, i])).unsqueeze(1) # B, -> B*I 42 | inp = inp.type(torch.FloatTensor) 43 | h, c = self.encoder[i](inp, (h, c)) # B*H 44 | encoded_input.append(h) 45 | 46 | d_i = Variable(torch.Tensor(self.batch_size, self.hidden_dim).fill_(-1.0)) # B*H 47 | distributions = [] 48 | for i in range(self.seq_len): 49 | h, c = self.decoder[i](d_i, (h, c)) # B*H 50 | 51 | # the attention part as obtained from the paper 52 | # u_i[j] = v * tanh(W1 * e[j] + W2 * c_i) 53 | u_i = [] 54 | c_i = self.W_decoder(c) # B*H 55 | for j in range(self.seq_len): 56 | e_j = self.W_encoder(encoded_input[j]) # B*H 57 | u_j = self.V(F.tanh(c_i + e_j)).squeeze(1) # B*I 58 | u_i.append(u_j) 59 | 60 | # a_i[j] = softmax(u_i[j]) 61 | u_i = torch.stack(u_i).t() # N*B -> B*N 62 | a_i = F.softmax(u_i) # B*N 63 | distributions.append(a_i) 64 | 65 | # d_i+1 = sum(a_i[j]*e[j]) over j 66 | d_i = Variable(torch.zeros([self.batch_size, self.input_dim])) 67 | for j in range(self.seq_len): 68 | # select jth column of a_i 69 | a_j = torch.index_select(a_i, 1, Variable(torch.LongTensor([j]))) # B, 70 | a_j = a_j.expand(self.batch_size, self.hidden_dim) # B*H 71 | d_i = d_i + (a_j*encoded_input[j]) # B*H 72 | 73 | distributions = torch.stack(distributions) 74 | # print distributions 75 | return distributions # N*B*N 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | import model 8 | 9 | 10 | class Trainer: 11 | 12 | def __init__(self, seq_len, input_dim, hidden_dim=256, batch_size=128, learning_rate=0.001): 13 | self.seq_len = seq_len 14 | self.input_dim = input_dim 15 | self.hidden_dim = hidden_dim 16 | self.batch_size = batch_size 17 | self.learning_rate = learning_rate 18 | self.episode = 0 19 | 20 | self.ptrNet = model.PtrNet(self.batch_size, self.seq_len, self.input_dim, self.hidden_dim) 21 | 22 | self.optimizer = torch.optim.RMSprop(self.ptrNet.parameters(), self.learning_rate) 23 | 24 | def incorrect_count_loss(self, actual, pred): 25 | actual = actual.data.numpy() 26 | pred = pred.data.numpy() 27 | sum = 0.0 28 | for i in range(pred.shape[0]): 29 | for j in range(pred.shape[1]): 30 | if actual[i][j][np.argmax(pred[i][j])] != 1: 31 | sum += 1.0 32 | return sum/pred.shape[1] 33 | 34 | def correct_count_score(self, actual, pred): 35 | actual = actual.data.numpy() 36 | pred = pred.data.numpy() 37 | sum = 0.0 38 | for i in range(pred.shape[0]): 39 | for j in range(pred.shape[1]): 40 | if actual[i][j][np.argmax(pred[i][j])] == 1: 41 | sum += 1.0 42 | return sum/pred.shape[1] 43 | 44 | def train(self, input, ground_truth): 45 | correct_out = Variable(torch.from_numpy(ground_truth)) 46 | pred_out = self.ptrNet.forward(input) 47 | 48 | loss = torch.sqrt(torch.mean(torch.pow(correct_out - pred_out, 2))) 49 | loss.backward() 50 | self.optimizer.step() 51 | 52 | self.episode += 1 53 | print 'Episode :- ', self.episode, ' L2 Loss :- ', loss.data.numpy(), \ 54 | ' My Loss :- ', self.incorrect_count_loss(correct_out, pred_out),\ 55 | ' My Score :- ', self.correct_count_score(correct_out, pred_out) 56 | 57 | if self.episode%500==0: 58 | self.test_batch(input) 59 | 60 | def test_batch(self, input): 61 | pred_out = self.ptrNet.forward(input) 62 | pred_out = pred_out.data.numpy() 63 | 64 | print ' --- INPUT ---' 65 | for i in range(5): 66 | print input[i] 67 | 68 | print ' --- OUTPUT ---' 69 | for i in range(5): 70 | for j in range(input.shape[1]): 71 | print input[i][np.argmax(pred_out[j][i])]; 72 | # print pred_out[j][i] 73 | print '\n' 74 | 75 | print ' ---- PROB ----' 76 | for i in range(5): 77 | for j in range(input.shape[1]): 78 | # print input[i][np.argmax(pred_out[j][i])]; 79 | print pred_out[j][i] 80 | print '\n' 81 | 82 | def save_model(self, episode_count): 83 | """ 84 | saves the model 85 | :param episode_count: the count of episodes iterated 86 | :return: 87 | """ 88 | torch.save(self.ptrNet.state_dict(), './Models/' + str(episode_count) + '_net.pt') 89 | print 'Model saved successfully' 90 | 91 | def load_model(self, episode): 92 | """ 93 | loads the model 94 | :param episode: the count of episodes iterated (used to find the file name) 95 | :return: 96 | """ 97 | self.ptrNet.load_state_dict(torch.load('./Models/' + str(episode) + '_net.pt')) 98 | print 'Model loaded succesfully' 99 | --------------------------------------------------------------------------------