├── README.md ├── ch_5.png ├── convex_hull.py ├── dataset.py ├── layers ├── .DS_Store ├── attention.py └── seq2seq │ ├── decoder.py │ └── encoder.py ├── pointer_network.png └── pointer_network.py /README.md: -------------------------------------------------------------------------------- 1 | # Pointer network in PyTorch 2 | 3 | This is an implementation of [Pointer Network](https://arxiv.org/abs/1506.03134) in PyTorch for [Convex Hull](https://en.wikibooks.org/wiki/Convexity/The_convex_hull) problem 4 | 5 | ## Network 6 | ![](pointer_network.png) 7 | 8 | ## Environment 9 | * Python 3.* 10 | * Pytorch 0.3.* 11 | * TensorboardX 1.1 12 | 13 | ## Data 14 | Convex Hull data is aviailable at [link](https://drive.google.com/drive/folders/0B2fg8yPGn2TCMzBtS0o4Q2RJaEU) 15 | 16 | ## Usage 17 | Training: 18 | 19 | ```bash 20 | python convex_hull.py 21 | ``` 22 | Evaluating: 23 | 24 | Not implemented yet 25 | 26 | Visualization: 27 | ```bash 28 | tensorboard --logdir LOG_DIR 29 | ``` 30 | 31 | ## Results 32 | * Training on Convex Hull 5 33 | ![](ch_5.png) 34 | -------------------------------------------------------------------------------- /ch_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thidtc/PointerNetwork-PyTorch/34e94a6ccf3a9f637d87a2f397863a92b4d82752/ch_5.png -------------------------------------------------------------------------------- /convex_hull.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | import tqdm 5 | import torch 6 | from torch import optim 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | from torch.nn.utils import clip_grad_norm 10 | import argparse 11 | import logging 12 | import sys 13 | from tensorboardX import SummaryWriter 14 | 15 | from dataset import CHDataset 16 | from pointer_network import PointerNet, PointerNetLoss 17 | 18 | if __name__ == "__main__": 19 | # Parse argument 20 | parser = argparse.ArgumentParser("Convex Hull") 21 | parser.add_argument("--gpu", type=int, default=0) 22 | parser.add_argument("--bz", type=int, default=256) 23 | parser.add_argument("--max_in_seq_len", type=int, default=5) 24 | parser.add_argument("--max_out_seq_len", type=int, default=6) 25 | parser.add_argument("--rnn_hidden_size", type=int, default=128) 26 | parser.add_argument("--attention_size", type=int, default=128) 27 | parser.add_argument("--num_layers", type=int, default=1) 28 | parser.add_argument("--beam_width", type=int, default=2) 29 | parser.add_argument("--lr", type=float, default=1e-3) 30 | parser.add_argument("--clip_norm", type=float, default=5.) 31 | parser.add_argument('--weight_decay', type=float, default=0.1) 32 | parser.add_argument("--check_interval", type=int, default=20) 33 | parser.add_argument("--nepoch", type=int, default=200) 34 | parser.add_argument("--train_filename", type=str, default="./data/convex_hull_5_test.txt") 35 | parser.add_argument("--model_file", type=str, default=None) 36 | parser.add_argument("--log_dir", type=str, default="./log") 37 | 38 | args = parser.parse_args() 39 | 40 | # Pytroch configuration 41 | if args.gpu >= 0 and torch.cuda.is_available(): 42 | args.use_cuda = True 43 | torch.cuda.device(args.gpu) 44 | else: 45 | args.use_cuda = False 46 | 47 | # Logger 48 | logger = logging.getLogger("Convex Hull") 49 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 50 | console_handler = logging.StreamHandler(sys.stdout) 51 | console_handler.formatter = formatter 52 | logger.addHandler(console_handler) 53 | logger.setLevel(logging.DEBUG) 54 | 55 | # Summary writer 56 | writer = SummaryWriter(args.log_dir) 57 | 58 | # Loading data 59 | train_ds = CHDataset(args.train_filename, args.max_in_seq_len, 60 | args.max_out_seq_len) 61 | logger.info("Train data size: {}".format(len(train_ds))) 62 | 63 | train_dl = DataLoader(train_ds, num_workers=2, batch_size=args.bz) 64 | 65 | # Init model 66 | model = PointerNet("LSTM", 67 | True, 68 | args.num_layers, 69 | 2, 70 | args.rnn_hidden_size, 71 | 0.0) 72 | criterion = PointerNetLoss() 73 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 74 | 75 | if args.use_cuda: 76 | model.cuda() 77 | 78 | # Training 79 | for epoch in range(args.nepoch): 80 | model.train() 81 | total_loss = 0. 82 | batch_cnt = 0. 83 | for b_inp, b_inp_len, b_outp_in, b_outp_out, b_outp_len in train_dl: 84 | b_inp = Variable(b_inp) 85 | b_outp_in = Variable(b_outp_in) 86 | b_outp_out = Variable(b_outp_out) 87 | if args.use_cuda: 88 | b_inp = b_inp.cuda() 89 | b_inp_len = b_inp_len.cuda() 90 | b_outp_in = b_outp_in.cuda() 91 | b_outp_out = b_outp_out.cuda() 92 | b_outp_len = b_outp_len.cuda() 93 | 94 | optimizer.zero_grad() 95 | align_score = model(b_inp, b_inp_len, b_outp_in, b_outp_len) 96 | loss = criterion(b_outp_out, align_score, b_outp_len) 97 | 98 | l = loss.data[0] 99 | total_loss += l 100 | batch_cnt += 1 101 | 102 | loss.backward() 103 | clip_grad_norm(model.parameters(), args.clip_norm) 104 | optimizer.step() 105 | writer.add_scalar('train/loss', total_loss / batch_cnt, epoch) 106 | logger.info("Epoch : {}, loss {}".format(epoch, total_loss / batch_cnt)) 107 | 108 | # Checkout 109 | if epoch % args.check_interval == args.check_interval - 1: 110 | # Save model 111 | if args.model_file is not None: 112 | torch.save(model.state_dict(), args.model_file) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from copy import copy 7 | 8 | class CHDataset(Dataset): 9 | """ Dataset for Convex Hull Problem data 10 | Args: 11 | filename : the dataset file name 12 | max_in_seq_len : maximum input sequence length 13 | max_out_seq_len : maximum output sequence length 14 | """ 15 | def __init__(self, filename, max_in_seq_len, max_out_seq_len): 16 | super(CHDataset, self).__init__() 17 | self.max_in_seq_len = max_in_seq_len 18 | self.max_out_seq_len = max_out_seq_len 19 | self.START = [0, 0] 20 | self.END = [0, 0] 21 | self._load_data(filename) 22 | 23 | def _load_data(self, filename): 24 | with open(filename, 'r') as f: 25 | data = [] 26 | for line in f: 27 | inp, outp = line.strip().split('output') 28 | inp = list(map(float, inp.strip().split(' '))) 29 | # Add 1 due to special token 30 | outp = list(map(int, outp.strip().split(' '))) 31 | # Add START token 32 | outp_in = copy(self.START) 33 | outp_out = [] 34 | for idx in outp: 35 | outp_in += inp[2 * (idx - 1): 2 * idx] 36 | outp_out += [idx] 37 | # Add END token 38 | outp_out += [0] 39 | 40 | # Padding input 41 | inp_len = len(inp) // 2 42 | inp = self.START + inp 43 | inp_len += 1 44 | # Special START token 45 | assert self.max_in_seq_len + 1 >= inp_len 46 | for i in range(self.max_in_seq_len + 1 - inp_len): 47 | inp += self.END 48 | inp = np.array(inp).reshape([-1, 2]) 49 | inp_len = np.array([inp_len]) 50 | # Padding output 51 | outp_len = len(outp) + 1 52 | for i in range(self.max_out_seq_len + 1 - outp_len): 53 | outp_in += self.START 54 | outp_in = np.array(outp_in).reshape([-1, 2]) 55 | outp_out = outp_out + [0] * (self.max_out_seq_len + 1 - outp_len) 56 | outp_out = np.array(outp_out) 57 | outp_len = np.array([outp_len]) 58 | 59 | data.append((inp.astype("float32"), inp_len, outp_in.astype("float32"), outp_out, outp_len)) 60 | self.data = data 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | def __getitem__(self, index): 66 | inp, inp_len, outp_in, outp_out, outp_len = self.data[index] 67 | return inp, inp_len, outp_in, outp_out, outp_len -------------------------------------------------------------------------------- /layers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thidtc/PointerNetwork-PyTorch/34e94a6ccf3a9f637d87a2f397863a92b4d82752/layers/.DS_Store -------------------------------------------------------------------------------- /layers/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | def sequence_mask(lengths, max_len=None): 8 | """ Crete mask for lengths 9 | Args: 10 | lengths (LongTensor) : lengths (bz) 11 | max_len (int) : maximum length 12 | Return: 13 | mask (bz, max_len) 14 | """ 15 | bz = lengths.numel() 16 | max_len = max_len or lengths.max() 17 | return (torch.arange(0, max_len) 18 | .type_as(lengths) 19 | .repeat(bz, 1) 20 | .lt(lengths)) 21 | 22 | class Attention(nn.Module): 23 | """ Attention layer 24 | Args: 25 | attn_type : attention type ["dot", "general"] 26 | dim : input dimension size 27 | """ 28 | def __init__(self, attn_type, dim): 29 | super(Attention, self).__init__() 30 | self.attn_type = attn_type 31 | bias_out = attn_type == "mlp" 32 | self.linear_out = nn.Linear(dim *2, dim, bias_out) 33 | if self.attn_type == "general": 34 | self.linear = nn.Linear(dim, dim, bias=False) 35 | elif self.attn_type == "dot": 36 | pass 37 | else: 38 | raise NotImplementedError() 39 | 40 | def score(self, src, tgt): 41 | """ Attention score calculation 42 | Args: 43 | src : source values (bz, src_len, dim) 44 | tgt : target values (bz, tgt_len, dim) 45 | """ 46 | bz, src_len, dim = src.size() 47 | _, tgt_len, _ = tgt.size() 48 | 49 | if self.attn_type in ["genenral", "dot"]: 50 | tgt_ = tgt 51 | if self.attn_type == "general": 52 | tgt_ = self.linear(tgt_) 53 | src_ = src.transpose(1, 2) 54 | return torch.bmm(tgt_, src_) 55 | else: 56 | raise NotImplementedError() 57 | 58 | def forward(self, src, tgt, src_lengths=None): 59 | """ 60 | Args: 61 | src : source values (bz, src_len, dim) 62 | tgt : target values (bz, tgt_len, dim) 63 | src_lengths : source values length 64 | """ 65 | if tgt.dim() == 2: 66 | one_step = True 67 | src = src.unsqueeze(1) 68 | else: 69 | one_step = False 70 | 71 | bz, src_len, dim = src.size() 72 | _, tgt_len, _ = tgt.size() 73 | 74 | align_score = self.score(src, tgt) 75 | 76 | if src_lengths is not None: 77 | mask = sequence_mask(src_lengths) 78 | # (bz, max_len) -> (bz, 1, max_len) 79 | # so mask can broadcast 80 | mask = mask.unsqueeze(1) 81 | align_score.data.masked_fill_(1 - mask, -float('inf')) 82 | 83 | # Normalize weights 84 | align_score = F.softmax(align_score, -1) 85 | 86 | c = torch.bmm(align_score, src) 87 | 88 | concat_c = torch.cat([c, tgt], -1) 89 | attn_h = self.linear_out(concat_c) 90 | 91 | if one_step: 92 | attn_h = attn_h.squeeze(1) 93 | align_score = align_score.squeeze(1) 94 | 95 | return attn_h, align_score -------------------------------------------------------------------------------- /layers/seq2seq/decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from layers.seq2seq.encoder import rnn_factory 8 | from layers.attention import Attention 9 | 10 | class RNNDecoderBase(nn.Module): 11 | """ RNN decoder base class 12 | Args: 13 | rnn_type : rnn cell type, ["LSTM", "GRU", "RNN"] 14 | bidirectional : whether use bidirectional rnn 15 | num_layers : number of layers in stacked rnn 16 | input_size : input dimension size 17 | hidden_size : rnn hidden dimension size 18 | dropout : dropout rate 19 | """ 20 | def __init__(self, rnn_type, bidirectional, num_layers, 21 | input_size, hidden_size, dropout): 22 | super(RNNDecoderBase, self).__init__() 23 | if bidirectional: 24 | assert hidden_size % 2 == 0 25 | hidden_size = hidden_size // 2 26 | self.rnn, _ = rnn_factory(rnn_type, 27 | input_size=input_size, 28 | hidden_size=hidden_size, 29 | bidirectional=bidirectional, 30 | num_layers=num_layers, 31 | dropout=dropout) 32 | 33 | def forward(self, tgt, memory_bank, hidden, memory_lengths=None): 34 | """ 35 | Args: 36 | tgt: target sequence 37 | memory_bank : memory from encoder or other source 38 | hidden : init hidden state 39 | memory_lengths : lengths of memory 40 | """ 41 | raise NotImplementedError() 42 | 43 | 44 | -------------------------------------------------------------------------------- /layers/seq2seq/encoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | 9 | def rnn_factory(rnn_type, **kwargs): 10 | pack_padded_seq = True 11 | if rnn_type in ["LSTM", "GRU", "RNN"]: 12 | rnn = getattr(nn, rnn_type)(**kwargs) 13 | return rnn, pack_padded_seq 14 | 15 | class EncoderBase(nn.Module): 16 | """ encoder base class 17 | """ 18 | def __init__(self): 19 | super(EncoderBase, self).__init__() 20 | 21 | def forward(self, src, lengths=None, hidden=None): 22 | """ 23 | Args: 24 | src (FloatTensor) : input sequence 25 | lengths (LongTensor) : lengths of input sequence 26 | hidden : init hidden state 27 | """ 28 | raise NotImplementedError() 29 | 30 | class RNNEncoder(EncoderBase): 31 | """ RNN encoder class 32 | 33 | Args: 34 | rnn_type : rnn cell type, ["LSTM", "GRU", "RNN"] 35 | bidirectional : whether use bidirectional rnn 36 | num_layers : number of layers in stacked rnn 37 | input_size : input dimension size 38 | hidden_size : rnn hidden dimension size 39 | dropout : dropout rate 40 | use_bridge : TODO: implement bridge 41 | """ 42 | def __init__(self, rnn_type, bidirectional, num_layers, 43 | input_size, hidden_size, dropout, use_bridge=False): 44 | super(RNNEncoder, self).__init__() 45 | if bidirectional: 46 | assert hidden_size % 2 == 0 47 | hidden_size = hidden_size // 2 48 | self.rnn, self.pack_padded_seq = rnn_factory(rnn_type, 49 | input_size=input_size, 50 | hidden_size=hidden_size, 51 | bidirectional=bidirectional, 52 | num_layers=num_layers, 53 | dropout=dropout) 54 | self.use_bridge = use_bridge 55 | if self.use_bridge: 56 | raise NotImplementedError() 57 | 58 | def forward(self, src, lengths=None, hidden=None): 59 | """ 60 | Same as BaseEncoder.forward 61 | """ 62 | packed_src = src 63 | if self.pack_padded_seq and lengths is not None: 64 | lengths = lengths.view(-1).tolist() 65 | packed_src = pack(src, lengths) 66 | 67 | memory_bank, hidden_final = self.rnn(packed_src, hidden) 68 | 69 | if self.pack_padded_seq and lengths is not None: 70 | memory_bank = unpack(memory_bank)[0] 71 | 72 | if self.use_bridge: 73 | raise NotImplementedError() 74 | return memory_bank, hidden_final -------------------------------------------------------------------------------- /pointer_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thidtc/PointerNetwork-PyTorch/34e94a6ccf3a9f637d87a2f397863a92b4d82752/pointer_network.png -------------------------------------------------------------------------------- /pointer_network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | from layers.seq2seq.encoder import RNNEncoder 8 | from layers.seq2seq.decoder import RNNDecoderBase 9 | from layers.attention import Attention, sequence_mask 10 | 11 | class PointerNetRNNDecoder(RNNDecoderBase): 12 | """ Pointer network RNN Decoder, process all the output together 13 | """ 14 | def __init__(self, rnn_type, bidirectional, num_layers, 15 | input_size, hidden_size, dropout): 16 | super(PointerNetRNNDecoder, self).__init__(rnn_type, bidirectional, num_layers, 17 | input_size, hidden_size, dropout) 18 | self.attention = Attention("dot", hidden_size) 19 | 20 | def forward(self, tgt, memory_bank, hidden, memory_lengths=None): 21 | # RNN 22 | rnn_output, hidden_final = self.rnn(tgt, hidden) 23 | # Attention 24 | memory_bank = memory_bank.transpose(0, 1) 25 | rnn_output = rnn_output.transpose(0, 1) 26 | attn_h, align_score = self.attention(memory_bank, rnn_output, memory_lengths) 27 | 28 | return align_score 29 | 30 | class PointerNet(nn.Module): 31 | """ Pointer network 32 | Args: 33 | rnn_type (str) : rnn cell type 34 | bidirectional : whether rnn is bidirectional 35 | num_layers : number of layers of stacked rnn 36 | encoder_input_size : input size of encoder 37 | rnn_hidden_size : rnn hidden dimension size 38 | dropout : dropout rate 39 | """ 40 | def __init__(self, rnn_type, bidirectional, num_layers, 41 | encoder_input_size, rnn_hidden_size, dropout): 42 | super(PointerNet, self).__init__() 43 | self.encoder = RNNEncoder(rnn_type, bidirectional, 44 | num_layers, encoder_input_size, rnn_hidden_size, dropout) 45 | self.decoder = PointerNetRNNDecoder(rnn_type, bidirectional, 46 | num_layers, encoder_input_size, rnn_hidden_size, dropout) 47 | 48 | def forward(self, inp, inp_len, outp, outp_len): 49 | inp = inp.transpose(0, 1) 50 | outp = outp.transpose(0, 1) 51 | memory_bank, hidden_final = self.encoder(inp, inp_len) 52 | align_score = self.decoder(outp, memory_bank, hidden_final, inp_len) 53 | return align_score 54 | 55 | class PointerNetLoss(nn.Module): 56 | """ Loss function for pointer network 57 | """ 58 | def __init__(self): 59 | super(PointerNetLoss, self).__init__() 60 | 61 | def forward(self, target, logits, lengths): 62 | """ 63 | Args: 64 | target : label data (bz, tgt_max_len) 65 | logits : predicts (bz, tgt_max_len, src_max_len) 66 | lengths : length of label data (bz) 67 | """ 68 | _, tgt_max_len = target.size() 69 | logits_flat = logits.view(-1, logits.size(-1)) 70 | log_logits_flat = torch.log(logits_flat) 71 | target_flat = target.view(-1, 1) 72 | losses_flat = -torch.gather(log_logits_flat, dim=1, index = target_flat) 73 | losses = losses_flat.view(*target.size()) 74 | mask = sequence_mask(lengths, tgt_max_len) 75 | mask = Variable(mask) 76 | losses = losses * mask.float() 77 | loss = losses.sum() / lengths.float().sum() 78 | return loss --------------------------------------------------------------------------------