├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── model.py └── train_sort.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/pycharm+all 3 | # Edit at https://www.gitignore.io/?templates=pycharm+all 4 | 5 | ### PyCharm+all ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # Generated files 17 | .idea/**/contentModel.xml 18 | 19 | # Sensitive or high-churn files 20 | .idea/**/dataSources/ 21 | .idea/**/dataSources.ids 22 | .idea/**/dataSources.local.xml 23 | .idea/**/sqlDataSources.xml 24 | .idea/**/dynamic.xml 25 | .idea/**/uiDesigner.xml 26 | .idea/**/dbnavigator.xml 27 | 28 | # Gradle 29 | .idea/**/gradle.xml 30 | .idea/**/libraries 31 | 32 | # Gradle and Maven with auto-import 33 | # When using Gradle or Maven with auto-import, you should exclude module files, 34 | # since they will be recreated, and may cause churn. Uncomment if using 35 | # auto-import. 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | 75 | ### PyCharm+all Patch ### 76 | # Ignores the whole .idea folder and all .iml files 77 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 78 | 79 | .idea/ 80 | 81 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 82 | 83 | *.iml 84 | modules.xml 85 | .idea/misc.xml 86 | *.ipr 87 | 88 | # Sonarlint plugin 89 | .idea/sonarlint 90 | 91 | # End of https://www.gitignore.io/api/pycharm+all 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sungtae An 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pointer-networks-pytorch 2 | Implementation of Pointer Networks using PyTorch: 3 | 4 | *Vinyals, Oriol, Meire Fortunato, and Navdeep Jaitly. "Pointer networks." Advances in Neural Information Processing Systems. 2015.* [[Paper](https://papers.nips.cc/paper/5866-pointer-networks)] 5 | 6 | 7 | **These codes were tested with _Python 3.7.3_ and _PyTorch 1.1.0_** 8 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class IntegerSortDataset(Dataset): 9 | def __init__(self, num_samples=10000, low=0, high=100, min_len=1, max_len=10, seed=1): 10 | 11 | self.prng = np.random.RandomState(seed=seed) 12 | self.input_dim = high 13 | 14 | # Here, we assuming that the shape of each sample is a list of list of a single integer, e.g., [[10], [3], [5], [0]] 15 | # It is for an easier extension later even though it is not necessary for this simple sorting example 16 | self.seqs = [list(map(lambda x: [x], self.prng.choice(np.arange(low, high), size=self.prng.randint(min_len, max_len+1)).tolist())) for _ in range(num_samples)] 17 | self.labels = [sorted(range(len(seq)), key=seq.__getitem__) for seq in self.seqs] 18 | 19 | def __getitem__(self, index): 20 | seq = self.seqs[index] 21 | label = self.labels[index] 22 | 23 | len_seq = len(seq) 24 | row_col_index = list(zip(*[(i, number) for i, numbers in enumerate(seq) for number in numbers])) 25 | num_values = len(row_col_index[0]) 26 | 27 | i = torch.LongTensor(row_col_index) 28 | v = torch.FloatTensor([1]*num_values) 29 | data = torch.sparse.FloatTensor(i, v, torch.Size([len_seq, self.input_dim])) 30 | 31 | return data, len_seq, label 32 | 33 | def __len__(self): 34 | return len(self.seqs) 35 | 36 | 37 | def sparse_seq_collate_fn(batch): 38 | batch_size = len(batch) 39 | 40 | sorted_seqs, sorted_lengths, sorted_labels = zip(*sorted(batch, key=lambda x: x[1], reverse=True)) 41 | 42 | padded_seqs = [seq.resize_as_(sorted_seqs[0]) for seq in sorted_seqs] 43 | 44 | # (Sparse) batch_size X max_seq_len X input_dim 45 | seq_tensor = torch.stack(padded_seqs) 46 | 47 | # batch_size 48 | length_tensor = torch.LongTensor(sorted_lengths) 49 | 50 | padded_labels = list(zip(*(itertools.zip_longest(*sorted_labels, fillvalue=-1)))) 51 | 52 | # batch_size X max_seq_len (-1 padding) 53 | label_tensor = torch.LongTensor(padded_labels).view(batch_size, -1) 54 | 55 | # TODO: Currently, PyTorch DataLoader with num_workers >= 1 (multiprocessing) does not support Sparse Tensor 56 | # TODO: Meanwhile, use a dense tensor when num_workers >= 1. 57 | seq_tensor = seq_tensor.to_dense() 58 | 59 | return seq_tensor, length_tensor, label_tensor 60 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Adopted from allennlp (https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py) 6 | def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: 7 | """ 8 | ``torch.nn.functional.log_softmax(vector)`` does not work if some elements of ``vector`` should be 9 | masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing 10 | ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax. 11 | ``vector`` can have an arbitrary number of dimensions; the only requirement is that ``mask`` is 12 | broadcastable to ``vector's`` shape. If ``mask`` has fewer dimensions than ``vector``, we will 13 | unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask, 14 | do it yourself before passing the mask into this function. 15 | In the case that the input vector is completely masked, the return value of this function is 16 | arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out 17 | of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way 18 | that we deal with this case relies on having single-precision floats; mixing half-precision 19 | floats with fully-masked vectors will likely give you ``nans``. 20 | If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or 21 | lower), the way we handle masking here could mess you up. But if you've got logit values that 22 | extreme, you've got bigger problems than this. 23 | """ 24 | if mask is not None: 25 | mask = mask.float() 26 | while mask.dim() < vector.dim(): 27 | mask = mask.unsqueeze(1) 28 | # vector + mask.log() is an easy way to zero out masked elements in logspace, but it 29 | # results in nans when the whole vector is masked. We need a very small value instead of a 30 | # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely 31 | # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it 32 | # becomes 0 - this is just the smallest value we can actually use. 33 | vector = vector + (mask + 1e-45).log() 34 | return torch.nn.functional.log_softmax(vector, dim=dim) 35 | 36 | 37 | # Adopted from allennlp (https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py) 38 | def masked_max(vector: torch.Tensor, 39 | mask: torch.Tensor, 40 | dim: int, 41 | keepdim: bool = False, 42 | min_val: float = -1e7) -> (torch.Tensor, torch.Tensor): 43 | """ 44 | To calculate max along certain dimensions on masked values 45 | Parameters 46 | ---------- 47 | vector : ``torch.Tensor`` 48 | The vector to calculate max, assume unmasked parts are already zeros 49 | mask : ``torch.Tensor`` 50 | The mask of the vector. It must be broadcastable with vector. 51 | dim : ``int`` 52 | The dimension to calculate max 53 | keepdim : ``bool`` 54 | Whether to keep dimension 55 | min_val : ``float`` 56 | The minimal value for paddings 57 | Returns 58 | ------- 59 | A ``torch.Tensor`` of including the maximum values. 60 | """ 61 | one_minus_mask = (1.0 - mask).byte() 62 | replaced_vector = vector.masked_fill(one_minus_mask, min_val) 63 | max_value, max_index = replaced_vector.max(dim=dim, keepdim=keepdim) 64 | return max_value, max_index 65 | 66 | 67 | class Encoder(nn.Module): 68 | def __init__(self, embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=True): 69 | super(Encoder, self).__init__() 70 | 71 | self.batch_first = batch_first 72 | self.rnn = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, 73 | batch_first=batch_first, bidirectional=bidirectional) 74 | 75 | def forward(self, embedded_inputs, input_lengths): 76 | # Pack padded batch of sequences for RNN module 77 | packed = nn.utils.rnn.pack_padded_sequence(embedded_inputs, input_lengths, batch_first=self.batch_first) 78 | # Forward pass through RNN 79 | outputs, hidden = self.rnn(packed) 80 | # Unpack padding 81 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first) 82 | # Return output and final hidden state 83 | return outputs, hidden 84 | 85 | 86 | class Attention(nn.Module): 87 | def __init__(self, hidden_size): 88 | super(Attention, self).__init__() 89 | self.hidden_size = hidden_size 90 | self.W1 = nn.Linear(hidden_size, hidden_size, bias=False) 91 | self.W2 = nn.Linear(hidden_size, hidden_size, bias=False) 92 | self.vt = nn.Linear(hidden_size, 1, bias=False) 93 | 94 | def forward(self, decoder_state, encoder_outputs, mask): 95 | # (batch_size, max_seq_len, hidden_size) 96 | encoder_transform = self.W1(encoder_outputs) 97 | 98 | # (batch_size, 1 (unsqueezed), hidden_size) 99 | decoder_transform = self.W2(decoder_state).unsqueeze(1) 100 | 101 | # 1st line of Eq.(3) in the paper 102 | # (batch_size, max_seq_len, 1) => (batch_size, max_seq_len) 103 | u_i = self.vt(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1) 104 | 105 | # softmax with only valid inputs, excluding zero padded parts 106 | # log-softmax for a better numerical stability 107 | log_score = masked_log_softmax(u_i, mask, dim=-1) 108 | 109 | return log_score 110 | 111 | 112 | class PointerNet(nn.Module): 113 | def __init__(self, input_dim, embedding_dim, hidden_size, bidirectional=True, batch_first=True): 114 | super(PointerNet, self).__init__() 115 | 116 | # Embedding dimension 117 | self.embedding_dim = embedding_dim 118 | # (Decoder) hidden size 119 | self.hidden_size = hidden_size 120 | # Bidirectional Encoder 121 | self.bidirectional = bidirectional 122 | self.num_directions = 2 if bidirectional else 1 123 | self.num_layers = 1 124 | self.batch_first = batch_first 125 | 126 | # We use an embedding layer for more complicate application usages later, e.g., word sequences. 127 | self.embedding = nn.Linear(in_features=input_dim, out_features=embedding_dim, bias=False) 128 | self.encoder = Encoder(embedding_dim=embedding_dim, hidden_size=hidden_size, num_layers=self.num_layers, 129 | bidirectional=bidirectional, batch_first=batch_first) 130 | self.decoding_rnn = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) 131 | self.attn = Attention(hidden_size=hidden_size) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Linear): 135 | if m.bias is not None: 136 | torch.nn.init.zeros_(m.bias) 137 | 138 | def forward(self, input_seq, input_lengths): 139 | 140 | if self.batch_first: 141 | batch_size = input_seq.size(0) 142 | max_seq_len = input_seq.size(1) 143 | else: 144 | batch_size = input_seq.size(1) 145 | max_seq_len = input_seq.size(0) 146 | 147 | # Embedding 148 | embedded = self.embedding(input_seq) 149 | # (batch_size, max_seq_len, embedding_dim) 150 | 151 | # encoder_output => (batch_size, max_seq_len, hidden_size) if batch_first else (max_seq_len, batch_size, hidden_size) 152 | # hidden_size is usually set same as embedding size 153 | # encoder_hidden => (num_layers * num_directions, batch_size, hidden_size) for each of h_n and c_n 154 | encoder_outputs, encoder_hidden = self.encoder(embedded, input_lengths) 155 | 156 | if self.bidirectional: 157 | # Optionally, Sum bidirectional RNN outputs 158 | encoder_outputs = encoder_outputs[:, :, :self.hidden_size] + encoder_outputs[:, :, self.hidden_size:] 159 | 160 | encoder_h_n, encoder_c_n = encoder_hidden 161 | encoder_h_n = encoder_h_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size) 162 | encoder_c_n = encoder_c_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size) 163 | 164 | # Lets use zeros as an intial input for sorting example 165 | decoder_input = encoder_outputs.new_zeros(torch.Size((batch_size, self.hidden_size))) 166 | decoder_hidden = (encoder_h_n[-1, 0, :, :].squeeze(), encoder_c_n[-1, 0, :, :].squeeze()) 167 | 168 | range_tensor = torch.arange(max_seq_len, device=input_lengths.device, dtype=input_lengths.dtype).expand(batch_size, max_seq_len, max_seq_len) 169 | each_len_tensor = input_lengths.view(-1, 1, 1).expand(batch_size, max_seq_len, max_seq_len) 170 | 171 | row_mask_tensor = (range_tensor < each_len_tensor) 172 | col_mask_tensor = row_mask_tensor.transpose(1, 2) 173 | mask_tensor = row_mask_tensor * col_mask_tensor 174 | 175 | pointer_log_scores = [] 176 | pointer_argmaxs = [] 177 | 178 | for i in range(max_seq_len): 179 | # We will simply mask out when calculating attention or max (and loss later) 180 | # not all input and hiddens, just for simplicity 181 | sub_mask = mask_tensor[:, i, :].float() 182 | 183 | # h, c: (batch_size, hidden_size) 184 | h_i, c_i = self.decoding_rnn(decoder_input, decoder_hidden) 185 | 186 | # next hidden 187 | decoder_hidden = (h_i, c_i) 188 | 189 | # Get a pointer distribution over the encoder outputs using attention 190 | # (batch_size, max_seq_len) 191 | log_pointer_score = self.attn(h_i, encoder_outputs, sub_mask) 192 | pointer_log_scores.append(log_pointer_score) 193 | 194 | # Get the indices of maximum pointer 195 | _, masked_argmax = masked_max(log_pointer_score, sub_mask, dim=1, keepdim=True) 196 | 197 | pointer_argmaxs.append(masked_argmax) 198 | index_tensor = masked_argmax.unsqueeze(-1).expand(batch_size, 1, self.hidden_size) 199 | 200 | # (batch_size, hidden_size) 201 | decoder_input = torch.gather(encoder_outputs, dim=1, index=index_tensor).squeeze(1) 202 | 203 | pointer_log_scores = torch.stack(pointer_log_scores, 1) 204 | pointer_argmaxs = torch.cat(pointer_argmaxs, 1) 205 | 206 | return pointer_log_scores, pointer_argmaxs, mask_tensor 207 | -------------------------------------------------------------------------------- /train_sort.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.optim import Adam 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | from dataset import IntegerSortDataset, sparse_seq_collate_fn 14 | from model import PointerNet 15 | 16 | parser = argparse.ArgumentParser(description='PtrNet-Sorting-Integer') 17 | 18 | parser.add_argument('--low', type=int, default=0, help='lowest value in dataset (default: 0)') 19 | parser.add_argument('--high', type=int, default=100, help='highest value in dataset (default: 100)') 20 | parser.add_argument('--min-length', type=int, default=5, help='minimum length of sequences (default: 5)') 21 | parser.add_argument('--max-length', type=int, default=10, help='maximum length of sequences (default: 20)') 22 | parser.add_argument('--train-samples', type=int, default=100000, help='number of samples in train set (default: 100000)') 23 | parser.add_argument('--test-samples', type=int, default=1000, help='number of samples in test set (default: 1000)') 24 | 25 | parser.add_argument('--emb-dim', type=int, default=8, help='embedding dimension (default: 8)') 26 | parser.add_argument('--batch-size', type=int, default=256, help='input batch size for training (default: 256)') 27 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') 28 | 29 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)') 30 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay (default: 1e-5)') 31 | 32 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers (default: 4)') 33 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') 34 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 35 | 36 | 37 | class AverageMeter(object): 38 | """Computes and stores the average and current value""" 39 | 40 | def __init__(self): 41 | self.reset() 42 | 43 | def reset(self): 44 | self.val = 0 45 | self.avg = 0 46 | self.sum = 0 47 | self.count = 0 48 | 49 | def update(self, val, n=1): 50 | self.val = val 51 | self.sum += val * n 52 | self.count += n 53 | self.avg = self.sum / self.count 54 | 55 | 56 | def masked_accuracy(output, target, mask): 57 | """Computes a batch accuracy with a mask (for padded sequences) """ 58 | with torch.no_grad(): 59 | masked_output = torch.masked_select(output, mask) 60 | masked_target = torch.masked_select(target, mask) 61 | accuracy = masked_output.eq(masked_target).float().mean() 62 | 63 | return accuracy 64 | 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | 69 | if args.seed is not None: 70 | random.seed(args.seed) 71 | torch.manual_seed(args.seed) 72 | cudnn.deterministic = True 73 | warnings.warn('You have chosen to seed training. ' 74 | 'This will turn on the CUDNN deterministic setting, ' 75 | 'which can slow down your training considerably! ' 76 | 'You may see unexpected behavior when restarting ' 77 | 'from checkpoints.') 78 | 79 | use_cuda = not args.no_cuda and torch.cuda.is_available() 80 | device = torch.device("cuda" if use_cuda else "cpu") 81 | cudnn.benchmark = True if use_cuda else False 82 | 83 | train_set = IntegerSortDataset(num_samples=args.train_samples, high=args.high, min_len=args.min_length, max_len=args.max_length, seed=1) 84 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=sparse_seq_collate_fn) 85 | 86 | test_set = IntegerSortDataset(num_samples=args.test_samples, high=args.high, min_len=args.min_length, max_len=args.max_length, seed=2) 87 | test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, collate_fn=sparse_seq_collate_fn) 88 | 89 | model = PointerNet(input_dim=args.high, embedding_dim=args.emb_dim, hidden_size=args.emb_dim).to(device) 90 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 91 | 92 | train_loss = AverageMeter() 93 | train_accuracy = AverageMeter() 94 | test_loss = AverageMeter() 95 | test_accuracy = AverageMeter() 96 | 97 | for epoch in range(args.epochs): 98 | # Train 99 | model.train() 100 | for batch_idx, (seq, length, target) in enumerate(train_loader): 101 | seq, length, target = seq.to(device), length.to(device), target.to(device) 102 | 103 | optimizer.zero_grad() 104 | log_pointer_score, argmax_pointer, mask = model(seq, length) 105 | 106 | unrolled = log_pointer_score.view(-1, log_pointer_score.size(-1)) 107 | loss = F.nll_loss(unrolled, target.view(-1), ignore_index=-1) 108 | assert not np.isnan(loss.item()), 'Model diverged with loss = NaN' 109 | 110 | loss.backward() 111 | optimizer.step() 112 | 113 | train_loss.update(loss.item(), seq.size(0)) 114 | 115 | mask = mask[:, 0, :] 116 | train_accuracy.update(masked_accuracy(argmax_pointer, target, mask).item(), mask.int().sum().item()) 117 | 118 | if batch_idx % 20 == 0: 119 | print('Epoch {}: Train [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}' 120 | .format(epoch, batch_idx * len(seq), len(train_loader.dataset), 121 | 100. * batch_idx / len(train_loader), train_loss.avg, train_accuracy.avg)) 122 | 123 | # Test 124 | model.eval() 125 | for seq, length, target in test_loader: 126 | seq, length, target = seq.to(device), length.to(device), target.to(device) 127 | 128 | log_pointer_score, argmax_pointer, mask = model(seq, length) 129 | unrolled = log_pointer_score.view(-1, log_pointer_score.size(-1)) 130 | loss = F.nll_loss(unrolled, target.view(-1), ignore_index=-1) 131 | assert not np.isnan(loss.item()), 'Model diverged with loss = NaN' 132 | 133 | test_loss.update(loss.item(), seq.size(0)) 134 | 135 | mask = mask[:, 0, :] 136 | test_accuracy.update(masked_accuracy(argmax_pointer, target, mask).item(), mask.int().sum().item()) 137 | print('Epoch {}: Test\tLoss: {:.6f}\tAccuracy: {:.6f}'.format(epoch, test_loss.avg, test_accuracy.avg)) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | --------------------------------------------------------------------------------