├── .gitignore ├── README.md ├── config.py ├── model.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | .DS_Store 3 | *.DS_Store 4 | *.zip 5 | *.swp 6 | *.pkl 7 | log/ 8 | *.log 9 | .nfs* 10 | 11 | data/ 12 | models/ 13 | RELEASE-1.5.5/ 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | .static_storage/ 70 | .media/ 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sequence to sequence implementation with PyTorch 2 | 3 | Seq2seq model with 4 | - global attention 5 | - self-critical sequence training 6 | 7 | ## Dependencies 8 | - Python 3.6 9 | - PyTorch 0.3 10 | - Spacy 2.0.4 11 | - Torchtext 0.2.3 12 | - Numpy 13 | 14 | You can install Torchtext following: https://stackoverflow.com/questions/42711144/how-can-i-install-torchtext 15 | 16 | You need to install Spacy models specified in `config.py` (`src_lang` and `trg_lang`). Usually you can do this by running `python -m spacy download en` after installing Spacy. 17 | 18 | ## Start training 19 | 20 | 1. create `models`, `data` and `log` folders in the root. 21 | 2. Prepare data files in `data` folder. 22 | * Prepare 6 files named as `[train/test/valid].[src/trg]`, where each line in `*.src` is a source sentence, and in `*.trg` is a target sentence. 23 | 3. You can modify the configurations in `config.py` 24 | 4. Start training 25 | - `python train.py --config --exp ` to train the model. 26 | 27 | ### Options 28 | - Define model settings in `config.py` and choose with `--config`. 29 | - The model will use GPU if available, add `--disable_cuda` to use cpu explicitly. 30 | - Use `CUDA_VISIBLE_DEVICES=2` to choose GPU device. For example, `CUDA_VISIBLE_DEVICES=1 python train.py --config chatbot_twitter`. 31 | - Add `--resume` to resume from a certain saved model, specified by `--config` and `--exp`. 32 | - Add `--early_stopping` and set `--patient ` to enable early stopping, the training process will end if the validation loss doesn't decrease for `n` epochs, or `max_epoch` is reached. Without `--early_stopping`, we'll train the model for `num_epoch` epochs. 33 | - Set `--self_critical

` to use hybrid loss. 34 | 35 | ## TODO 36 | 37 | ## Reference: 38 | - OpenNMT-py: https://github.com/OpenNMT/OpenNMT-py 39 | - Pytorch NMT tutorial: http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html (Note that the tutorial has some faults) 40 | - torchtext: https://github.com/pytorch/text 41 | - Effective Approaches to Attention-based Neural Machine Translation 42 | - Neural Text Generation: A Practical Guide 43 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017 Yifan WANG 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | Configurations for different tasks 11 | """ 12 | from torchtext.data import Example 13 | 14 | def gigawords(): 15 | c = {} 16 | # filename is of form: 'train.src', 'test.trg' 17 | c['root'] = 'data/summarization/' 18 | c['prefix'] = 'summarization' 19 | c['splits'] = ['train', 'test', 'valid'] 20 | # names of Spacy models 21 | c['src_lang'] = 'en_core_web_sm' 22 | c['trg_lang'] = 'en_core_web_sm' 23 | c['model_path'] = './models/' 24 | c['log_step'] = 1000 25 | c['save_step'] = 1000 26 | c['test_step'] = 4000 27 | c['beam_size'] = -1 28 | # model settings 29 | c['encoder_embed_size'] = 300 30 | c['decoder_embed_size'] = 300 31 | c['share_embed'] = False 32 | c['encoder_hidden_size'] = 512 33 | c['decoder_hidden_size'] = 512 34 | # training settings 35 | c['num_epoch'] = 5 36 | c['max_epoch'] = 50 37 | c['num_layers'] = 1 38 | c['batch_size'] = 32 39 | c['learning_rate'] = 0.0001 40 | c['encoder_vocab'] = 30000 41 | c['decoder_vocab'] = 20000 42 | 43 | def load(src_path, trg_path, src_field, trg_field): 44 | """ 45 | Function used to load examples from file 46 | """ 47 | src = open(src_path, 'r').readlines() 48 | trg = open(trg_path, 'r').readlines() 49 | examples = [] 50 | for (l1, l2) in zip(src,trg): 51 | examples.append(Example.fromlist([l1, l2], [('src', src_field), ('trg', trg_field)])) 52 | return examples 53 | 54 | c['load'] = load 55 | return c 56 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017 Yifan WANG 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | Sequence to sequence model with global attention. 11 | """ 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch import optim 17 | import torch.nn.functional as F 18 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 19 | import os, time, sys 20 | 21 | 22 | class GlobalAttention(nn.Module): 23 | """ 24 | Global Attention as described in 'Effective Approaches to Attention-based Neural Machine Translation' 25 | """ 26 | def __init__(self, enc_hidden, dec_hidden): 27 | super(GlobalAttention, self).__init__() 28 | self.enc_hidden = enc_hidden 29 | self.dec_hidden = dec_hidden 30 | 31 | # a = h_t^T W h_s 32 | self.linear_in = nn.Linear(enc_hidden, dec_hidden, bias=False) 33 | # W [c, h_t] 34 | self.linear_out = nn.Linear(dec_hidden + enc_hidden, dec_hidden) 35 | self.softmax = nn.Softmax() 36 | self.tanh = nn.Tanh() 37 | 38 | def sequence_mask(self, lengths, max_len=None): 39 | """ 40 | Creates a boolean mask from sequence lengths. 41 | """ 42 | batch_size = lengths.numel() 43 | max_len = max_len or lengths.max() 44 | return (torch.arange(0, max_len) 45 | .type_as(lengths) 46 | .repeat(batch_size, 1) 47 | .lt(lengths.unsqueeze(1))) 48 | 49 | def forward(self, inputs, context, context_lengths): 50 | """ 51 | input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output. (h_t) 52 | context (FloatTensor): batch x src_len x dim: src hidden states 53 | context_lengths (LongTensor): the source context lengths. 54 | """ 55 | # (batch, tgt_len, src_len) 56 | align = self.score(inputs, context) 57 | batch, tgt_len, src_len = align.size() 58 | 59 | 60 | mask = self.sequence_mask(context_lengths) 61 | # (batch, 1, src_len) 62 | mask = mask.unsqueeze(1) # Make it broadcastable. 63 | if next(self.parameters()).is_cuda: 64 | mask = mask.cuda() 65 | align.data.masked_fill_(1 - mask, -float('inf')) # fill with -inf 66 | 67 | align_vectors = self.softmax(align.view(batch*tgt_len, src_len)) 68 | align_vectors = align_vectors.view(batch, tgt_len, src_len) 69 | 70 | # (batch, tgt_len, src_len) * (batch, src_len, enc_hidden) -> (batch, tgt_len, enc_hidden) 71 | c = torch.bmm(align_vectors, context) 72 | 73 | # \hat{h_t} = tanh(W [c_t, h_t]) 74 | concat_c = torch.cat([c, inputs], 2).view(batch*tgt_len, self.enc_hidden + self.dec_hidden) 75 | attn_h = self.tanh(self.linear_out(concat_c).view(batch, tgt_len, self.dec_hidden)) 76 | 77 | # transpose will make it non-contiguous 78 | attn_h = attn_h.transpose(0, 1).contiguous() 79 | align_vectors = align_vectors.transpose(0, 1).contiguous() 80 | # (tgt_len, batch, dim) 81 | return attn_h, align_vectors 82 | 83 | def score(self, h_t, h_s): 84 | """ 85 | h_t (FloatTensor): batch x tgt_len x dim, inputs 86 | h_s (FloatTensor): batch x src_len x dim, context 87 | """ 88 | tgt_batch, tgt_len, tgt_dim = h_t.size() 89 | src_batch, src_len, src_dim = h_s.size() 90 | 91 | h_t = h_t.view(tgt_batch*tgt_len, tgt_dim) 92 | h_t_ = self.linear_in(h_t) 93 | h_t = h_t.view(tgt_batch, tgt_len, tgt_dim) 94 | # (batch, d, s_len) 95 | h_s_ = h_s.transpose(1, 2) 96 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 97 | return torch.bmm(h_t, h_s_) 98 | 99 | 100 | class EncoderRNN(nn.Module): 101 | def __init__(self, vocab_size, embed_size, hidden_size, n_layers=1, padding_idx=1): 102 | super(EncoderRNN, self).__init__() 103 | self.n_layers = n_layers 104 | self.vocab_size = vocab_size 105 | self.hidden_size = hidden_size 106 | self.embed_size = embed_size 107 | 108 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx) 109 | self.rnn = nn.GRU(input_size=embed_size, hidden_size=hidden_size, num_layers=n_layers) 110 | 111 | def forward(self, inputs, lengths, return_packed=False): 112 | """ 113 | Inputs: 114 | inputs: (seq_length, batch_size), non-packed inputs 115 | lengths: (batch_size) 116 | """ 117 | # [seq_length, batch_size, embed_length] 118 | embedded = self.embedding(inputs) 119 | packed = pack_padded_sequence(embedded, lengths=lengths.numpy()) 120 | outputs, hiddens = self.rnn(packed) 121 | if not return_packed: 122 | return pad_packed_sequence(outputs)[0], hiddens 123 | return outputs, hiddens 124 | 125 | 126 | class DecoderRNN(nn.Module): 127 | """ 128 | """ 129 | def __init__(self, vocab_size, embed_size, hidden_size, n_layers=1, encoder_hidden=None, dropout_p=0.2, padding_idx=1, packed=True): 130 | super(DecoderRNN, self).__init__() 131 | self.n_layers = n_layers 132 | self.hidden_size = hidden_size 133 | self.vocab_size = vocab_size 134 | self.embed_size = embed_size 135 | 136 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx) 137 | self.rnn = nn.GRU(input_size=embed_size, hidden_size=hidden_size, num_layers=n_layers) 138 | 139 | # h_t^T W h_s 140 | self.linear_out = nn.Linear(hidden_size, vocab_size) 141 | self.attn = GlobalAttention(encoder_hidden, hidden_size) 142 | self.dropout = nn.Dropout(dropout_p) 143 | 144 | def forward(self, inputs, hidden, context, context_lengths): 145 | """ 146 | inputs: (tgt_len, batch_size, d) 147 | hidden: last hidden state from encoder 148 | context: (src_len, batch_size, hidden_size), outputs of encoder 149 | """ 150 | # Teacher-forcing, not packed! 151 | embedded = self.embedding(inputs) 152 | embedded = self.dropout(embedded) 153 | decoder_unpacked, decoder_hidden = self.rnn(embedded, hidden) 154 | # Calculate the attention. 155 | attn_outputs, attn_scores = self.attn( 156 | decoder_unpacked.transpose(0, 1).contiguous(), # (len, batch, d) -> (batch, len, d) 157 | context.transpose(0, 1).contiguous(), # (len, batch, d) -> (batch, len, d) 158 | context_lengths=context_lengths 159 | ) 160 | # Don't need LogSoftmax with CrossEntropyLoss 161 | # the outputs are not normalized, and can be negative 162 | # Note that a mask is needed to compute the loss 163 | outputs = self.linear_out(attn_outputs) 164 | return outputs, decoder_hidden 165 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017 Yifan WANG 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | from torch import optim 18 | import torch.nn.functional as F 19 | from torch.utils import data 20 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 21 | from torch.distributions import Categorical 22 | from torchtext.vocab import Vocab 23 | # from torchtext.vocab import GloVe 24 | from torchtext.data import Field, Pipeline, RawField, Dataset, Example, BucketIterator 25 | from torchtext.data import get_tokenizer 26 | import os, time, sys, datetime, argparse, pickle 27 | 28 | from model import EncoderRNN, DecoderRNN 29 | import config 30 | from utils import * 31 | 32 | EOS = "" 33 | SOS = "" 34 | PAD = "" 35 | np.random.seed(666) 36 | 37 | 38 | def main(args): 39 | start = time.time() 40 | print(since(start) + "Loading data with configuration '{0}'...".format(args.config)) 41 | c = getattr(config, args.config)() 42 | c['use_cuda'] = args.use_cuda 43 | datasets, src_field, trg_field = load_data(c) 44 | # TODO: validation dataset 45 | 46 | train = datasets['train'] 47 | src_field.build_vocab(train, max_size=c['encoder_vocab']) 48 | trg_field.build_vocab(train, max_size=c['decoder_vocab']) 49 | del train 50 | print("Source vocab: {0}".format(len(src_field.vocab.itos))) 51 | print("Target vocab: {0}".format(len(trg_field.vocab.itos))) 52 | 53 | test = datasets['test'] 54 | n_test = len(test.examples) 55 | 56 | test_iter = iter(BucketIterator( 57 | dataset=test, batch_size=1, 58 | sort_key=lambda x: -len(x.src), device=-1)) 59 | 60 | PAD_IDX = trg_field.vocab.stoi[PAD] # default=1 61 | 62 | print(since(start) + "Loading models...") 63 | encoder = torch.load(c['model_path'] + c['prefix'] + 'encoder.pkl') 64 | decoder = torch.load(c['model_path'] + c['prefix'] + 'decoder.pkl') 65 | 66 | if c['use_cuda']: 67 | encoder.cuda() 68 | decoder.cuda() 69 | else: 70 | encoder.cpu() 71 | decoder.cpu() 72 | 73 | CEL = nn.CrossEntropyLoss(size_average=True, ignore_index=PAD_IDX) 74 | test_losses = [] 75 | test_rouges = [] 76 | gts = [] 77 | greedys = [] 78 | synchronize(c) 79 | for i in range(n_test): 80 | test_batch = next(test_iter) 81 | test_encoder_inputs, test_encoder_lengths = test_batch.src 82 | test_decoder_inputs, test_decoder_lengths = test_batch.trg 83 | test_encoder_inputs = cuda(Variable(test_encoder_inputs.data, volatile=True), c['use_cuda']) 84 | test_decoder_inputs = cuda(Variable(test_decoder_inputs.data, volatile=True), c['use_cuda']) 85 | 86 | test_encoder_packed, test_encoder_hidden = encoder(test_encoder_inputs, test_encoder_lengths) 87 | test_encoder_unpacked = pad_packed_sequence(test_encoder_packed)[0] 88 | # remove last symbol 89 | test_decoder_unpacked, test_decoder_hidden = decoder(test_decoder_inputs[:-1,:], test_encoder_hidden, test_encoder_unpacked, test_encoder_lengths) 90 | trg_len, batch_size, d = test_decoder_unpacked.size() 91 | 92 | test_loss = CEL(test_decoder_unpacked.view(trg_len*batch_size, d), test_decoder_inputs[1:,:].view(-1)) 93 | 94 | test_enc_input = (test_encoder_inputs[:,0].unsqueeze(1), torch.LongTensor([test_encoder_lengths[0]])) 95 | # use self critical training 96 | test_greedy_out, _ = sample(encoder, decoder, test_enc_input, trg_field, 97 | max_len=30, greedy=True, config=c) 98 | test_greedy_sent = tostr(clean(test_greedy_out)) 99 | 100 | test_gt_sent = tostr(clean(itos(test_decoder_inputs[:,0].cpu().data.numpy(), trg_field))) 101 | 102 | gts.append(test_gt_sent) 103 | greedys.append(test_greedy_sent) 104 | test_rouges.append(score(hyps=test_greedy_sent, refs=test_gt_sent, metric='rouge')['rouge-1']['f']) 105 | test_losses.append(float(test_loss.cpu().data.numpy().tolist()[0])) 106 | synchronize(c) 107 | print("\tTest ROUGE-1_f: ", np.mean(test_rouges)) 108 | print("\tTest Loss: ", np.mean(test_losses)) 109 | 110 | with open('test.log' ,'w') as f: 111 | f.write("Test loss: {0}\n".format(np.mean(test_losses))) 112 | f.write("{0} samples, svg ROUGE-1_f: {1}\n".format(n_test, np.mean(test_rouges))) 113 | for i in range(n_test): 114 | f.write(str(test_losses[i]) + '\n') 115 | f.write(str(test_rouges[i]) + '\n') 116 | f.write(str(gts[i]) + '\n') 117 | f.write(str(greedys[i]) + '\n') 118 | 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--config', type=str, default=None , 124 | help='model configurations, defined in config.py') 125 | parser.add_argument('--from_scratch', type=bool, default=False) 126 | parser.add_argument('--disable_cuda', type=bool, default=False) 127 | parser.add_argument('--self_critical', type=float, default=0.) 128 | args = parser.parse_args() 129 | args.use_cuda = not args.disable_cuda and torch.cuda.is_available() 130 | if args.use_cuda: 131 | print("Use GPU...") 132 | else: 133 | print("Use CPU...") 134 | main(args) 135 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017 Yifan WANG 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | """ 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch import optim 17 | import torch.nn.functional as F 18 | from torch.utils import data 19 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 20 | from torch.distributions import Categorical 21 | from torchtext.vocab import Vocab 22 | # from torchtext.vocab import GloVe 23 | from torchtext.data import Field, Pipeline, RawField, Dataset, Example, BucketIterator 24 | from torchtext.data import get_tokenizer 25 | import os, time, sys, datetime, argparse, pickle, json 26 | 27 | from model import EncoderRNN, DecoderRNN 28 | import config 29 | from utils import * 30 | 31 | EOS = "" 32 | SOS = "" 33 | PAD = "" 34 | np.random.seed(666) 35 | 36 | 37 | def main(args): 38 | # Load configurations 39 | start = time.time() 40 | c = getattr(config, args.config)() 41 | c['use_cuda'], c['exp'], c['mode'] = args.use_cuda, args.exp, args.mode 42 | assert c['exp'] is not None, "'exp' must be specified." 43 | logger = init_logging('log/{0}_{1}_{2}.log'.format(c['prefix'], c['exp'], start)) 44 | logger.info(since(start) + "Loading data with configuration '{0}':\n{1}".format(args.config, str(c))) 45 | 46 | # Load datasets 47 | datasets, src_field, trg_field = load_data(c) 48 | 49 | train = datasets['train'] 50 | n_train = len(train.examples) 51 | test = datasets['test'] 52 | n_test = len(test.examples) 53 | valid = datasets['valid'] 54 | n_valid = len(valid.examples) 55 | num_epoch = c['num_epoch'] if not args.early_stopping else c['max_epoch'] 56 | batch_per_epoch = n_train // c['batch_size'] if n_train % c['batch_size'] == 0 else n_train // c['batch_size']+1 57 | n_iters = batch_per_epoch * num_epoch 58 | 59 | # Build vocabularies 60 | src_field.build_vocab(train, max_size=c['encoder_vocab']) 61 | trg_field.build_vocab(train, max_size=c['decoder_vocab']) 62 | PAD_IDX = trg_field.vocab.stoi[PAD] # default=1 63 | 64 | logger.info("Source vocab: {0}".format(len(src_field.vocab.itos))) 65 | logger.info("Target vocab: {0}".format(len(trg_field.vocab.itos))) 66 | logger.info(since(start) + "{0} training samples, {1} epochs, batch size={2}, {3} batches per epoch.".format( 67 | n_train, num_epoch, c['batch_size'], batch_per_epoch)) 68 | 69 | train_iter = iter(BucketIterator( 70 | dataset=train, batch_size=c['batch_size'], sort=True, 71 | sort_key=lambda x: len(x.src), device=-1)) 72 | 73 | test_iter = iter(BucketIterator( 74 | dataset=test, batch_size=1, sort=True, 75 | sort_key=lambda x: len(x.src), device=-1)) 76 | 77 | valid_iter = iter(BucketIterator( 78 | dataset=valid, batch_size=1,sort=True, 79 | sort_key=lambda x: len(x.src), device=-1)) 80 | 81 | del train 82 | del test 83 | del valid 84 | 85 | encoder = EncoderRNN(vocab_size=len(src_field.vocab), embed_size=c['encoder_embed_size'],\ 86 | hidden_size=c['encoder_hidden_size'], padding_idx=PAD_IDX, n_layers=c['num_layers']) 87 | decoder = DecoderRNN(vocab_size=len(trg_field.vocab), embed_size=c['decoder_embed_size'],\ 88 | hidden_size=c['decoder_hidden_size'], encoder_hidden=c['encoder_hidden_size'],\ 89 | padding_idx=PAD_IDX, n_layers=c['num_layers']) 90 | if not args.resume: 91 | # Train from scratch 92 | params = list(encoder.parameters()) + list(decoder.parameters()) 93 | optimizer = optim.Adam(params, lr=c['learning_rate']) 94 | init_epoch = init_step = 0 95 | history = {'epochs':[], 96 | 'train_loss':[], 97 | 'valid_loss':[], 98 | 'test_loss':[], 99 | 'test_score':[], 100 | 'best_epoch':-1, 101 | 'best_loss':float("inf")} 102 | 103 | logger.info(since(start) + "Start training... {0} epochs, {1} steps per epoch.".format( 104 | num_epoch, batch_per_epoch)) 105 | else: 106 | assert os.path.isfile("{0}{1}_{2}.pkl".format(c['model_path'], c['prefix'], c['exp'])) 107 | # Load checkpoint 108 | logger.info(since(start) + "Loading from {0}{1}_{2}.pkl".format(c['model_path'], c['prefix'], c['exp'])) 109 | cp = torch.load("{0}{1}_{2}.pkl".format(c['model_path'], c['prefix'], c['exp'])) 110 | encoder.load_state_dict(cp['encoder']) 111 | decoder.load_state_dict(cp['decoder']) 112 | params = list(encoder.parameters()) + list(decoder.parameters()) 113 | optimizer = optim.Adam(params, lr=c['learning_rate']) 114 | optimizer.load_state_dict(cp['optimizer']) 115 | init_epoch, init_step, others, history = cp['epoch'], cp['step'], cp['others'], cp['history'] 116 | del cp 117 | logger.info(since(start) + "Resume training from {0}/{1} epoch, {2}/{3} step".format( 118 | init_epoch+1, num_epoch, init_step+1, batch_per_epoch)) 119 | 120 | if c['use_cuda']: 121 | encoder.cuda() 122 | decoder.cuda() 123 | else: 124 | encoder.cpu() 125 | decoder.cpu() 126 | 127 | CEL = nn.CrossEntropyLoss(size_average=True, ignore_index=PAD_IDX) 128 | print_loss = 0 129 | 130 | # Start training 131 | for e in range(init_epoch, num_epoch): 132 | for j in range(init_step, batch_per_epoch): 133 | init_step = 0 134 | i = batch_per_epoch*e + j + 1 # global step 135 | 136 | batch = next(train_iter) 137 | encoder_inputs, encoder_lengths = batch.src 138 | decoder_inputs, decoder_lengths = batch.trg 139 | 140 | encoder_inputs = cuda(encoder_inputs, c['use_cuda']) 141 | decoder_inputs = cuda(decoder_inputs, c['use_cuda']) 142 | 143 | encoder_unpacked, encoder_hidden = encoder(encoder_inputs, encoder_lengths, return_packed=False) 144 | # we don't remove the last symbol 145 | decoder_unpacked, decoder_hidden = decoder(decoder_inputs[:-1,:], encoder_hidden, encoder_unpacked, encoder_lengths) 146 | trg_len, batch_size, d = decoder_unpacked.size() 147 | # remove first symbol 148 | ce_loss = CEL(decoder_unpacked.view(trg_len*batch_size, d), decoder_inputs[1:,:].view(-1)) 149 | print_loss += ce_loss.data 150 | 151 | # Self-critical sequence training 152 | assert args.self_critical >= 0. and args.self_critical <= 1. 153 | if args.self_critical > 1e-5: 154 | sc_loss = cuda(Variable(torch.Tensor([0.])), c['use_cuda']) 155 | for j in range(batch_size): 156 | enc_input = (encoder_inputs[:,j].unsqueeze(1), torch.LongTensor([encoder_lengths[j]])) 157 | # use self critical training 158 | greedy_out, _ = sample(encoder, decoder, enc_input, trg_field, 159 | max_len=30, greedy=True, config=c) 160 | greedy_sent = tostr(clean(greedy_out)) 161 | sample_out, sample_logp = sample(encoder, decoder, enc_input, trg_field, 162 | max_len=30, greedy=False, config=c) 163 | sample_sent = tostr(clean(sample_out)) 164 | # Ground truth 165 | gt_sent = tostr(clean(itos(decoder_inputs[:,j].cpu().data.numpy(), trg_field))) 166 | greedy_score = score(hyps=greedy_sent, refs=gt_sent, metric='rouge') 167 | sample_score = score(hyps=sample_sent, refs=gt_sent, metric='rouge') 168 | reward = Variable(torch.Tensor([sample_score["rouge-1"]['f'] - greedy_score["rouge-1"]['f']]), requires_grad=False) 169 | reward = cuda(reward, c['use_cuda']) 170 | sc_loss -= reward*torch.sum(sample_logp) 171 | 172 | if i % c['log_step'] == 0: 173 | logger.info("CE loss: {0}".format(ce_loss)) 174 | logger.info("RL loss: {0}".format(sc_loss)) 175 | logger.info("Ground truth: {0}".format(gt_sent)) 176 | logger.info("Greedy: {0}, {1}".format(greedy_score['rouge-1']['f'], greedy_sent)) 177 | logger.info("Sample: {0}, {1}".format(sample_score['rouge-1']['f'], sample_sent)) 178 | 179 | loss = (1-args.self_critical) * ce_loss + args.self_critical * sc_loss 180 | else: 181 | loss = ce_loss 182 | 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | 187 | del encoder_inputs, decoder_inputs 188 | 189 | if i % c['save_step'] == 0: 190 | # Save model for resuming 191 | synchronize(c) 192 | logger.info(since(start) + "Saving models.") 193 | cp = {'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 194 | 'optimizer': optimizer.state_dict(), 'others': {}, 195 | 'step': j+1, 'epoch': e, 'history': history} 196 | torch.save(cp, "{0}{1}_{2}.pkl".format(c['model_path'], c['prefix'], c['exp'])) 197 | 198 | if i % c['log_step'] == 0: 199 | synchronize(c) 200 | logger.info(since(start) + 'epoch {0}/{1}, iteration {2}/{3}'.format(e, num_epoch, i, batch_per_epoch)) 201 | logger.info("\tTrain loss: {0}".format(print_loss.cpu().numpy().tolist()[0] / c['log_step'])) 202 | print_loss = 0 203 | random_eval(encoder, decoder, batch, n=1, src_field=src_field, trg_field=trg_field, config=c, 204 | greedy=True, logger=logger) 205 | 206 | # Evaluate on test set 207 | if i % c['test_step'] == 0: 208 | test_loss = 0 209 | test_rouge = 0 210 | refs = [] 211 | greedys = [] 212 | for j in range(n_test): 213 | test_batch = next(test_iter) 214 | test_encoder_inputs, test_encoder_lengths = test_batch.src 215 | test_decoder_inputs, test_decoder_lengths = test_batch.trg 216 | # GPU 217 | test_encoder_inputs = cuda(Variable(test_encoder_inputs.data, volatile=True), c['use_cuda']) 218 | test_decoder_inputs = cuda(Variable(test_decoder_inputs.data, volatile=True), c['use_cuda']) 219 | 220 | test_encoder_unpacked, test_encoder_hidden = encoder(test_encoder_inputs, test_encoder_lengths, return_packed=False) 221 | # we don't remove the last symbol 222 | test_decoder_unpacked, test_decoder_hidden = decoder(test_decoder_inputs[:-1,:], test_encoder_hidden, test_encoder_unpacked, test_encoder_lengths) 223 | trg_len, batch_size, d = test_decoder_unpacked.size() 224 | # remove first symbol 225 | test_ce_loss = CEL(test_decoder_unpacked.view(trg_len*batch_size, d), test_decoder_inputs[1:,:].view(-1)) 226 | test_loss += test_ce_loss.data 227 | 228 | test_enc_input = (test_encoder_inputs[:,0].unsqueeze(1), torch.LongTensor([test_encoder_lengths[0]])) 229 | test_greedy_out, _ = sample(encoder, decoder, test_enc_input, trg_field, 230 | max_len=30, greedy=True, config=c) 231 | test_greedy_sent = tostr(clean(test_greedy_out)) 232 | 233 | test_gt_sent = tostr(clean(itos(test_decoder_inputs[:,0].cpu().data.numpy(), trg_field))) 234 | refs.append(test_gt_sent) 235 | greedys.append(test_greedy_sent) 236 | 237 | 238 | rouges = get_rouge(hyps=greedys, refs=refs) 239 | synchronize(c) 240 | logger.info(since(start) + "Test loss: {0}".format(test_loss.cpu().numpy().tolist()[0]/n_test)) 241 | logger.info(rouges) 242 | 243 | # One epoch is finished 244 | logger.info(since(start) + "Epoch {0} is finished.".format(e)) 245 | # Evaluate on validation set and perform early stopping 246 | valid_loss = 0 247 | for j in range(n_valid): 248 | batch = next(valid_iter) 249 | encoder_inputs, encoder_lengths = batch.src 250 | decoder_inputs, decoder_lengths = batch.trg 251 | 252 | encoder_inputs = cuda(encoder_inputs, c['use_cuda']) 253 | decoder_inputs = cuda(decoder_inputs, c['use_cuda']) 254 | 255 | encoder_unpacked, encoder_hidden = encoder(encoder_inputs, encoder_lengths, return_packed=False) 256 | decoder_unpacked, decoder_hidden = decoder(decoder_inputs[:-1,:], encoder_hidden, encoder_unpacked, encoder_lengths) 257 | trg_len, batch_size, d = decoder_unpacked.size() 258 | valid_ce = CEL(decoder_unpacked.view(trg_len*batch_size, d), decoder_inputs[1:,:].view(-1)) 259 | valid_loss += valid_ce.data 260 | history['valid_loss'].append(valid_loss.cpu().numpy().tolist()[0]/n_valid) 261 | synchronize(c) 262 | logger.info(since(start) + "Saving models.") 263 | cp = {'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 264 | 'optimizer': optimizer.state_dict(), 'others': {}, 265 | 'step': j+1, 'epoch': e, 'history': history} 266 | torch.save(cp, "{0}{1}_{2}_epoch{3}.pkl".format(c['model_path'], c['prefix'], c['exp'], e)) 267 | logger.info("Epoch {0}, valid loss: {1}".format(e, history['valid_loss'][-1])) 268 | if history['valid_loss'][-1] < history['best_loss']: 269 | history['best_loss'] = history['valid_loss'][-1] 270 | history['best_epoch'] = e 271 | torch.save(cp, "{0}{1}_{2}_best.pkl".format(c['model_path'], c['prefix'], c['exp'])) 272 | elif args.early_stopping and e - history['best_epoch'] > patient: 273 | # early stopping 274 | logger.info(since(start) + "Early stopping at epoch {0}, best result at epoch {1}".format(e, history['best_epoch'])) 275 | return 276 | 277 | if __name__ == '__main__': 278 | parser = argparse.ArgumentParser() 279 | parser.add_argument('--config', type=str, default=None , 280 | help='model configurations, defined in config.py') 281 | parser.add_argument('--disable_cuda', type=bool, default=False) 282 | parser.add_argument('--resume', dest='resume', action='store_true') 283 | parser.add_argument('--early_stopping', dest='early_stopping', action='store_true', 284 | help='With early stopping, the training will end when valid loss doesn\'t decrease \ 285 | for `--patient` epochs, or at `max_epoch` epoch.') 286 | parser.add_argument('--self_critical', type=float, default=0.) 287 | parser.add_argument('--exp', type=str, default=None, help='A string that specify the name of the experiment') 288 | parser.add_argument('--mode', type=str, default='train') 289 | parser.add_argument('--patient', type=int, default=5) 290 | args = parser.parse_args() 291 | args.use_cuda = not args.disable_cuda and torch.cuda.is_available() 292 | 293 | print(args) 294 | go = input('Start? y/n\n') 295 | if go != 'y': 296 | exit() 297 | if args.use_cuda: 298 | print("Use GPU...") 299 | else: 300 | print("Use CPU...") 301 | main(args) 302 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017 Yifan WANG 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | Utility functions 11 | """ 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch import optim 17 | import torch.nn.functional as F 18 | from torch.utils import data 19 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 20 | from torch.distributions import Categorical 21 | import os, time, sys, datetime, argparse, pickle, logging 22 | import spacy 23 | from torchtext.vocab import Vocab 24 | # from torchtext.vocab import GloVe 25 | from torchtext.data import Field, Pipeline, RawField, Dataset, Example, BucketIterator 26 | from torchtext.data import get_tokenizer 27 | from rouge import Rouge 28 | 29 | # TODO: add these into configuration 30 | EOS = "" 31 | SOS = "" 32 | PAD = "" 33 | 34 | def split_data(root, filenames, exts, train_ratio=0.8, test_ratio=0.2): 35 | """ 36 | Examples: filenames = ['en.txt', 'fr.txt'], exts = ['src', 'trg'] 37 | => train.src, train.trg; test.src, test.trg 38 | """ 39 | # TODO: check the extension names 40 | eps = 1e-5 41 | valid_ratio = 1 - train_ratio - test_ratio 42 | p = None 43 | for name, ext in zip(filenames, exts): 44 | print("Opening {0}".format(name)) 45 | with open(root + name, 'r') as f: 46 | lines = f.readlines() 47 | n = len(lines) 48 | p = np.random.permutation(n) if p is None else p 49 | train, test, valid = np.split(np.arange(n)[p], [int(n*train_ratio), int(n*train_ratio+n*test_ratio)]) 50 | 51 | train = [lines[i] for i in train] 52 | test = [lines[i] for i in test] 53 | valid = [lines[i] for i in valid] if valid_ratio > eps else valid 54 | for samples, mode in [(train, 'train'), (test, 'test'), (valid, 'valid')]: 55 | if valid_ratio < eps and mode == 'valid': 56 | continue 57 | out = open(root + mode + ext, 'w') 58 | for l in samples: 59 | out.write(l.strip() + '\n') 60 | out.close() 61 | print("Train: {0}\nTest: {1}\nValidation: {2}".format(len(train), len(test), len(valid))) 62 | 63 | 64 | def stoi(s, field): 65 | sent = [field.vocab.stoi[w] for w in s] 66 | return sent 67 | 68 | def itos(s, field): 69 | sent = [field.vocab.itos[w] for w in s] 70 | return sent 71 | 72 | def since(t): 73 | return '[' + str(datetime.timedelta(seconds=time.time() - t)) + '] ' 74 | 75 | def init_logging(log_name): 76 | """ 77 | 78 | """ 79 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(module)s: %(message)s', 80 | datefmt='%m/%d/%Y %H:%M:%S' ) 81 | handler = logging.FileHandler(log_name) 82 | out = logging.StreamHandler(sys.stdout) 83 | 84 | handler.setFormatter(formatter) 85 | out.setFormatter(formatter) 86 | out.setLevel(logging.INFO) 87 | logging.getLogger().addHandler(handler) 88 | logging.getLogger().addHandler(out) 89 | logging.getLogger().setLevel(logging.INFO) 90 | return logging 91 | 92 | 93 | def load_data(c): 94 | """ 95 | Load datasets, return a dictionary of datasets and fields 96 | """ 97 | 98 | # TODO: add field for context 99 | 100 | spacy_src = spacy.load(c['src_lang']) 101 | spacy_trg = spacy.load(c['trg_lang']) 102 | 103 | def tokenize_src(text): 104 | return [tok.text for tok in spacy_src.tokenizer(text)] 105 | 106 | def tokenize_trg(text): 107 | return [tok.text for tok in spacy_trg.tokenizer(text)] 108 | 109 | src_field = Field(tokenize=tokenize_src, include_lengths=True, eos_token=EOS, lower=True) 110 | trg_field= Field(tokenize=tokenize_trg, include_lengths=True, eos_token=EOS, lower=True, init_token=SOS) 111 | 112 | datasets = {} 113 | # load processed data 114 | for split in c['splits']: 115 | if os.path.isfile(c['root'] + split + '.pkl'): 116 | print('Loading {0}'.format(c['root'] + split + '.pkl')) 117 | examples = pickle.load(open(c['root'] + split + '.pkl', 'rb')) 118 | datasets[split] = Dataset(examples = examples, fields={'src':src_field,'trg': trg_field}) 119 | else: 120 | src_path = c['root'] + split + '.src' 121 | trg_path = c['root'] + split + '.trg' 122 | examples = c['load'](src_path, trg_path, src_field, trg_field) 123 | datasets[split] = Dataset(examples = examples, fields={'src':src_field,'trg': trg_field}) 124 | print('Saving to {0}'.format(c['root'] + split + '.pkl')) 125 | pickle.dump(examples, open(c['root'] + split + '.pkl', 'wb')) 126 | 127 | return datasets, src_field, trg_field 128 | 129 | 130 | def cuda(var, use_cuda): 131 | if use_cuda: 132 | var = var.cuda() 133 | return var 134 | 135 | 136 | def evaluate(encoder, decoder, var, trg_field, max_len=30, beam_size=-1): 137 | """ 138 | var: tuple of tensors 139 | """ 140 | logsm = nn.LogSoftmax() 141 | # Beam search 142 | # TODO: check the beam search 143 | H = [([SOS], 0.)] 144 | H_temp = [] 145 | H_final = [] 146 | use_cuda = next(encoder.parameters()).is_cuda 147 | 148 | outputs = [] 149 | encoder_inputs, encoder_lengths = var 150 | encoder_inputs = cuda(encoder_inputs, use_cuda) 151 | # encoder_lengths = cuda(encoder_lengths, use_cuda) 152 | encoder_unpacked, encoder_hidden = encoder(encoder_inputs, encoder_lengths, return_packed=False) 153 | 154 | decoder_hidden = encoder_hidden 155 | decoder_inputs, decoder_lenghts = trg_field.numericalize(([[SOS]], [1]), device=-1) 156 | decoder_inputs = cuda(decoder_inputs, use_cuda) 157 | if beam_size > 0: 158 | for i in range(max_len): 159 | for h in H: 160 | hyp, s = h 161 | decoder_inputs, decoder_lenghts = trg_field.numericalize(([hyp], [len(hyp)]), device=-1) 162 | decoder_unpacked, decoder_hidden = decoder(decoder_inputs, decoder_hidden, encoder_unpacked, encoder_lengths) 163 | topv, topi = decoder_unpacked.data[-1].topk(beam_size) 164 | topv = logsm(topv) 165 | for j in range(beam_size): 166 | nj = int(topi.numpy()[0][j]) 167 | hyp_new = hyp + [trg_field.vocab.itos[nj]] 168 | s_new = s + topv.data.numpy().tolist()[-1][j] 169 | if trg_field.vocab.itos[nj] == EOS: 170 | H_final.append((hyp_new, s_new)) 171 | else: 172 | H_temp.append((hyp_new, s_new)) 173 | H_temp = sorted(H_temp, key=lambda x:x[1], reverse=True) 174 | H = H_temp[:beam_size] 175 | H_temp = [] 176 | 177 | H_final = sorted(H_final, key=lambda x:x[1], reverse=True) 178 | outputs = [" ".join(H_final[i][0]) for i in range(beam_size)] 179 | 180 | else: 181 | for i in range(max_len): 182 | # Eval mode, dropout is not used 183 | decoder_unpacked, decoder_hidden = decoder.eval()(decoder_inputs, decoder_hidden, encoder_unpacked, encoder_lengths) 184 | topv, topi = decoder_unpacked.data.topk(1) 185 | ni = int(topi.cpu().numpy()[0][0][0]) 186 | if trg_field.vocab.itos[ni] == EOS: 187 | outputs.append(EOS) 188 | break 189 | else: 190 | outputs.append(trg_field.vocab.itos[ni]) 191 | decoder_inputs = Variable(torch.LongTensor([[ni]])) 192 | decoder_inputs = cuda(decoder_inputs, use_cuda) 193 | outputs = " ".join(outputs) 194 | return outputs.strip() 195 | 196 | def sample(encoder, decoder, var, trg_field, max_len=30, greedy=False, config=None): 197 | """ Sample an output given the input 198 | Args: 199 | var: (Tensor, List) tuple 200 | 201 | Returns: (outputs, log_probas) 202 | outputs: a list of str 203 | log_probas: Tensor (1, len) 204 | 205 | """ 206 | # use_cuda = next(encoder.parameters()).is_cuda 207 | use_cuda = config['use_cuda'] 208 | ls = nn.LogSoftmax() 209 | log_probas = [] 210 | outputs = [] 211 | 212 | encoder_inputs, encoder_lengths = var 213 | encoder_inputs = cuda(encoder_inputs, use_cuda) 214 | encoder_unpacked, encoder_hidden = encoder(encoder_inputs, encoder_lengths, return_packed=False) 215 | decoder_hidden = encoder_hidden 216 | decoder_inputs, decoder_lenghts = trg_field.numericalize(([[SOS]], [1]), device=-1) 217 | decoder_inputs = cuda(decoder_inputs, use_cuda) 218 | for i in range(max_len): 219 | # TODO: shall we use eval mode? 220 | # decoder_unpacked: (1, 1, vocab_size), eval() is effective to Dropout and BatchNorm 221 | decoder_unpacked, decoder_hidden = decoder.eval()(decoder_inputs, decoder_hidden, encoder_unpacked, encoder_lengths) 222 | if greedy: 223 | logp, ni = torch.max(ls(decoder_unpacked.squeeze()), 0) 224 | # ni must be an integer, not like numpy.int32 225 | ni = int(ni.data.cpu().numpy()[0]) 226 | else: 227 | m = Categorical(F.softmax(decoder_unpacked.squeeze())) 228 | ni = m.sample() 229 | logp = m.log_prob(ni) 230 | ni = int(ni.cpu().data.numpy()[0]) 231 | if trg_field.vocab.itos[ni] == EOS: 232 | outputs.append(EOS) 233 | log_probas.append(logp) 234 | # Note that the log proba of EOS is not saved, 235 | # In this case, there will be no log proba 236 | break 237 | else: 238 | outputs.append(trg_field.vocab.itos[ni]) 239 | log_probas.append(logp) 240 | decoder_inputs = Variable(torch.LongTensor([[ni]])) 241 | decoder_inputs = cuda(decoder_inputs, use_cuda) 242 | # => row vector 243 | seq_log_probas = torch.cat([p.unsqueeze(1) for p in log_probas], 1) 244 | return outputs, seq_log_probas 245 | 246 | 247 | 248 | def random_eval(encoder, decoder, batch, n, src_field, trg_field, config=None, 249 | greedy=False, metric='rouge', logger=None): 250 | 251 | enc_inputs, enc_lengths = batch.src 252 | dec_inputs, dec_lengths = batch.trg 253 | 254 | N = enc_inputs.size()[1] 255 | idx = np.random.choice(N, n) 256 | for i in idx: 257 | logger.info('> ' + tostr(clean(itos(enc_inputs[:,i].cpu().data.numpy(), src_field)))) 258 | logger.info('= ' + tostr(clean(itos(dec_inputs[:,i].cpu().data.numpy(), trg_field)))) 259 | enc_input = (enc_inputs[:,i].unsqueeze(1), torch.LongTensor([enc_lengths[i]])) 260 | outputs, _ = sample(encoder, decoder, enc_input, trg_field, max_len=30, greedy=greedy, config=config) 261 | # sent = evaluate(encoder, decoder, enc_input, trg_field=trg_field, beam_size=beam_size) 262 | logger.info('< ' + tostr(clean(outputs)) + '\n') 263 | 264 | 265 | def score(hyps, refs, metric='rouge'): 266 | """ 267 | Args: 268 | hyp: predicted sentence 269 | ref: reference sentence 270 | metric: metric to use 271 | """ 272 | assert metric in ['rouge', 'bleu'] 273 | if metric is 'rouge': 274 | rouge = Rouge() 275 | # {"rouge-1": {"f": _, "p": _, "r": _}, "rouge-2" : { .. }, "rouge-3": { ... }} 276 | scores = rouge.get_scores(hyps, refs, avg=True) 277 | elif metric is 'bleu': 278 | pass 279 | return scores 280 | 281 | def get_rewards(encoder, decoder,src_field, trg_field, beam_size=-1, metric='rouge'): 282 | pass 283 | 284 | def synchronize(config): 285 | if config['use_cuda']: 286 | torch.cuda.synchronize() 287 | 288 | def clean(l): 289 | """ 290 | Remove special symbols from a list of str 291 | """ 292 | symbols = [EOS, SOS, PAD] 293 | return [w for w in l if w not in symbols] 294 | 295 | def tostr(l): 296 | return " ".join(l) 297 | 298 | def get_rouge(hyps, refs): 299 | """ 300 | Get average ROUGE-1, ROUGE-2, ROUGE-L F-1 scores 301 | """ 302 | scores = score(hyps=hyps, refs=refs, metric='rouge') 303 | s = "\nROUGE-1: {0}\nROUGE-2: {1}\nROUGE-L: {2}\n".format( 304 | scores['rouge-1']['f'], scores['rouge-2']['f'], 305 | scores['rouge-l']['f']) 306 | return s 307 | --------------------------------------------------------------------------------