├── .gitignore ├── NMT ├── CheckPoint.py ├── Loss.py ├── ModelFactory.py ├── Models │ ├── RNN │ │ ├── Decoder.py │ │ ├── DecoderState.py │ │ ├── Encoder.py │ │ ├── Model.py │ │ ├── Modules │ │ │ ├── StackedRNN.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── Decoder.cpython-36.pyc │ │ │ ├── DecoderState.cpython-36.pyc │ │ │ ├── Encoder.cpython-36.pyc │ │ │ ├── Model.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ ├── Transformer │ │ ├── Decoder.py │ │ ├── DecoderState.py │ │ ├── Encoder.py │ │ ├── Model.py │ │ ├── Modules │ │ │ ├── MultiHeadedAttn.py │ │ │ ├── PositionalEncoding.py │ │ │ ├── PositionwiseFF.py │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── MultiHeadedAttn.cpython-36.pyc │ │ │ │ ├── PositionalEncoding.cpython-36.pyc │ │ │ │ ├── PositionwiseFF.cpython-36.pyc │ │ │ │ └── __init__.cpython-36.pyc │ │ ├── VModel.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── BaseModel.cpython-36.pyc │ │ │ ├── Decoder.cpython-36.pyc │ │ │ ├── DecoderState.cpython-36.pyc │ │ │ ├── Encoder.cpython-36.pyc │ │ │ ├── Model.cpython-36.pyc │ │ │ ├── Utils.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ └── backup │ │ ├── Decoders.py │ │ ├── Encoders.py │ │ ├── VNMTModel.py │ │ └── VRNMTModel.py ├── Modules │ ├── GlobalAttention.py │ ├── IntehgerEncoding.py │ ├── MultiHeadedAttn.py │ ├── StackedRNN.py │ ├── __init__.py │ └── __pycache__ │ │ ├── Embeddings.cpython-36.pyc │ │ ├── GlobalAttention.cpython-36.pyc │ │ ├── MultiHeadedAttn.cpython-36.pyc │ │ ├── StackedRNN.cpython-36.pyc │ │ ├── UtilClass.cpython-36.pyc │ │ └── __init__.cpython-36.pyc ├── Optimizer.py ├── Statistics.py ├── Trainer.py ├── __init__.py ├── __pycache__ │ ├── CheckPoint.cpython-36.pyc │ ├── Loss.cpython-36.pyc │ ├── ModelConstructor.cpython-36.pyc │ ├── ModelFactory.cpython-36.pyc │ ├── Optimizer.cpython-36.pyc │ ├── Statistics.cpython-36.pyc │ ├── Trainer.cpython-36.pyc │ └── __init__.cpython-36.pyc └── translate │ ├── Beam.py │ ├── Penalties.py │ ├── Translation.py │ ├── Translator.py │ ├── __init__.py │ └── __pycache__ │ ├── Beam.cpython-36.pyc │ ├── Penalties.cpython-36.pyc │ ├── Translation.cpython-36.pyc │ ├── Translator.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── README.md ├── Utils ├── DataLoader.py ├── __pycache__ │ ├── DataLoader.cpython-36.pyc │ ├── args.cpython-36.pyc │ ├── bleu.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── log.cpython-36.pyc │ ├── plot.cpython-36.pyc │ ├── rouge.cpython-36.pyc │ └── utils.cpython-36.pyc ├── args.py ├── bleu.py ├── config.py ├── log.py ├── plot.py ├── rouge.py └── utils.py ├── config ├── rnn.ini └── transformer.ini ├── ensemble.py ├── scripts ├── clean_synthetic.sh ├── data-processing-test.sh ├── dataPost-processing.sh ├── eval.test2017.sh ├── eval.test2018.sh ├── extract_raw_test.sh ├── measure_bleu.pl ├── train_bpe.sh ├── train_spm.sh └── train_word2vec.sh ├── train.py └── translate.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | !/.gitignore 3 | *.pyc 4 | scripts 5 | config/* 6 | __pycache__/* 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /NMT/CheckPoint.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import sys 5 | import NMT 6 | import Utils 7 | from Utils.log import trace 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | 13 | def dump_checkpoint(model, path, suffix=""): 14 | checkpoint = CheckPoint(model) 15 | checkpoint.dump(path, suffix) 16 | 17 | class CheckPoint(object): 18 | """ 19 | Standard Training Checkpoint. 20 | """ 21 | def __init__(self, checkpoint): 22 | super(CheckPoint, self).__init__() 23 | self.state_dict = dict() 24 | if isinstance(checkpoint, list): 25 | self.checkpoint_ensemble(checkpoint) 26 | else: 27 | self.state_dict['model'] = self.get_state_dict(checkpoint) 28 | def checkpoint_ensemble(self, checkpoints): 29 | total = len(checkpoints) 30 | params = dict() 31 | for cp in checkpoints: 32 | sub_params = self.load(cp) 33 | for k, p in sub_params.items(): 34 | if k not in params: 35 | params[k] = 0. 36 | params[k] += p 37 | for k in params.keys(): 38 | params[k] /= total 39 | self.state_dict['model'] = params 40 | 41 | def load(self, path): 42 | abspath = os.path.abspath(path) 43 | if os.path.isfile(abspath): 44 | saved = torch.load(path) 45 | if "model" in saved: 46 | params = saved['model'] 47 | else: 48 | params = saved 49 | return params 50 | else: 51 | trace("#ERROR! checkpoint file does not exist !") 52 | sys.exit() 53 | 54 | def dump(self, path, suffix=""): 55 | torch.save(self.state_dict['model'], '%s%s.pt' % (path, suffix)) 56 | 57 | def get_state_dict(self, model): 58 | real_model = (model.module 59 | if isinstance(model, nn.DataParallel) 60 | else model) 61 | model_state_dict = real_model.state_dict() 62 | return dict(model_state_dict.items()) -------------------------------------------------------------------------------- /NMT/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import NMT 6 | import Utils 7 | from NMT.Statistics import Statistics 8 | 9 | import torch.nn.functional as F 10 | from Utils.log import trace 11 | 12 | 13 | class LossBase(nn.Module): 14 | """ 15 | Standard NMT CrossEntropy/NLL Loss Computation. 16 | """ 17 | 18 | def __init__(self, config, padding_idx, vocab_size): 19 | super(LossBase, self).__init__() 20 | self.padding_idx = padding_idx 21 | self.criterion = nn.CrossEntropyLoss( 22 | ignore_index=padding_idx, size_average=False) 23 | 24 | self.num_words = 0 25 | self.num_correct = 0 26 | 27 | def compute(self, probs, golds, normalization): 28 | """Compute the forward loss and backpropagate. 29 | Args: 30 | probs (FloatTensor) : distribution of output model `[(trg_len x batch) x V]` 31 | golds (LongTensor) : target examples 32 | normalization 33 | Returns: 34 | :NMT.Statistics: validation loss statistics 35 | 36 | """ 37 | 38 | vocab_size = probs.size(-1) 39 | loss = self.criterion(probs.view(-1, vocab_size), golds.view(-1)) 40 | loss.div(normalization) 41 | stats = self.create_stats(float(loss), probs, golds) 42 | return loss, stats 43 | 44 | def create_stats(self, loss, probs, golds): 45 | """ 46 | Args: 47 | loss (`FloatTensor`): the loss computed by the loss criterion. 48 | scores (`FloatTensor`): a score for each possible output 49 | target (`FloatTensor`): true targets 50 | 51 | Returns: 52 | `Statistics` : statistics for this batch. 53 | """ 54 | 55 | preds = probs.data.topk(1, dim=-1)[1] 56 | non_padding = golds.ne(self.padding_idx) 57 | correct = preds.squeeze(2).eq(golds).masked_select(non_padding) 58 | num_words = non_padding.long().sum() 59 | num_correct = correct.long().sum() 60 | stats = \ 61 | Statistics(float(loss), int(num_words), int(num_correct)) 62 | return stats 63 | 64 | 65 | class LabelSmoothingLoss(LossBase): 66 | def __init__(self, config, padding_idx, vocab_size, 67 | label_smoothing=0.1): 68 | super(LabelSmoothingLoss, self).__init__( 69 | config, padding_idx, vocab_size) 70 | 71 | self.criterion = nn.KLDivLoss(size_average=False) 72 | one_hot = torch.randn(1, vocab_size).cuda() 73 | one_hot.fill_(config.label_smoothing / (vocab_size - 2)) 74 | one_hot[0][self.padding_idx] = 0 75 | self.register_buffer('one_hot', one_hot) 76 | self.confidence = 1.0 - config.label_smoothing 77 | 78 | def compute(self, probs, golds, normalization, kld=0.): 79 | vocab_size = probs.size(-1) 80 | scores = F.log_softmax(probs.view(-1, vocab_size), dim=-1) 81 | gtruth = golds.view(-1).data 82 | mask = torch.nonzero(gtruth.eq(self.padding_idx)).long() 83 | mask = mask.squeeze() 84 | log_likelihood = torch.gather(scores.data, 1, gtruth.unsqueeze(1)) 85 | tmp = self.one_hot.repeat(gtruth.size(0), 1).cuda() 86 | tmp.scatter_(1, gtruth.unsqueeze(1), self.confidence) 87 | if mask.dim() > 0 and mask.size(0) > 0: 88 | log_likelihood.index_fill_(0, mask, 0) 89 | tmp.index_fill_(0, mask, 0) 90 | gtruth = Variable(tmp, requires_grad=False) 91 | loss = self.criterion(scores, gtruth) 92 | #loss += kld 93 | loss.div(normalization) 94 | stats = self.create_stats(float(loss), probs, golds) 95 | return loss, stats 96 | -------------------------------------------------------------------------------- /NMT/ModelFactory.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is for models creation, which consults options 3 | and creates each encoder and decoder accordingly. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from Utils.log import trace 9 | from Utils.DataLoader import PAD_WORD 10 | from NMT.Models.RNN import RNNModel 11 | from NMT.Models.Transformer import TransformerModel 12 | from NMT.CheckPoint import CheckPoint 13 | 14 | def make_embeddings(config, *vocab): 15 | """ 16 | Make an Embeddings instance. 17 | Args: 18 | vocab (Vocab): words dictionary. 19 | config: global configuration settings. 20 | """ 21 | 22 | if len(vocab) == 2: 23 | trace("Making independent embeddings ...") 24 | src_vocab, trg_vocab = vocab 25 | padding_idx = src_vocab.stoi[PAD_WORD] 26 | src_embeddings = nn.Embedding( 27 | src_vocab.vocab_size, 28 | config.src_embed_dim, 29 | padding_idx=padding_idx, 30 | max_norm=None, 31 | norm_type=2, 32 | scale_grad_by_freq=False, 33 | sparse=config.sparse_embeddings) 34 | trg_embeddings = nn.Embedding( 35 | trg_vocab.vocab_size, 36 | config.src_embed_dim, 37 | padding_idx=padding_idx, 38 | max_norm=None, 39 | norm_type=2, 40 | scale_grad_by_freq=False, 41 | sparse=config.sparse_embeddings) 42 | else: 43 | 44 | assert config.trg_embed_dim == config.src_embed_dim 45 | src_vocab = trg_vocab = vocab[0] 46 | padding_idx = trg_vocab.padding_idx 47 | src_embeddings = nn.Embedding( 48 | src_vocab.vocab_size, 49 | config.src_embed_dim, 50 | padding_idx=padding_idx, 51 | max_norm=None, 52 | norm_type=2, 53 | scale_grad_by_freq=False, 54 | sparse=config.sparse_embeddings) 55 | if config.share_embedding: 56 | trace("Making shared embeddings ...") 57 | trg_embeddings = src_embeddings 58 | else: 59 | trace("Making independent embeddings ...") 60 | trg_embeddings = nn.Embedding( 61 | trg_vocab.vocab_size, 62 | config.trg_embed_dim, 63 | padding_idx=padding_idx, 64 | max_norm=None, 65 | norm_type=2, 66 | scale_grad_by_freq=False, 67 | sparse=config.sparse_embeddings) 68 | return src_vocab, trg_vocab, src_embeddings, trg_embeddings 69 | 70 | 71 | 72 | def model_factory(config, checkpoint, *vocab): 73 | # Make embedding. 74 | 75 | 76 | src_vocab, trg_vocab, src_embeddings, trg_embeddings = \ 77 | make_embeddings(config, *vocab) 78 | 79 | if config.system == "RNN": 80 | model = RNNModel( 81 | src_embeddings, trg_embeddings, 82 | trg_vocab.vocab_size, config) 83 | 84 | elif config.system == "Transformer": 85 | model = TransformerModel( 86 | src_embeddings, trg_embeddings, 87 | trg_vocab.vocab_size, trg_vocab.padding_idx, 88 | config) 89 | if checkpoint: 90 | trace("Loading model parameters from checkpoint: %s." % str(checkpoint)) 91 | cp = CheckPoint(checkpoint) 92 | model.load_state_dict(cp.state_dict['model'], strict = False) 93 | 94 | if config.training: 95 | model.train() 96 | else: 97 | model.eval() 98 | 99 | if config.use_gpu is not None: 100 | model.cuda() 101 | else: 102 | model.cpu() 103 | 104 | return model 105 | -------------------------------------------------------------------------------- /NMT/Models/RNN/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from NMT.Modules import GlobalAttention 4 | from NMT.Modules import StackedLSTM 5 | from NMT.Modules import StackedGRU 6 | from .DecoderState import RNNDecoderState 7 | 8 | 9 | class RNNDecoder(nn.Module): 10 | def __init__(self, 11 | trg_embedding, 12 | rnn_type, 13 | embedding_size, 14 | hidden_size, 15 | num_layers=2, 16 | attn_type="general", 17 | bidirectional_encoder=True, 18 | dropout=0.0): 19 | 20 | super(RNNDecoder, self).__init__() 21 | 22 | self.trg_embedding = trg_embedding 23 | 24 | self.rnn_type = rnn_type 25 | self.hidden_size = hidden_size 26 | self.num_layers = num_layers 27 | self.input_size = embedding_size + hidden_size 28 | self.bidirectional_encoder = bidirectional_encoder 29 | 30 | self.dropout = nn.Dropout(dropout) 31 | # Build the RNN. 32 | self.rnn = self._build_rnn(rnn_type, 33 | input_size=self.input_size, 34 | hidden_size=hidden_size, 35 | num_layers=num_layers, 36 | dropout=dropout) 37 | 38 | 39 | self.attn = GlobalAttention(hidden_size, 40 | attn_type=attn_type) 41 | 42 | def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout): 43 | if rnn_type == "LSTM": 44 | stacked_cell = StackedLSTM 45 | elif rnn_type == "GRU": 46 | stacked_cell = StackedGRU 47 | else: 48 | raise NotImplementedError 49 | return stacked_cell(num_layers, input_size, 50 | hidden_size, dropout) 51 | 52 | def init_decoder_state(self, encoder_state): 53 | assert self.hidden_size == encoder_state[-1].size(-1) 54 | if isinstance(encoder_state, tuple): 55 | # LSTM: encoder_state = (hidden, state) 56 | return RNNDecoderState(encoder_state) 57 | else: 58 | # GRU: encoder_state = state 59 | return RNNDecoderState(encoder_state) 60 | 61 | def forward(self, trg, encoder_outputs, lengths, state): 62 | """ 63 | Args: 64 | trg (`LongTensor`): sequences of padded tokens 65 | `[L_t x B x D]`. 66 | encoder_outputs (`FloatTensor`): vectors from the encoder 67 | `[L_s x B x H]`. 68 | 69 | lengths (`LongTensor`): the padded source lengths 70 | `[B]`. 71 | state (`DecoderState`): 72 | decoder state object to initialize the decoder 73 | Returns: 74 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 75 | * decoder_outputs: output from the decoder (after attn) 76 | `[trg_len x batch x hidden]`. 77 | * decoder_state: final hidden state from the decoder 78 | * attns: distribution over source words at each target word 79 | `[L_t x B x L_s]`. 80 | """ 81 | # Run the forward pass of the RNN. 82 | decoder_outputs = [] 83 | attns = [] 84 | 85 | encoder_outputs = encoder_outputs.transpose(0, 1) 86 | trg_embed = self.trg_embedding(trg) 87 | 88 | 89 | # iterate over each target word 90 | for t, embed in enumerate(trg_embed.split(1, dim=0)): 91 | output, attn, state = self.forward_step( 92 | embed, encoder_outputs, lengths, state) 93 | decoder_output = self.dropout(output) 94 | decoder_outputs.append(decoder_output) 95 | attns.append(attn) 96 | 97 | 98 | # Concatenates sequence of tensors along a new dimension. 99 | decoder_outputs = torch.stack(decoder_outputs) 100 | attns = torch.stack(attns) 101 | 102 | return decoder_outputs, state, attns 103 | 104 | 105 | 106 | 107 | 108 | def forward_step(self, trg_embed, encoder_outputs, lengths, state): 109 | """ 110 | Input feed concatenates hidden state with input at every time step. 111 | 112 | Args: 113 | trg_embed (LongTensor): each target token 114 | `[1 x B x H]`. 115 | encoder_outputs (`FloatTensor`): vectors from the encoder 116 | `[ B x L_s x H]`. 117 | 118 | lengths (`LongTensor`): the padded source lengths 119 | `[B]`. 120 | state (`DecoderState`): 121 | decoder state object to initialize the decoder 122 | """ 123 | 124 | input_feed = state.input_feed # [1 x B x H] 125 | 126 | # teacher forcing 127 | rnn_input = torch.cat([trg_embed, input_feed], -1) 128 | rnn_state = state.rnn_state 129 | 130 | rnn_input = rnn_input.squeeze(0) 131 | # update rnn state and feed to next RNNCell 132 | rnn_output, rnn_state = self.rnn(rnn_input, rnn_state) 133 | 134 | attn_output, attn = self.attn( 135 | rnn_output, encoder_outputs, lengths=lengths) 136 | 137 | # update decoder state and input feed feed 138 | state = state.update_state(rnn_state, attn_output.unsqueeze(0)) 139 | 140 | return attn_output, attn, state -------------------------------------------------------------------------------- /NMT/Models/RNN/DecoderState.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class RNNDecoderState(object): 8 | def __init__(self, rnn_state, beam_size=1): 9 | """ 10 | Args: 11 | rnn_state: final hidden state from the encoder. 12 | transformed to shape: layers x batch x (directions*dim). 13 | """ 14 | if isinstance(rnn_state, tuple): # LSTM 15 | self.rnn_state = rnn_state 16 | else: # GRU 17 | self.rnn_state = (rnn_state, ) 18 | 19 | self.beam_size = beam_size 20 | 21 | # Init the input feed [1 x B x H]. 22 | self.input_feed = \ 23 | Variable( 24 | self.rnn_state[0].data.new_zeros( 25 | self.rnn_state[0][-1].size(), 26 | requires_grad=False) 27 | ).unsqueeze(0) 28 | 29 | 30 | def beam_update_state(self, idx, positions): 31 | 32 | p = self.input_feed.view( 33 | self.input_feed.size(0), 34 | self.beam_size, 35 | self.input_feed.size(1) // self.beam_size, 36 | self.input_feed.size(2))[:, :, idx] 37 | 38 | p.data.copy_(p.data.index_select(1, positions)) 39 | 40 | for h in self.rnn_state: 41 | p = h.view(h.size(0), 42 | self.beam_size, 43 | h.size(1) // self.beam_size, 44 | h.size(2))[:, :, idx] 45 | p.data.copy_(p.data.index_select(1, positions)) 46 | 47 | def update_state(self, rnn_state, input_feed): 48 | state = RNNDecoderState(rnn_state, self.beam_size) 49 | state.input_feed = input_feed 50 | return state 51 | 52 | def repeat_beam_size_times(self, beam_size): 53 | """ Repeat beam_size times along batch dimension. """ 54 | self.beam_size = beam_size 55 | repeat_func = lambda x: Variable(x.data.repeat(1, beam_size, 1)) 56 | self.rnn_state = tuple(repeat_func(h) for h in self.rnn_state) 57 | self.input_feed = repeat_func(self.input_feed) -------------------------------------------------------------------------------- /NMT/Models/RNN/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence as pack 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | 6 | class RNNEncoder(nn.Module): 7 | def __init__(self, 8 | src_embedding, 9 | rnn_type, 10 | embed_dim, 11 | hidden_size, 12 | enforce_sorted=True, 13 | num_layers=2, 14 | dropout=0.0, 15 | bidirectional=True): 16 | 17 | super(RNNEncoder, self).__init__() 18 | 19 | num_directions = 2 if bidirectional else 1 20 | assert hidden_size % num_directions == 0 21 | hidden_size = hidden_size // num_directions 22 | self.num_layers = num_layers 23 | self.enforce_sorted = enforce_sorted 24 | self.bidirectional = bidirectional 25 | self.rnn_type = rnn_type 26 | self.rnn = getattr(nn, rnn_type)(embed_dim, 27 | hidden_size=hidden_size, 28 | num_layers=num_layers, 29 | dropout=dropout, 30 | bidirectional=bidirectional) 31 | self.hidden_size = hidden_size * num_directions 32 | self.src_embedding = src_embedding 33 | 34 | def fix_final_state_size(self, final_state): 35 | def resize(h): 36 | # The encoder hidden is (layers*directions) x batch x dim. 37 | # We need to convert it to layers x batch x (directions*dim). 38 | if self.bidirectional: 39 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], -1) 40 | return h 41 | if self.bidirectional: 42 | if self.rnn_type == "GRU": 43 | final_state = resize(final_state) 44 | elif self.rnn_type == "LSTM": 45 | final_state = tuple(resize(h) for h in final_state) 46 | return final_state 47 | 48 | def forward(self, src, lengths=None): 49 | """ 50 | Args: 51 | src (`LongTensor`): sequences of padded tokens `[L_s x B]`. 52 | lengths (`LongTensor`): the padded source lengths `[B]`. 53 | Returns: 54 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 55 | * decoder_outputs: output from the decoder (after attn) 56 | `[L_t x B x H]`. 57 | * decoder_state: final hidden state from the decoder 58 | * attns: distribution over source words at each target word 59 | `[L_t x B x L_s]`. 60 | """ 61 | src_embed = self.src_embedding(src) 62 | if lengths is not None: 63 | packed = pack(src_embed, lengths.view(-1).tolist(), enforce_sorted=self.enforce_sorted) 64 | 65 | output, final_state = self.rnn(packed) 66 | 67 | if lengths is not None: 68 | output = unpack(output)[0] 69 | 70 | final_state = self.fix_final_state_size(final_state) 71 | 72 | return output, final_state -------------------------------------------------------------------------------- /NMT/Models/RNN/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from Utils.log import trace 7 | from .Encoder import RNNEncoder 8 | from .Decoder import RNNDecoder 9 | 10 | 11 | class RNNModel(nn.Module): 12 | """ 13 | Core RNN model for NMT. 14 | { 15 | RNN Encoder + RNN Decoder. 16 | } 17 | """ 18 | def __init__(self, src_embedding, trg_embedding, 19 | trg_vocab_size, config): 20 | super(RNNModel, self).__init__() 21 | 22 | self.encoder = RNNEncoder( 23 | src_embedding, 24 | config.rnn_type, 25 | config.src_embed_dim, 26 | config.hidden_size, 27 | 'decreasing' == config.mini_batch_sort_order, # True if decreasing, else false 28 | config.enc_num_layers, 29 | config.dropout, 30 | config.bidirectional) 31 | 32 | self.decoder = RNNDecoder( 33 | trg_embedding, 34 | config.rnn_type, 35 | config.trg_embed_dim, 36 | config.hidden_size, 37 | config.dec_num_layers, 38 | config.attn_type, 39 | config.bidirectional, 40 | config.dropout) 41 | 42 | self.generator = nn.Linear(config.hidden_size, trg_vocab_size) 43 | self.config = config 44 | if self.training: 45 | self.param_init() 46 | 47 | def param_init(self): 48 | trace("Initializing model parameters.") 49 | for p in self.parameters(): 50 | p.data.uniform_(-0.1, 0.1) 51 | 52 | 53 | def translate_step(self, trg, encoder_outputs, lengths, decoder_state): 54 | 55 | encoder_outputs = encoder_outputs.transpose(0, 1) 56 | 57 | trg_embed = self.decoder.trg_embedding(trg) 58 | 59 | return self.decoder.forward_step( 60 | trg_embed, encoder_outputs, lengths, decoder_state) 61 | 62 | 63 | def forward(self, src, lengths, trg, state=None): 64 | """ 65 | Forward propagate a `src` and `trg` pair for training. 66 | Possible initialized with a beginning decoder state. 67 | 68 | Args: 69 | src (Tensor) : source sequence. [L_s x B]`. 70 | lengths (LongTensor): the src lengths, pre-padding `[B]`. 71 | trg (LongTensor): source sequence. [L_t x B]`. 72 | state (DecoderState): initial decoder state 73 | Returns: 74 | output `[trg_len x batch x hidden]` 75 | attention: distributions of `[L_t x B x L_s]` 76 | state (DecoderState): final decoder state 77 | """ 78 | # encoding side 79 | encoder_outputs, encoder_state = \ 80 | self.encoder(src, lengths) 81 | 82 | # encoder to decoder 83 | decoder_state = \ 84 | self.decoder.init_decoder_state(encoder_state) 85 | 86 | # decoding side 87 | decoder_outputs, decoder_state, attns = \ 88 | self.decoder(trg, encoder_outputs, lengths, decoder_state) 89 | 90 | return decoder_outputs, decoder_state, attns 91 | 92 | 93 | 94 | # Basic attributes. 95 | 96 | 97 | -------------------------------------------------------------------------------- /NMT/Models/RNN/Modules/StackedRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class StackedLSTM(nn.Module): 5 | """ 6 | Customed stacked LSTM, which we can custom the first layer size. 7 | """ 8 | def __init__(self, num_layers, input_size, hidden_size, dropout, residual=True): 9 | super(StackedLSTM, self).__init__() 10 | self.dropout = nn.Dropout(dropout) 11 | self.num_layers = num_layers 12 | self.residual = residual 13 | self.layers = nn.ModuleList() 14 | 15 | for i in range(num_layers): 16 | if i == 0: 17 | self.layers.append(nn.LSTMCell(input_size, hidden_size)) 18 | else: 19 | self.layers.append(nn.LSTMCell(hidden_size, hidden_size)) 20 | 21 | def forward(self, input, hidden): 22 | h_0, c_0 = hidden 23 | h_1, c_1 = [], [] 24 | for i, layer in enumerate(self.layers): 25 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 26 | input = h_1_i 27 | if i + 1 != self.num_layers: 28 | input = self.dropout(input) 29 | h_1 += [h_1_i] 30 | c_1 += [c_1_i] 31 | 32 | h_1 = torch.stack(h_1) 33 | c_1 = torch.stack(c_1) 34 | return input, (h_1, c_1) 35 | 36 | 37 | class StackedGRU(nn.Module): 38 | 39 | def __init__(self, num_layers, input_size, hidden_size, dropout, residual=True): 40 | super(StackedGRU, self).__init__() 41 | self.dropout = nn.Dropout(dropout) 42 | self.num_layers = num_layers 43 | self.layers = nn.ModuleList() 44 | self.residual = residual 45 | for i in range(num_layers): 46 | if i == 0: 47 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 48 | else: 49 | self.layers.append(nn.GRUCell(hidden_size, hidden_size)) 50 | 51 | 52 | def forward(self, input, hidden): 53 | """ 54 | Args: 55 | input (FloatTensor): [B x H]. 56 | hidden: [B x H]. 57 | """ 58 | assert len(input.size()) == 2 59 | 60 | h_1 = [] 61 | for i, layer in enumerate(self.layers): 62 | h_1_i = layer(input, hidden[0][i]) 63 | # if self.residual and 0 < i < self.num_layers-1: 64 | # input = h_1_i + input 65 | # else: 66 | input = h_1_i 67 | if i + 1 != self.num_layers: 68 | input = self.dropout(input) 69 | h_1 += [h_1_i] 70 | 71 | h_1 = torch.stack(h_1) 72 | return input, (h_1,) 73 | -------------------------------------------------------------------------------- /NMT/Models/RNN/Modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .StackedRNN import StackedLSTM, StackedGRU 2 | -------------------------------------------------------------------------------- /NMT/Models/RNN/__init__.py: -------------------------------------------------------------------------------- 1 | from .Model import RNNModel -------------------------------------------------------------------------------- /NMT/Models/RNN/__pycache__/Decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/RNN/__pycache__/Decoder.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/RNN/__pycache__/DecoderState.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/RNN/__pycache__/DecoderState.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/RNN/__pycache__/Encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/RNN/__pycache__/Encoder.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/RNN/__pycache__/Model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/RNN/__pycache__/Model.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/RNN/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/RNN/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from NMT.Modules import GlobalAttention 4 | from .Modules import MultiHeadedAttention 5 | from .Modules import PositionwiseFeedForward 6 | from torch.autograd import Variable 7 | from .DecoderState import TransformerDecoderState 8 | import numpy as np 9 | MAX_SIZE = 5000 10 | 11 | 12 | class TransformerDecoderLayer(nn.Module): 13 | def __init__(self, size, dropout, 14 | num_heads=8, hidden_size=1024): 15 | 16 | super(TransformerDecoderLayer, self).__init__() 17 | self.self_attn = MultiHeadedAttention( 18 | num_heads, size, dropout=dropout) 19 | self.context_attn = MultiHeadedAttention( 20 | num_heads, size, dropout=dropout) 21 | self.feed_forward = PositionwiseFeedForward( 22 | size, hidden_size, dropout) 23 | self.layer_norm_1 = nn.LayerNorm(size, eps=1e-6) 24 | self.layer_norm_2 = nn.LayerNorm(size, eps=1e-6) 25 | self.dropout = dropout 26 | self.drop = nn.Dropout(dropout) 27 | mask = self._get_attn_subsequent_mask(MAX_SIZE) 28 | # Register self.mask as a buffer in TransformerDecoderLayer, so 29 | # it gets TransformerDecoderLayer's cuda behavior automatically. 30 | self.register_buffer('mask', mask) 31 | 32 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, 33 | previous_input=None): 34 | 35 | dec_mask = torch.gt(tgt_pad_mask + 36 | self.mask[:, :tgt_pad_mask.size(1), 37 | :tgt_pad_mask.size(1)], 0) 38 | input_norm = self.layer_norm_1(inputs) 39 | all_input = input_norm 40 | if previous_input is not None: 41 | all_input = torch.cat((previous_input, input_norm), dim=1) 42 | dec_mask = None 43 | query, attn = self.self_attn(all_input, all_input, input_norm, 44 | mask=dec_mask) 45 | query = self.drop(query) + inputs 46 | 47 | query_norm = self.layer_norm_2(query) 48 | mid, attn = self.context_attn(memory_bank, memory_bank, query_norm, 49 | mask=src_pad_mask) 50 | output = self.feed_forward(self.drop(mid) + query) 51 | 52 | return output, attn, all_input 53 | 54 | def _get_attn_subsequent_mask(self, size): 55 | ''' Get an attention mask to avoid using the subsequent info.''' 56 | attn_shape = (1, size, size) 57 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 58 | subsequent_mask = torch.from_numpy(subsequent_mask) 59 | return subsequent_mask 60 | 61 | class TransformerDecoder(nn.Module): 62 | """ 63 | The Transformer decoder from "Attention is All You Need". 64 | """ 65 | def __init__(self, trg_embedding, 66 | trg_embed_dim, 67 | hidden_size, 68 | inner_hidden_size, 69 | num_layers=4, 70 | attn_type="general", 71 | dropout=0.0, 72 | num_heads=8, 73 | padding_idx=0): 74 | super(TransformerDecoder, self).__init__() 75 | 76 | self.trg_embedding = trg_embedding 77 | self.padding_idx = padding_idx 78 | self.num_layers = num_layers 79 | 80 | self.transformer_layers = \ 81 | nn.ModuleList( 82 | [TransformerDecoderLayer( 83 | trg_embed_dim, dropout, 84 | num_heads, inner_hidden_size)\ 85 | for _ in range(num_layers)]) 86 | 87 | self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) 88 | 89 | 90 | def init_decoder_state(self, src): 91 | return TransformerDecoderState(Variable(src)) 92 | 93 | def forward(self, trg, encoder_outputs, src_lengths, state): 94 | """ 95 | Args: 96 | trg (`LongTensor`): sequences of padded tokens 97 | `[L_t x B ]`. 98 | encoder_outputs (`FloatTensor`): vectors from the encoder 99 | `[L_s x B x H]`. 100 | 101 | lengths (`LongTensor`): the padded source lengths 102 | `[B]`. 103 | state (`DecoderState`): 104 | decoder state object to initialize the decoder 105 | Returns: 106 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 107 | * decoder_outputs: output from the decoder (after attn) 108 | `[trg_len x batch x hidden]`. 109 | * decoder_state: final hidden state from the decoder 110 | * attns: distribution over source words at each target word 111 | `[L_t x B x L_s]`. 112 | """ 113 | src_words = state.src # [L_s, B] 114 | src_words = src_words.transpose(0,1).contiguous() 115 | src_batch, src_len = src_words.size() 116 | 117 | trg_words = trg 118 | trg_words = trg_words.transpose(0,1).contiguous() 119 | trg_batch, trg_len = trg_words.size() 120 | assert (trg_batch == src_batch) 121 | 122 | 123 | if state.input is not None: 124 | trg = torch.cat([state.input, trg], 0) 125 | 126 | encoder_outputs = encoder_outputs.transpose(0, 1) 127 | 128 | # Run the forward pass of the TransformerDecoder. 129 | trg_embed = self.trg_embedding(trg) 130 | if state.input is not None: 131 | trg_embed = trg_embed[-1:, ] 132 | # assert trg_embed.dim() == 3 # len x batch x embedding_dim 133 | 134 | output = trg_embed.transpose(0, 1).contiguous() 135 | # B, L_t, H 136 | src_pad_mask = src_words.data\ 137 | .eq(self.padding_idx)\ 138 | .unsqueeze(1) \ 139 | .expand(src_batch, trg_len, src_len) 140 | 141 | 142 | trg_pad_mask = trg_words.data\ 143 | .eq(self.padding_idx)\ 144 | .unsqueeze(1) \ 145 | .expand(trg_batch, trg_len, trg_len) 146 | 147 | 148 | 149 | saved_inputs = [] 150 | for i in range(self.num_layers): 151 | input = None 152 | if state.input is not None: 153 | input = state.previous_inputs[i] 154 | output, attn, all_input \ 155 | = self.transformer_layers[i]( 156 | output, encoder_outputs, 157 | src_pad_mask, trg_pad_mask, 158 | previous_input=input) 159 | saved_inputs.append(all_input) 160 | saved_inputs = torch.stack(saved_inputs) 161 | output = self.layer_norm(output) 162 | 163 | # Process the result and update the attentions. 164 | output = output.transpose(0, 1).contiguous() 165 | attn = attn.transpose(0, 1).contiguous() 166 | 167 | # Update the state. 168 | state = state.update_state(trg, saved_inputs) 169 | return output, state, attn 170 | -------------------------------------------------------------------------------- /NMT/Models/Transformer/DecoderState.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class TransformerDecoderState(object): 8 | def __init__(self, src, input=None, previous_inputs=None, beam_size=1): 9 | """ 10 | Args: 11 | src (FloatTensor): a sequence of source words tensors 12 | with optional feature tensors, of size (len x batch). 13 | """ 14 | self.src = src 15 | self.input = input 16 | self.previous_inputs = previous_inputs 17 | self.beam_size = beam_size 18 | 19 | def update_state(self, input, previous_inputs): 20 | """ 21 | layer_input : [L_t, B]" 22 | previous_layer_inputs: [[L_t, B], ...] 23 | """ 24 | state = TransformerDecoderState(self.src, 25 | input, 26 | previous_inputs, 27 | self.beam_size) 28 | return state 29 | 30 | def beam_update_state(self, idx, positions): 31 | """ 32 | idx : id in batch 33 | positions: updateing positions 34 | """ 35 | batch_size = self.input.size(1) // self.beam_size 36 | 37 | p = self.input.view(self.input.size(0), 38 | self.beam_size, 39 | batch_size)[:, :, idx] 40 | 41 | p.data.copy_(p.data.index_select(1, positions)) 42 | 43 | #print(self.previous_inputs.size()) 44 | pl = self.previous_inputs.view(self.previous_inputs.size(0), 45 | self.beam_size, 46 | batch_size, 47 | self.previous_inputs.size(2), 48 | self.previous_inputs.size(3))[:, :, idx] 49 | 50 | pl.data.copy_(pl.data.index_select(1, positions)) 51 | 52 | 53 | 54 | def repeat_beam_size_times(self, beam_size): 55 | """ Repeat beam_size times along batch dimension. """ 56 | self.beam_size = beam_size 57 | self.src = Variable( 58 | self.src.data.repeat(1, self.beam_size)) 59 | 60 | -------------------------------------------------------------------------------- /NMT/Models/Transformer/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence as pack 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | from .Modules import MultiHeadedAttention 6 | from .Modules import PositionwiseFeedForward 7 | 8 | 9 | class TransformerEncoderLayer(nn.Module): 10 | """ 11 | A single layer of the transformer encoder. 12 | """ 13 | 14 | def __init__(self, size, dropout, num_heads=8, hidden_size=1024): 15 | super(TransformerEncoderLayer, self).__init__() 16 | 17 | self.self_attn = MultiHeadedAttention( 18 | num_heads, size, dropout=dropout) 19 | self.ff = PositionwiseFeedForward( 20 | size, hidden_size, dropout) 21 | self.layer_norm = nn.LayerNorm(size, eps=1e-6) 22 | self.dropout = nn.Dropout(dropout) 23 | 24 | def forward(self, input, mask): 25 | input_norm = self.layer_norm(input) 26 | 27 | context, _ = self.self_attn( 28 | input_norm, input_norm, input_norm,mask=mask) 29 | 30 | # residual 31 | out = self.dropout(context) + input 32 | return self.ff(out) 33 | 34 | class TransformerEncoder(nn.Module): 35 | """ 36 | Args: 37 | num_layers (int): number of encoder layers 38 | hidden_size (int): number of hidden units 39 | dropout (float): dropout parameters 40 | embeddings (:obj:`onmt.modules.Embeddings`): 41 | embeddings to use, should have positional encodings 42 | """ 43 | def __init__(self, 44 | src_embedding, 45 | embed_dim, 46 | hidden_size, 47 | inner_hidden_size, 48 | num_layers=2, 49 | dropout=0.0, 50 | num_heads=8, 51 | padding_idx=0): 52 | 53 | super(TransformerEncoder, self).__init__() 54 | self.src_embedding = src_embedding 55 | self.num_layers = num_layers 56 | self.padding_idx = padding_idx 57 | self.inner_hidden_size = inner_hidden_size 58 | self.transformer = \ 59 | nn.ModuleList( 60 | [ 61 | TransformerEncoderLayer(embed_dim, dropout, num_heads, inner_hidden_size)\ 62 | for i in range(num_layers) 63 | ]) 64 | self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) 65 | 66 | def make_mask(self, src): 67 | words = src.transpose(0, 1) 68 | mask = words.data.eq(self.padding_idx) 69 | mask = mask.unsqueeze(1).repeat(1, words.size(-1), 1) 70 | return mask 71 | 72 | def forward(self, src, lengths=None): 73 | src_embed = self.src_embedding(src) 74 | src_mask = self.make_mask(src) 75 | 76 | # Run the forward pass of every layer of the tranformer. 77 | output = src_embed.transpose(0, 1) 78 | for i in range(self.num_layers): 79 | output = self.transformer[i](output, src_mask) 80 | output = self.layer_norm(output) 81 | return output.transpose(0, 1), src -------------------------------------------------------------------------------- /NMT/Models/Transformer/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.init import xavier_uniform_ 6 | 7 | from .Encoder import TransformerEncoder 8 | from .Decoder import TransformerDecoder 9 | from .DecoderState import TransformerDecoderState 10 | from .Modules import PositionalEncoding 11 | 12 | # class BaseTransformerModel(nn.Module): 13 | # """ 14 | # Core RNN model for NMT. 15 | # { 16 | # Transformer Encoder + Transformer Decoder. 17 | # } 18 | # """ 19 | # def __init__(self, src_embedding, trg_embedding, 20 | # trg_vocab_size, padding_idx, config): 21 | # super(BaseTransformerModel, self).__init__() 22 | class TransformerModel(nn.Module): 23 | """ 24 | Core RNN model for NMT. 25 | { 26 | Transformer Encoder + Transformer Decoder. 27 | } 28 | """ 29 | def __init__(self, src_embedding, trg_embedding, 30 | trg_vocab_size, padding_idx, config): 31 | super(TransformerModel, self).__init__() 32 | self.padding_idx = padding_idx 33 | 34 | self.src_embedding = nn.Sequential( 35 | src_embedding, 36 | PositionalEncoding(config.dropout, config.src_embed_dim) 37 | ) 38 | 39 | self.trg_embedding = nn.Sequential( 40 | trg_embedding, 41 | PositionalEncoding(config.dropout, config.trg_embed_dim) 42 | ) 43 | 44 | self.encoder = TransformerEncoder( 45 | self.src_embedding, 46 | config.src_embed_dim, 47 | config.hidden_size, 48 | config.inner_hidden_size, 49 | num_layers=config.enc_num_layers, 50 | dropout=config.dropout, 51 | num_heads=config.num_heads, 52 | padding_idx=padding_idx) 53 | 54 | 55 | self.decoder = TransformerDecoder( 56 | self.trg_embedding, 57 | config.trg_embed_dim, 58 | config.hidden_size, 59 | config.inner_hidden_size, 60 | num_layers=config.dec_num_layers, 61 | attn_type=config.attn_type, 62 | dropout=config.dropout, 63 | num_heads=config.num_heads, 64 | padding_idx=0) 65 | 66 | 67 | 68 | self.generator = nn.Linear(config.hidden_size, trg_vocab_size) 69 | self.config = config 70 | if self.training: 71 | self.param_init() 72 | 73 | def param_init(self): 74 | for p in self.parameters(): 75 | if p.dim() > 1: 76 | xavier_uniform_(p) 77 | 78 | def make_mask(self, input): 79 | words = input.transpose(0, 1) 80 | mask = words.data.eq(self.padding_idx) 81 | mask = mask.unsqueeze(1).repeat(1, words.size(-1), 1) 82 | return mask 83 | 84 | 85 | 86 | def decode(self, trg, encoder_outputs, src_lengths, decoder_state): 87 | 88 | return self.decoder(trg, encoder_outputs, 89 | src_lengths, decoder_state) 90 | 91 | def translate_step(self, trg, encoder_outputs, lengths, decoder_state): 92 | output, state, attn = \ 93 | self.decoder(trg, encoder_outputs, lengths, decoder_state) 94 | #print(output.size(), attn.size()) 95 | return output.squeeze(0), attn.squeeze(0), state 96 | 97 | def forward(self, src, src_lengths, trg, decoder_state=None): 98 | # encoding side 99 | encoder_outputs, src = self.encoder(src, src_lengths) 100 | 101 | # encoder to decoder 102 | if decoder_state is None: 103 | decoder_state = self.decoder.init_decoder_state(src) 104 | 105 | # decoding side 106 | decoder_outputs, decoder_state, attns = \ 107 | self.decoder(trg, encoder_outputs, src_lengths, decoder_state) 108 | return decoder_outputs, attns, decoder_state 109 | -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/MultiHeadedAttn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | 8 | class MultiHeadedAttention(nn.Module): 9 | def __init__(self, num_heads, model_dim, dropout=0.1): 10 | assert model_dim % num_heads == 0 11 | self.dim_per_head = model_dim // num_heads 12 | self.model_dim = model_dim 13 | 14 | super(MultiHeadedAttention, self).__init__() 15 | self.num_heads = num_heads 16 | 17 | self.linear_keys = nn.Linear(model_dim, 18 | num_heads * self.dim_per_head) 19 | self.linear_values = nn.Linear(model_dim, 20 | num_heads * self.dim_per_head) 21 | self.linear_query = nn.Linear(model_dim, 22 | num_heads * self.dim_per_head) 23 | self.sm = nn.Softmax(dim=-1) 24 | self.dropout = nn.Dropout(dropout) 25 | self.final_linear = nn.Linear(model_dim, model_dim) 26 | 27 | def forward(self, key, value, query, mask=None): 28 | """ 29 | Compute the context vector and the attention vectors. 30 | Args: 31 | key (`FloatTensor`): set of `key_len` 32 | key vectors `[batch, key_len, dim]` 33 | value (`FloatTensor`): set of `key_len` 34 | value vectors `[batch, key_len, dim]` 35 | query (`FloatTensor`): set of `query_len` 36 | query vectors `[batch, query_len, dim]` 37 | mask: binary mask indicating which keys have 38 | non-zero attention `[batch, query_len, key_len]` 39 | Returns: 40 | (`FloatTensor`, `FloatTensor`) : 41 | * output context vectors `[batch, query_len, dim]` 42 | * one of the attention vectors `[batch, query_len, key_len]` 43 | """ 44 | batch_size = key.size(0) 45 | dim_per_head = self.dim_per_head 46 | num_heads = self.num_heads 47 | key_len = key.size(1) 48 | query_len = query.size(1) 49 | 50 | def shape(x): 51 | return x.view(batch_size, -1, num_heads, dim_per_head) \ 52 | .transpose(1, 2) 53 | 54 | def unshape(x): 55 | return x.transpose(1, 2).contiguous() \ 56 | .view(batch_size, -1, num_heads * dim_per_head) 57 | 58 | # 1) Project key, value, and query. 59 | key_up = shape(self.linear_keys(key)) 60 | value_up = shape(self.linear_values(value)) 61 | query_up = shape(self.linear_query(query)) 62 | 63 | # 2) Calculate and scale scores. 64 | query_up = query_up / math.sqrt(dim_per_head) 65 | scores = torch.matmul(query_up, key_up.transpose(2, 3)) 66 | 67 | if mask is not None: 68 | mask = mask.unsqueeze(1).expand_as(scores) 69 | scores = scores.masked_fill(Variable(mask), -1e18) 70 | 71 | # 3) Apply attention dropout and compute context vectors. 72 | attn = self.sm(scores) 73 | drop_attn = self.dropout(attn) 74 | context = unshape(torch.matmul(drop_attn, value_up)) 75 | 76 | output = self.final_linear(context) 77 | 78 | top_attn = attn \ 79 | .view(batch_size, num_heads, 80 | query_len, key_len)[:, 0, :, :] \ 81 | .contiguous() 82 | 83 | return output, top_attn -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/PositionalEncoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch.autograd import Variable 5 | 6 | MAX_SIZE = 5000 7 | 8 | class PositionalEncoding(nn.Module): 9 | """ 10 | Args: 11 | dropout (float): 12 | dim (int): embedding size 13 | """ 14 | def __init__(self, dropout, dim, max_len=MAX_SIZE): 15 | pe = torch.zeros(max_len, dim) 16 | position = torch.arange(0, max_len).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, dim, 2) * 18 | -(math.log(10000.0) / dim)) 19 | pe[:, 0::2] = torch.sin(position * div_term) 20 | pe[:, 1::2] = torch.cos(position * div_term) 21 | pe = pe.unsqueeze(1) 22 | super(PositionalEncoding, self).__init__() 23 | self.register_buffer('pe', pe) 24 | self.dropout = nn.Dropout(p=dropout) 25 | self.dim = dim 26 | 27 | def forward(self, emb): 28 | 29 | emb = emb * math.sqrt(self.dim) 30 | emb = emb + Variable(self.pe[:emb.size(0)], requires_grad=False) 31 | emb = self.dropout(emb) 32 | return emb 33 | -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/PositionwiseFF.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | class PositionwiseFeedForward(nn.Module): 6 | """ A two-layer Feed-Forward-Network with residual layer norm. 7 | """ 8 | 9 | def __init__(self, size, hidden_size, dropout=0.1): 10 | super(PositionwiseFeedForward, self).__init__() 11 | self.pw_ff1 = nn.Sequential( 12 | nn.LayerNorm(size, eps=1e-06), 13 | nn.Linear(size, hidden_size), 14 | nn.ReLU(), 15 | nn.Dropout(dropout) 16 | ) 17 | self.pw_ff2 = nn.Sequential( 18 | nn.Linear(hidden_size, size), 19 | nn.Dropout(dropout)) 20 | 21 | def forward(self, x): 22 | inter = self.pw_ff1(x) 23 | output = self.pw_ff2(inter) 24 | return output + x -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .MultiHeadedAttn import MultiHeadedAttention 2 | from .PositionwiseFF import PositionwiseFeedForward 3 | from .PositionalEncoding import PositionalEncoding -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/__pycache__/MultiHeadedAttn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/Modules/__pycache__/MultiHeadedAttn.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/__pycache__/PositionalEncoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/Modules/__pycache__/PositionalEncoding.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/__pycache__/PositionwiseFF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/Modules/__pycache__/PositionwiseFF.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/Modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/Modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/VModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.init import xavier_uniform_ 6 | 7 | 8 | from .BaseModel import BaseTransformerModel 9 | 10 | class TransformerModel(BaseTransformerModel): 11 | """ 12 | Core RNN model for NMT. 13 | { 14 | Transformer Encoder + Transformer Decoder. 15 | } 16 | """ 17 | def __init__(self, src_embedding, trg_embedding, 18 | trg_vocab_size, padding_idx, config): 19 | super(TransformerModel, self).__init__(src_embedding, trg_embedding, 20 | trg_vocab_size, padding_idx, config) 21 | self.context_to_mu = nn.Linear( 22 | config.hidden_size, 23 | config.latent_size) 24 | self.context_to_logvar = nn.Linear( 25 | config.hidden_size, 26 | config.latent_size) 27 | if self.training: 28 | self.param_init() 29 | 30 | def reparameterize(self, encoder_outputs): 31 | """ 32 | context [B x 2H] 33 | """ 34 | #hidden = Variable(encoder_outputs.data, requires_grad=False) 35 | hidden = encoder_outputs 36 | mu = self.context_to_mu(hidden) 37 | logvar = self.context_to_logvar(hidden) 38 | if self.training: 39 | std = torch.exp(0.5*logvar) 40 | eps = torch.randn_like(std) 41 | z = eps.mul(std).add_(mu) 42 | else: 43 | z = mu 44 | return z, mu, logvar 45 | 46 | def translate_step(self, trg, encoder_outputs, lengths, decoder_state): 47 | output, state, attn = \ 48 | self.decoder(trg, encoder_outputs, lengths, decoder_state) 49 | return output.squeeze(0), attn.squeeze(0), state 50 | 51 | def forward(self, src, src_lengths, trg, decoder_state=None): 52 | # encoding side 53 | encoder_outputs, src = self.encoder(src, src_lengths) 54 | 55 | encoder_outputs, mu, logvar = self.reparameterize(encoder_outputs) 56 | kld = 0. 57 | if self.training: 58 | kld = self.compute_kld(mu, logvar) 59 | # encoder to decoder 60 | if decoder_state is None: 61 | decoder_state = self.decoder.init_decoder_state(src) 62 | 63 | # decoding side 64 | decoder_outputs, decoder_state, attns = \ 65 | self.decoder(trg, encoder_outputs, src_lengths, decoder_state) 66 | return decoder_outputs, attns, decoder_state, kld 67 | 68 | def compute_kld(self, mu, logvar): 69 | kld = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) 70 | return kld -------------------------------------------------------------------------------- /NMT/Models/Transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Model import TransformerModel -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/BaseModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/BaseModel.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/Decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/Decoder.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/DecoderState.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/DecoderState.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/Encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/Encoder.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/Model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/Model.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/Utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/Utils.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/Transformer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/Transformer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /NMT/Models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Models/backup/Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from NMT.Modules import StackedLSTM 5 | from NMT.Modules import StackedGRU 6 | from NMT.Modules import GlobalAttention 7 | 8 | 9 | 10 | 11 | class RNNDecoderBase(nn.Module): 12 | def __init__(self, rnn_type, 13 | input_size, 14 | hidden_size, 15 | num_layers=2, 16 | attn_type="general", 17 | bidirectional_encoder=True, 18 | dropout=0.0, 19 | embeddings=None): 20 | 21 | super(RNNDecoderBase, self).__init__() 22 | 23 | # Basic attributes. 24 | self.rnn_type = rnn_type 25 | self.decoder_type = 'rnn' 26 | self.bidirectional_encoder = bidirectional_encoder 27 | self.num_layers = num_layers 28 | self.input_size = input_size 29 | self.hidden_size = hidden_size 30 | self.embeddings = embeddings 31 | self.dropout = nn.Dropout(dropout) 32 | # Build the RNN. 33 | self.rnn = self._build_rnn(rnn_type, 34 | input_size=input_size, 35 | hidden_size=hidden_size, 36 | num_layers=num_layers, 37 | dropout=dropout) 38 | 39 | 40 | self.attn = GlobalAttention( 41 | hidden_size, attn_type=attn_type 42 | ) 43 | 44 | def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout): 45 | if rnn_type == "LSTM": 46 | stacked_cell = StackedLSTM 47 | elif rnn_type == "GRU": 48 | stacked_cell = StackedGRU 49 | else: 50 | raise NotImplementedError 51 | return stacked_cell(num_layers, input_size, 52 | hidden_size, dropout) 53 | 54 | def forward(self, trg, encoder_outputs, lengths, state): 55 | """ 56 | Args: 57 | trg (`LongTensor`): sequences of padded tokens 58 | `[L_t x B x D]`. 59 | encoder_outputs (`FloatTensor`): vectors from the encoder 60 | `[L_s x B x H]`. 61 | 62 | lengths (`LongTensor`): the padded source lengths 63 | `[B]`. 64 | state (`DecoderState`): 65 | decoder state object to initialize the decoder 66 | Returns: 67 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 68 | * decoder_outputs: output from the decoder (after attn) 69 | `[trg_len x batch x hidden]`. 70 | * decoder_state: final hidden state from the decoder 71 | * attns: distribution over source words at each target word 72 | `[L_t x B x L_s]`. 73 | """ 74 | # Run the forward pass of the RNN. 75 | decoder_outputs, final_state, attns = self.forward_step( 76 | trg, encoder_outputs, lengths, state) 77 | 78 | # Update the state with the result. 79 | final_output = decoder_outputs[-1] 80 | state.update_state(final_state, final_output.unsqueeze(0)) 81 | 82 | # Concatenates sequence of tensors along a new dimension. 83 | decoder_outputs = torch.stack(decoder_outputs) 84 | for k in attns: 85 | attns[k] = torch.stack(attns[k]) 86 | 87 | return decoder_outputs, state, attns 88 | 89 | 90 | 91 | 92 | 93 | class VarInputFeedRNNDecoder(RNNDecoderBase): 94 | def __init__(self, rnn_type, 95 | embedding_size, 96 | hidden_size, 97 | latent_size, 98 | num_layers=2, 99 | attn_type="general", 100 | bidirectional_encoder=True, 101 | dropout=0.0): 102 | super(VarInputFeedRNNDecoder, self).__init__(rnn_type, 103 | embedding_size + hidden_size + latent_size, 104 | hidden_size, 105 | num_layers, 106 | attn_type, 107 | bidirectional_encoder, 108 | dropout) 109 | 110 | self.context_to_mu = nn.Linear( 111 | hidden_size, 112 | latent_size) 113 | self.context_to_logvar = nn.Linear( 114 | hidden_size, 115 | latent_size) 116 | 117 | def reparameterize(self, state): 118 | """ 119 | context [B x 2H] 120 | """ 121 | hidden = self.get_hidden(state) 122 | mu = self.context_to_mu(hidden) 123 | logvar = self.context_to_logvar(hidden) 124 | if self.training: 125 | std = logvar.mul(0.5).exp_() 126 | eps = Variable(std.data.new(std.size()).normal_()) 127 | z = eps.mul(std).add_(mu) 128 | else: 129 | z = mu 130 | return z, mu, logvar 131 | 132 | 133 | def get_hidden(self, state): 134 | hidden = None 135 | if self.rnn_type == "GRU": 136 | hidden = state[-1] 137 | elif self.rnn_type == "LSTM": 138 | hidden = state[0][-1] 139 | return hidden 140 | def compute_kld(self, mu, logvar): 141 | kld = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) 142 | return kld 143 | def forward_step(self, trg, encoder_outputs, lengths, dec_state): 144 | """ 145 | Input feed concatenates hidden state with input at every time step. 146 | 147 | Args: 148 | trg (`LongTensor`): sequences of padded tokens 149 | `[L_t x B x nfeats]`. 150 | encoder_outputs (`FloatTensor`): vectors from the encoder 151 | `[L_s x B x H]`. 152 | state (`DecoderState`): 153 | decoder state object to initialize the decoder 154 | lengths (`LongTensor`): the padded source lengths 155 | `[B]`. 156 | Returns: 157 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 158 | * decoder_outputs: output from the decoder (after attn) 159 | `[L_t x B x H]`. 160 | * decoder_state: final hidden state from the decoder 161 | * attns: distribution over source words at each target word 162 | `[L_t x B x L_s]`. 163 | """ 164 | 165 | input_feed = dec_state.input_feed.squeeze(0) # [B x H] 166 | 167 | decoder_outputs = [] 168 | attns = {"std": []} 169 | 170 | rnn_state = dec_state.state 171 | 172 | kld = 0. 173 | for t, emb_t in enumerate(trg.split(1, dim=0)): 174 | # iterate over each target word 175 | emb_t = emb_t.squeeze(0) 176 | # teacher forcing 177 | z, mu, logvar = self.reparameterize(rnn_state[0]) 178 | decoder_input = torch.cat([emb_t, input_feed, z], -1) 179 | kld += self.compute_kld(mu, logvar) 180 | # update state and feed to next RNNCell 181 | rnn_output, rnn_state = self.rnn(decoder_input, rnn_state) 182 | 183 | decoder_output, attn = self.attn( 184 | rnn_output, encoder_outputs.transpose(0, 1), lengths=lengths) 185 | 186 | decoder_output = self.dropout(decoder_output) 187 | 188 | input_feed = decoder_output 189 | decoder_outputs += [decoder_output] 190 | attns["std"] += [attn] 191 | 192 | # Return result. 193 | return decoder_outputs, rnn_state, attns, kld 194 | def forward(self, trg, encoder_outputs, lengths, state): 195 | """ 196 | Args: 197 | trg (`LongTensor`): sequences of padded tokens 198 | `[L_t x B x D]`. 199 | encoder_outputs (`FloatTensor`): vectors from the encoder 200 | `[L_s x B x H]`. 201 | 202 | lengths (`LongTensor`): the padded source lengths 203 | `[B]`. 204 | state (`DecoderState`): 205 | decoder state object to initialize the decoder 206 | Returns: 207 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 208 | * decoder_outputs: output from the decoder (after attn) 209 | `[trg_len x batch x hidden]`. 210 | * decoder_state: final hidden state from the decoder 211 | * attns: distribution over source words at each target word 212 | `[L_t x B x L_s]`. 213 | """ 214 | # Run the forward pass of the RNN. 215 | decoder_outputs, final_state, attns, kld = self.forward_step( 216 | trg, encoder_outputs, lengths, state) 217 | 218 | # Update the state with the result. 219 | final_output = decoder_outputs[-1] 220 | state.update_state(final_state, final_output.unsqueeze(0)) 221 | 222 | # Concatenates sequence of tensors along a new dimension. 223 | decoder_outputs = torch.stack(decoder_outputs) 224 | for k in attns: 225 | attns[k] = torch.stack(attns[k]) 226 | 227 | return decoder_outputs, state, attns, kld 228 | -------------------------------------------------------------------------------- /NMT/Models/backup/Encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | 7 | class PackedRNNEncoder(nn.Module): 8 | def __init__(self, rnn_type, embed_dim, 9 | hidden_size, num_layers=2, dropout=0.0, bidirectional=True): 10 | super(PackedRNNEncoder, self).__init__() 11 | 12 | num_directions = 2 if bidirectional else 1 13 | assert hidden_size % num_directions == 0 14 | hidden_size = hidden_size // num_directions 15 | self.num_layers = num_layers 16 | self.bidirectional = bidirectional 17 | self.rnn_type = rnn_type 18 | self.rnn = getattr(nn, rnn_type)(embed_dim, 19 | hidden_size=hidden_size, 20 | num_layers=num_layers, 21 | dropout=dropout, 22 | bidirectional=bidirectional) 23 | self.hidden_size = hidden_size * num_directions 24 | def fix_final_state(self, final_state): 25 | def resize(h): 26 | # The encoder hidden is (layers*directions) x batch x dim. 27 | # We need to convert it to layers x batch x (directions*dim). 28 | if self.bidirectional: 29 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], -1) 30 | return h 31 | if self.bidirectional: 32 | if self.rnn_type == "GRU": 33 | final_state = resize(final_state) 34 | elif self.rnn_type == "LSTM": 35 | final_state = tuple([resize(h) for h in final_state]) 36 | return final_state 37 | def forward(self, input, lengths=None, state=None): 38 | if lengths is not None: 39 | packed = pack_padded_sequence(input, lengths.view(-1).tolist()) 40 | output, final_state = self.rnn(packed, state) 41 | 42 | if lengths is not None: 43 | output = pad_packed_sequence(output)[0] 44 | return output, self.fix_final_state(final_state) 45 | 46 | class RNNEncoder(PackedRNNEncoder): 47 | def __init__(self, *arg, **kwargs): 48 | super(RNNEncoder, self).__init__(*arg, **kwargs) 49 | 50 | def forward(self, input, lengths=None, state=None): 51 | 52 | if lengths is not None: 53 | lengths, rank = torch.sort(lengths, dim=0, descending=True) 54 | input = input.index_select(1, rank) 55 | output, final_state = super(RNNEncoder, self).forward(input, lengths) 56 | _, order = torch.sort(rank, dim=0, descending=False) 57 | if isinstance(final_state, tuple): 58 | final_state = tuple(x.index_select(1, order) for x in final_state) 59 | else: 60 | final_state = final_state.index_select(1, order) 61 | return output.index_select(1, order), final_state 62 | 63 | -------------------------------------------------------------------------------- /NMT/Models/backup/VNMTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | from NMT.Models import NMTModel 7 | from NMT.Models.Decoders import RNNDecoderState 8 | 9 | def compute_kld(mu, logvar): 10 | kld = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) 11 | return kld 12 | 13 | 14 | class VNMTModel(NMTModel): 15 | """Recent work has found that VAE + LSTM decoders underperforms vanilla LSTM decoder. 16 | You should use VAE + GRU. 17 | see Yang et. al, ICML 2017, `Improved VAE for Text Modeling using Dilated Convolutions`. 18 | """ 19 | def __init__(self, 20 | encoder, decoder, 21 | src_embedding, trg_embedding, 22 | trg_vocab_size, 23 | config): 24 | super(VNMTModel, self).__init__( 25 | encoder, decoder, 26 | src_embedding, trg_embedding, 27 | trg_vocab_size, config) 28 | self.context_to_mu = nn.Linear( 29 | config.hidden_size, 30 | config.latent_size) 31 | self.context_to_logvar = nn.Linear( 32 | config.hidden_size, 33 | config.latent_size) 34 | self.lstm_state2context = nn.Linear( 35 | 2*config.hidden_size, 36 | config.latent_size) 37 | def get_hidden(self, state): 38 | hidden = None 39 | if self.encoder.rnn_type == "GRU": 40 | hidden = state[-1] 41 | elif self.encoder.rnn_type == "LSTM": 42 | hidden, context = state[0][-1], state[1][-1] 43 | hidden = self.lstm_state2context(torch.cat([hidden, context], -1)) 44 | return hidden 45 | 46 | def reparameterize(self, encoder_state): 47 | """ 48 | context [B x 2H] 49 | """ 50 | hidden = self.get_hidden(encoder_state).detach() 51 | mu = self.context_to_mu(hidden) 52 | logvar = self.context_to_logvar(hidden) 53 | if self.training: 54 | std = torch.exp(0.5*logvar) 55 | eps = torch.randn_like(std) 56 | z = eps.mul(std).add_(mu) 57 | else: 58 | z = mu 59 | return z, mu, logvar 60 | 61 | 62 | def forward(self, src, src_lengths, trg, trg_lengths=None, decoder_state=None): 63 | """ 64 | Forward propagate a `src` and `trg` pair for training. 65 | Possible initialized with a beginning decoder state. 66 | 67 | Args: 68 | src (Tensor): source sequence. [L x B x N]`. 69 | trg (LongTensor): source sequence. [L x B]`. 70 | src_lengths (LongTensor): the src lengths, pre-padding `[batch]`. 71 | trg_lengths (LongTensor): the trg lengths, pre-padding `[batch]`. 72 | dec_state (`DecoderState`, optional): initial decoder state 73 | z (`FloatTensor`): latent variables 74 | Returns: 75 | (:obj:`FloatTensor`, `dict`, :obj:`nmt.Models.DecoderState`): 76 | 77 | * decoder output `[trg_len x batch x hidden]` 78 | * dictionary attention dists of `[trg_len x batch x src_len]` 79 | * final decoder state 80 | """ 81 | 82 | # encoding side 83 | encoder_outputs, encoder_state = self.encode(src, src_lengths) 84 | 85 | # re-parameterize 86 | z, mu, logvar = self.reparameterize(encoder_state) 87 | # encoder to decoder 88 | decoder_state = self.encoder2decoder(encoder_state) 89 | 90 | trg_feed = trg[:-1] 91 | decoder_input = torch.cat([ 92 | self.trg_embedding(trg_feed), 93 | z.unsqueeze(0).repeat(trg_feed.size(0) ,1, 1)], 94 | -1) 95 | 96 | # decoding side 97 | decoder_outputs, decoder_state, attns = self.decoder( 98 | decoder_input, encoder_outputs, src_lengths, decoder_state) 99 | 100 | return decoder_outputs, decoder_state, attns, compute_kld(mu, logvar) 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /NMT/Models/backup/VRNMTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | from NMT.Models import NMTModel 7 | 8 | 9 | class VRNMTModel(NMTModel): 10 | def forward(self, src, lengths, trg, decoder_state=None): 11 | """ 12 | Forward propagate a `src` and `trg` pair for training. 13 | Possible initialized with a beginning decoder state. 14 | 15 | Args: 16 | src (Tensor): source sequence. [L x B x N]`. 17 | trg (LongTensor): source sequence. [L x B]`. 18 | lengths (LongTensor): the src lengths, pre-padding `[batch]`. 19 | dec_state (`DecoderState`, optional): initial decoder state 20 | Returns: 21 | (:obj:`FloatTensor`, `dict`, :obj:`nmt.Models.DecoderState`): 22 | 23 | * decoder output `[trg_len x batch x hidden]` 24 | * dictionary attention dists of `[trg_len x batch x src_len]` 25 | * final decoder state 26 | """ 27 | trg = trg[:-1] 28 | # encoding side 29 | encoder_outputs, encoder_state = self.encoder( 30 | self.src_embedding(src), lengths) 31 | 32 | # encoder to decoder 33 | decoder_state = self.encoder2decoder(encoder_state) 34 | 35 | decoder_input = self.trg_embedding(trg) 36 | # decoding side 37 | decoder_outputs, decoder_state, attns, kld = self.decoder( 38 | decoder_input, encoder_outputs, lengths, decoder_state) 39 | 40 | return decoder_outputs, decoder_state, attns, kld 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /NMT/Modules/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from Utils.utils import aeq, sequence_mask 5 | 6 | 7 | class GlobalAttention(nn.Module): 8 | """ 9 | Global attention takes a matrix and a query vector. It 10 | then computes a parameterized convex combination of the matrix 11 | based on the input query. 12 | 13 | * Luong Attention (dot, general): 14 | * dot: :math:`score(H_j,q) = H_j^T q` 15 | * general: :math:`score(H_j, q) = H_j^T W_a q` 16 | 17 | * Bahdanau Attention (mlp): 18 | * :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)` 19 | 20 | Args: 21 | dim (int): dimensionality of query and key 22 | coverage (bool): use coverage term 23 | attn_type (str): type of attention to use, options [dot,general,mlp] 24 | 25 | """ 26 | def __init__(self, dim, coverage=False, attn_type="dot"): 27 | super(GlobalAttention, self).__init__() 28 | 29 | self.dim = dim 30 | self.attn_type = attn_type 31 | assert (self.attn_type in ["dot", "general", "mlp"]), ( 32 | "Please select a valid attention type.") 33 | 34 | if self.attn_type == "general": 35 | self.linear_in = nn.Linear(dim, dim, bias=False) 36 | elif self.attn_type == "mlp": 37 | self.linear_context = nn.Linear(dim, dim, bias=False) 38 | self.linear_query = nn.Linear(dim, dim, bias=True) 39 | self.v = nn.Linear(dim, 1, bias=False) 40 | # mlp wants it with bias 41 | out_bias = self.attn_type == "mlp" 42 | self.linear_out = nn.Linear(dim*2, dim, bias=out_bias) 43 | 44 | self.sm = nn.Softmax(dim=-1) 45 | self.tanh = nn.Tanh() 46 | 47 | if coverage: 48 | self.linear_cover = nn.Linear(1, dim, bias=False) 49 | 50 | def score(self, h_t, h_s): 51 | """ 52 | Args: 53 | h_t (`FloatTensor`): sequence of queries `[batch x trg_len x dim]` 54 | h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` 55 | 56 | Returns: 57 | :obj:`FloatTensor`: 58 | raw attention scores (unnormalized) for each src index 59 | `[batch x trg_len x src_len]` 60 | 61 | """ 62 | 63 | # Check input sizes 64 | src_batch, src_len, src_dim = h_s.size() 65 | trg_batch, trg_len, trg_dim = h_t.size() 66 | aeq(src_batch, trg_batch) 67 | aeq(src_dim, trg_dim) 68 | aeq(self.dim, src_dim) 69 | 70 | if self.attn_type in ["general", "dot"]: 71 | if self.attn_type == "general": 72 | h_t_ = h_t.view(trg_batch*trg_len, trg_dim) 73 | h_t_ = self.linear_in(h_t_) 74 | h_t = h_t_.view(trg_batch, trg_len, trg_dim) 75 | h_s_ = h_s.transpose(1, 2) 76 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 77 | return torch.bmm(h_t, h_s_) 78 | else: 79 | dim = self.dim 80 | wq = self.linear_query(h_t.view(-1, dim)) 81 | wq = wq.view(trg_batch, trg_len, 1, dim) 82 | wq = wq.expand(trg_batch, trg_len, src_len, dim) 83 | 84 | uh = self.linear_context(h_s.contiguous().view(-1, dim)) 85 | uh = uh.view(src_batch, 1, src_len, dim) 86 | uh = uh.expand(src_batch, trg_len, src_len, dim) 87 | 88 | # (batch, t_len, s_len, d) 89 | wquh = self.tanh(wq + uh) 90 | 91 | return self.v(wquh.view(-1, dim)).view(trg_batch, trg_len, src_len) 92 | 93 | def forward(self, input, memory_bank, lengths=None): 94 | """ 95 | 96 | Args: 97 | input (`FloatTensor`): query vectors `[batch x trg_len x dim]` 98 | memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` 99 | lengths (`LongTensor`): the source context lengths `[batch]` 100 | Returns: 101 | (`FloatTensor`, `FloatTensor`): 102 | 103 | * Computed vector `[trg_len x batch x dim]` 104 | * Attention distribtutions for each query 105 | `[trg_len x batch x src_len]` 106 | """ 107 | 108 | # one step input 109 | if input.dim() == 2: 110 | one_step = True 111 | input = input.unsqueeze(1) 112 | else: 113 | one_step = False 114 | 115 | batch, sourceL, dim = memory_bank.size() 116 | batch_, targetL, dim_ = input.size() 117 | 118 | # compute attention scores, as in Luong et al. 119 | align = self.score(input, memory_bank) 120 | 121 | if lengths is not None: 122 | mask = sequence_mask(lengths) 123 | mask = mask.unsqueeze(1) # Make it broadcastable. 124 | align.data.masked_fill_(1 - mask, -float('inf')) 125 | 126 | # Softmax to normalize attention weights 127 | align_vectors = self.sm(align.view(batch*targetL, sourceL)) 128 | align_vectors = align_vectors.view(batch, targetL, sourceL) 129 | 130 | # each context vector c_t is the weighted average 131 | # over all the source hidden states 132 | c = torch.bmm(align_vectors, memory_bank) 133 | 134 | # concatenate 135 | concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) 136 | attn_h = self.linear_out(concat_c).view(batch, targetL, dim) 137 | if self.attn_type in ["general", "dot"]: 138 | attn_h = self.tanh(attn_h) 139 | 140 | if one_step: 141 | attn_h = attn_h.squeeze(1) 142 | align_vectors = align_vectors.squeeze(1) 143 | 144 | # Check output sizes 145 | batch_, dim_ = attn_h.size() 146 | aeq(batch, batch_) 147 | aeq(dim, dim_) 148 | batch_, sourceL_ = align_vectors.size() 149 | aeq(batch, batch_) 150 | aeq(sourceL, sourceL_) 151 | else: 152 | attn_h = attn_h.transpose(0, 1).contiguous() 153 | align_vectors = align_vectors.transpose(0, 1).contiguous() 154 | 155 | # Check output sizes 156 | targetL_, batch_, dim_ = attn_h.size() 157 | aeq(targetL, targetL_) 158 | aeq(batch, batch_) 159 | aeq(dim, dim_) 160 | targetL_, batch_, sourceL_ = align_vectors.size() 161 | aeq(targetL, targetL_) 162 | aeq(batch, batch_) 163 | aeq(sourceL, sourceL_) 164 | 165 | return attn_h, align_vectors 166 | -------------------------------------------------------------------------------- /NMT/Modules/IntehgerEncoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # the neural network model 7 | ############################################################################### 8 | 9 | def sample_gumbel(shape, eps=1e-20): 10 | mu = torch.rand(shape).cuda() 11 | return Variable(-torch.log(-torch.log(mu + eps) + eps)) 12 | 13 | 14 | def gumbel_softmax_sample(logits, temperature): 15 | y = logits + sample_gumbel(logits.size()) 16 | return F.softmax(y / temperature, dim=-1) 17 | 18 | 19 | def gumbel_softmax(logits, temperature=1.0): 20 | """ 21 | input: [*, n_class] 22 | return: [*, n_class] an one-hot vector 23 | """ 24 | y = gumbel_softmax_sample(logits, temperature) 25 | shape = y.size() 26 | _, ind = y.max(dim=-1) 27 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 28 | y_hard.scatter_(1, ind.view(-1, 1), 1) 29 | y_hard = y_hard.view(*shape) 30 | return (y_hard - y).detach() + y 31 | 32 | 33 | def softmax2onehot(logits): 34 | """ 35 | input: [*, n_class] 36 | return: [*, n_class] an one-hot vector 37 | """ 38 | 39 | y = F.softmax(logits, dim=-1) 40 | shape = y.size() 41 | _, ind = y.max(dim=-1) 42 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 43 | y_hard.scatter_(1, ind.view(-1, 1), 1) 44 | y_hard = y_hard.view(*shape) 45 | return (y_hard - y).detach() + y 46 | 47 | 48 | class BinaryEncoding(nn.Module): 49 | def __init__(self, embedding_dim, m, k, gpu=False): 50 | # m is the number of codebooks 51 | # k is the number of candidate vectors in each codebooks 52 | # the number of neurons in the hidden layer 53 | hidden_size = int(m * k / 2) 54 | self.package_num = m 55 | self.embedding_dim = embedding_dim 56 | super(Coding, self).__init__() 57 | self.bottleneck = nn.Sequential( 58 | nn.Linear(embedding_dim, hidden_size), 59 | nn.Tanh()) 60 | 61 | self.decomposition = nn.ModuleList( 62 | [nn.Conv1d(embedding_dim, k, 1) for _ in range(m)]) 63 | self.decomposition = nn.ModuleList( 64 | [nn.Linear(hidden_size, k) for _ in range(m)]) 65 | self.composition = nn.ModuleList( 66 | [nn.Linear(k, embedding_dim, bias=False) for _ in range(m)]) 67 | #self.fc = nn.Conv1d(32, 1, 1) 68 | if gpu: 69 | self.decomposition.cuda() 70 | self.composition.cuda() 71 | self.params_init() 72 | 73 | def params_init(self): 74 | if self.training: 75 | print("Initializing model parameters.", file=sys.stderr) 76 | for p in self.parameters(): 77 | p.data.uniform_(-0.1, 0.1) 78 | 79 | def forward(self, input): 80 | # decomposition 81 | context = self.encode(input) 82 | #print([c.size() for c in context]) 83 | # composition 84 | decoder_output = self.decode(context) 85 | return decoder_output 86 | 87 | def encode(self, input): 88 | # hidden layer 89 | h = self.bottleneck(input) 90 | #h = self.c1(input) 91 | #h = input.unsqueeze(2) 92 | #print(h.size()) 93 | # decomposition 94 | ds = [] 95 | for i, block in enumerate(self.decomposition): 96 | #h_i = block(h).squeeze(2) 97 | h_i = block(h) 98 | a_i = F.softplus(h_i) 99 | #d_i = gumbel_softmax(a_i) 100 | d_i = softmax2onehot(a_i) 101 | 102 | ds.append(d_i) 103 | return ds 104 | 105 | def decode(self, context): 106 | output = [] 107 | for i, block in enumerate(self.composition): 108 | # print(encoder_output[i].size()) 109 | # print(self.composition[i].size()) 110 | output.append(block(context[i])) 111 | output = torch.stack(output) 112 | #output = self.fc(output.transpose(0,1)).squeeze(1) 113 | #print(output.size()) 114 | output = torch.sum(output, dim=0) 115 | return output 116 | 117 | def hard_encode(self, input): 118 | """ 119 | This function maps continuous word embeddings to integer word embeddings. 120 | :param input: each row is a word embedding vector 121 | :return: each row is the corresponding integer embedding vector 122 | >>> s = torch.manual_seed(1) 123 | >>> net = Coding(5, 3, 4) # m=3 is the dimension of integer vector 124 | >>> v1, v2 = list(net.get_integer_embed(Variable(torch.FloatTensor([[1,2,3,4,5],[2,3,5,7,1]])))) 125 | >>> list(v1) 126 | [1.0, 2.0, 2.0] 127 | >>> list(v2) 128 | [1.0, 2.0, 2.0] 129 | """ 130 | ds = self.encode(input) 131 | int_vec = [d.max(dim=0)[1].data.item() for d in ds] 132 | return int_vec 133 | 134 | # data 135 | ######################################################################### 136 | 137 | 138 | class Dataset(list): 139 | def __init__(self, vec_path, *args, **kwargs): 140 | super(Dataset, self).__init__(*args, **kwargs) 141 | self.load_embedding(vec_path) 142 | self.word_ids = {w: i for i, (w, _) in enumerate(self)} 143 | self.words = [w for (w, _) in self] 144 | 145 | def load_embedding(self, vec_path): 146 | with open(vec_path, "r", errors="replace") as vec_fin: 147 | for i, line in enumerate(vec_fin): 148 | if i == 0: 149 | self.voc_size, self.dimension = map( 150 | lambda x: int(x), line.rstrip("\n").split()) 151 | continue 152 | cols = line.rstrip("\n").split() 153 | w = cols[0] 154 | v = torch.FloatTensor([float(x) for x in cols[1:]]) 155 | self.append((w, v)) 156 | vec_fin.close() 157 | 158 | 159 | # training 160 | ########################################################################## 161 | def train(optimizer, loss_function, model, trainloader, epoch, is_cuda=False): 162 | print('Training on epoch No.%s' % (epoch)) 163 | model.train() 164 | 165 | loss = 0. 166 | total = 0 167 | batch_size = trainloader.batch_size 168 | 169 | for i, data in enumerate(trainloader): 170 | optimizer.zero_grad() 171 | words, inputs = data 172 | targets = inputs.clone() 173 | cur_size = inputs.size(0) 174 | inputs, targets = Variable(inputs), Variable(targets) 175 | if is_cuda: 176 | inputs, targets = inputs.cuda(), targets.cuda() 177 | 178 | outputs = model(inputs) 179 | 180 | 181 | loss = loss_function(outputs, targets) 182 | loss.backward() 183 | #print (model.decomposition[0].weight) 184 | optimizer.step() 185 | if i != 0 and i % 50 == 0: 186 | print('epoch No.%s , loss:%0.2f' % 187 | (epoch, float(loss.data) / cur_size)) 188 | print(words[0], model.hard_encode(outputs[0])) 189 | 190 | # writing integer vector model 191 | ########################################################################## 192 | 193 | 194 | def dump_int_vec(model, dataset, batch_size, is_cuda=False): 195 | array = [] 196 | pbar = tqdm.tqdm(range(math.ceil(dataset.voc_size / batch_size))) 197 | for i in pbar: 198 | pbar.set_description('producing integer word embedding model') 199 | start = i * batch_size 200 | stop = (i + 1) * batch_size 201 | if stop > dataset.voc_size: 202 | stop = dataset.voc_size 203 | if is_cuda: 204 | new_matrix = model.hard_encode( 205 | Variable(torch.Tensor(dataset[start:stop])).cuda()) 206 | else: 207 | new_matrix = model.hard_encode( 208 | Variable(torch.Tensor(dataset[start:stop]))) 209 | 210 | for vec in new_matrix: 211 | array.append(list(vec)) 212 | dataset.save_new_model(np.array(array)) 213 | 214 | 215 | # argument parser 216 | ################################################################ 217 | def parse_args(): 218 | usage = '\n1. You can query by using a command as follows:\n' \ 219 | 'python manager.py -q /path/to/model/dirs -w 5 -d 100\n' 220 | 221 | parser = argparse.ArgumentParser( 222 | description='description:\nThis Python2 program helps to manage model.\n Current version support only querying.', usage=usage) 223 | parser.add_argument('-m', type=int, default=32, 224 | help='number of codebooks') 225 | parser.add_argument('-k', type=int, default=16, 226 | help='number of vectors in each code book') 227 | parser.add_argument('--vec', type=str, 228 | default="/itigo/Uploads/WMT2018/en-tr/orig_clean/ft.embed.vec", 229 | help='path to word embedding file') 230 | parser.add_argument('--gpu', action="store_true", default=False, 231 | help="train on gpu") 232 | parser.add_argument('--epoch', type=int, default=20, 233 | help="number of epochs") 234 | parser.add_argument('--batch_size', type=int, default=128, 235 | help="number of epochs") 236 | args = parser.parse_args() 237 | return args 238 | 239 | 240 | # main function 241 | ########################################################################## 242 | def main(vec, m_codebooks, k, epoch, gpu, batch_size): 243 | """ 244 | >>> main('/itigo/Uploads/Hongyz/WVs/088/088', True, 32, 16, 3, True) 245 | 246 | 247 | """ 248 | 249 | dataset = Dataset(vec) 250 | 251 | loss_function = nn.MSELoss(size_average=False) 252 | code = Coding(dataset.dimension, m_codebooks, k, gpu) 253 | print (code) 254 | optimizer = optim.Adam(code.parameters(), lr=0.001) 255 | trainloader = torch.utils.data.DataLoader( 256 | dataset, batch_size=batch_size, shuffle=True) 257 | if gpu: 258 | code.cuda() 259 | for e in range(epoch): 260 | train(optimizer, loss_function, code, trainloader, e, gpu) 261 | if e == 10: 262 | checkpoint = { 263 | 'model': code.state_dict(), 264 | } 265 | torch.save(checkpoint, 'code_e%d.pt' % e) 266 | 267 | dump_int_vec(code, dataset, batch_size, gpu) 268 | 269 | 270 | if __name__ == "__main__": 271 | args = parse_args() 272 | main(args.vec, args.m, args.k, 273 | args.epoch, args.gpu, args.batch_size) 274 | 275 | -------------------------------------------------------------------------------- /NMT/Modules/MultiHeadedAttn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | 8 | class MultiHeadedAttention(nn.Module): 9 | def __init__(self, num_heads, model_dim, dropout=0.1): 10 | assert model_dim % num_heads == 0 11 | self.dim_per_head = model_dim // num_heads 12 | self.model_dim = model_dim 13 | 14 | super(MultiHeadedAttention, self).__init__() 15 | self.num_heads = num_heads 16 | 17 | self.linear_keys = nn.Linear(model_dim, 18 | num_heads * self.dim_per_head) 19 | self.linear_values = nn.Linear(model_dim, 20 | num_heads * self.dim_per_head) 21 | self.linear_query = nn.Linear(model_dim, 22 | num_heads * self.dim_per_head) 23 | self.sm = nn.Softmax(dim=-1) 24 | self.dropout = nn.Dropout(dropout) 25 | self.final_linear = nn.Linear(model_dim, model_dim) 26 | 27 | def forward(self, key, value, query, mask=None): 28 | """ 29 | Compute the context vector and the attention vectors. 30 | Args: 31 | key (`FloatTensor`): set of `key_len` 32 | key vectors `[batch, key_len, dim]` 33 | value (`FloatTensor`): set of `key_len` 34 | value vectors `[batch, key_len, dim]` 35 | query (`FloatTensor`): set of `query_len` 36 | query vectors `[batch, query_len, dim]` 37 | mask: binary mask indicating which keys have 38 | non-zero attention `[batch, query_len, key_len]` 39 | Returns: 40 | (`FloatTensor`, `FloatTensor`) : 41 | * output context vectors `[batch, query_len, dim]` 42 | * one of the attention vectors `[batch, query_len, key_len]` 43 | """ 44 | batch_size = key.size(0) 45 | dim_per_head = self.dim_per_head 46 | num_heads = self.num_heads 47 | key_len = key.size(1) 48 | query_len = query.size(1) 49 | 50 | def shape(x): 51 | return x.view(batch_size, -1, num_heads, dim_per_head) \ 52 | .transpose(1, 2) 53 | 54 | def unshape(x): 55 | return x.transpose(1, 2).contiguous() \ 56 | .view(batch_size, -1, num_heads * dim_per_head) 57 | 58 | # 1) Project key, value, and query. 59 | key_up = shape(self.linear_keys(key)) 60 | value_up = shape(self.linear_values(value)) 61 | query_up = shape(self.linear_query(query)) 62 | 63 | # 2) Calculate and scale scores. 64 | query_up = query_up / math.sqrt(dim_per_head) 65 | scores = torch.matmul(query_up, key_up.transpose(2, 3)) 66 | 67 | if mask is not None: 68 | mask = mask.unsqueeze(1).expand_as(scores) 69 | scores = scores.masked_fill(Variable(mask), -1e18) 70 | 71 | # 3) Apply attention dropout and compute context vectors. 72 | attn = self.sm(scores) 73 | drop_attn = self.dropout(attn) 74 | context = unshape(torch.matmul(drop_attn, value_up)) 75 | 76 | output = self.final_linear(context) 77 | 78 | top_attn = attn \ 79 | .view(batch_size, num_heads, 80 | query_len, key_len)[:, 0, :, :] \ 81 | .contiguous() 82 | 83 | return output, top_attn -------------------------------------------------------------------------------- /NMT/Modules/StackedRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class StackedLSTM(nn.Module): 5 | """ 6 | Customed stacked LSTM, which we can custom the first layer size. 7 | """ 8 | def __init__(self, num_layers, input_size, hidden_size, dropout, residual=True): 9 | super(StackedLSTM, self).__init__() 10 | self.dropout = nn.Dropout(dropout) 11 | self.num_layers = num_layers 12 | self.residual = residual 13 | self.layers = nn.ModuleList() 14 | 15 | for i in range(num_layers): 16 | if i == 0: 17 | self.layers.append(nn.LSTMCell(input_size, hidden_size)) 18 | else: 19 | self.layers.append(nn.LSTMCell(hidden_size, hidden_size)) 20 | 21 | def forward(self, input, hidden): 22 | h_0, c_0 = hidden 23 | h_1, c_1 = [], [] 24 | for i, layer in enumerate(self.layers): 25 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 26 | input = h_1_i 27 | if i + 1 != self.num_layers: 28 | input = self.dropout(input) 29 | h_1 += [h_1_i] 30 | c_1 += [c_1_i] 31 | 32 | h_1 = torch.stack(h_1) 33 | c_1 = torch.stack(c_1) 34 | return input, (h_1, c_1) 35 | 36 | 37 | class StackedGRU(nn.Module): 38 | 39 | def __init__(self, num_layers, input_size, hidden_size, dropout, residual=True): 40 | super(StackedGRU, self).__init__() 41 | self.dropout = nn.Dropout(dropout) 42 | self.num_layers = num_layers 43 | self.layers = nn.ModuleList() 44 | self.residual = residual 45 | for i in range(num_layers): 46 | if i == 0: 47 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 48 | else: 49 | self.layers.append(nn.GRUCell(hidden_size, hidden_size)) 50 | 51 | 52 | def forward(self, input, hidden): 53 | """ 54 | Args: 55 | input (FloatTensor): [B x H]. 56 | hidden: [B x H]. 57 | """ 58 | assert len(input.size()) == 2 59 | 60 | h_1 = [] 61 | for i, layer in enumerate(self.layers): 62 | h_1_i = layer(input, hidden[0][i]) 63 | # if self.residual and 0 < i < self.num_layers-1: 64 | # input = h_1_i + input 65 | # else: 66 | input = h_1_i 67 | if i + 1 != self.num_layers: 68 | input = self.dropout(input) 69 | h_1 += [h_1_i] 70 | 71 | h_1 = torch.stack(h_1) 72 | return input, (h_1,) 73 | -------------------------------------------------------------------------------- /NMT/Modules/__init__.py: -------------------------------------------------------------------------------- 1 | from NMT.Modules.GlobalAttention import GlobalAttention 2 | from NMT.Modules.StackedRNN import StackedLSTM, StackedGRU 3 | from NMT.Modules.MultiHeadedAttn import MultiHeadedAttention 4 | -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/Embeddings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/Embeddings.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/GlobalAttention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/GlobalAttention.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/MultiHeadedAttn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/MultiHeadedAttn.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/StackedRNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/StackedRNN.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/UtilClass.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/UtilClass.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/Modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/Optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm 3 | from Utils.log import trace 4 | 5 | 6 | 7 | 8 | class Optimizer(object): 9 | def __init__(self, method, config): 10 | 11 | self.last_ppl = float("inf") 12 | self.lr = config.lr 13 | self.original_lr = config.lr 14 | self.max_grad_norm = config.max_grad_norm 15 | self.method = method 16 | self.lr_decay_rate = config.lr_decay_rate 17 | self.start_decay_at = config.start_decay_at 18 | self.start_decay = False 19 | self.alpha = config.alpha 20 | self._step = 0 21 | self.decreased_steps = 0 22 | self.decay_method = config.decay_method 23 | self.momentum = config.momentum 24 | self.betas = [config.adam_beta1, config.adam_beta2] 25 | self.eps = config.eps 26 | self.warmup_steps = config.warmup_steps 27 | self.model_size = config.hidden_size 28 | self.epochs = config.epochs 29 | 30 | def set_parameters(self, params): 31 | self.params = [] 32 | #self.sparse_params = [] 33 | for k, p in params: 34 | if p.requires_grad: 35 | #if "embed" not in k: 36 | self.params.append(p) 37 | if self.method == 'SGD': 38 | # I recommend SGD when using LSTM. 39 | self.optimizer = optim.SGD(self.params, lr=self.lr) 40 | elif self.method == 'Adam': 41 | self.optimizer = optim.Adam( 42 | self.params, 43 | lr=self.lr, 44 | betas=self.betas, 45 | eps=1e-9) 46 | else: 47 | raise NotImplementedError 48 | # elif self.method == 'Adadelta': 49 | # self.optimizer = optim.Adadelta(self.params, lr=self.lr, rho=0.95) 50 | 51 | # elif self.method == 'RMSprop': 52 | # # does not work properly. 53 | # self.optimizer = optim.RMSprop(self.params, lr=self.lr, 54 | # alpha=self.alpha, eps=self.eps, weight_decay=self.lr_decay_rate, 55 | # momentum=self.momentum, centered=False) 56 | def _set_rate(self, lr): 57 | self.lr = lr 58 | self.optimizer.param_groups[0]['lr'] = self.lr 59 | 60 | def step(self): 61 | self._step += 1 62 | if self.decay_method == "noam": 63 | self._set_rate(self.original_lr * (self.model_size ** (-0.5) * 64 | min(self._step ** (-0.5), self._step * self.warmup_steps**(-1.5)))) 65 | 66 | if self.max_grad_norm >0: 67 | clip_grad_norm(self.params, self.max_grad_norm) 68 | self.optimizer.step() 69 | 70 | def update_lr(self, ppl, epoch): 71 | if self.start_decay_at is not None and epoch > self.start_decay_at: 72 | self.start_decay = True 73 | if self.last_ppl is not None and ppl > self.last_ppl: 74 | self.start_decay = True 75 | 76 | if self.start_decay: 77 | self.lr = self.lr * self.lr_decay_rate 78 | trace("Decaying learning rate to %g" % self.lr) 79 | self.last_ppl = ppl 80 | self.optimizer.param_groups[0]['lr'] = self.lr 81 | -------------------------------------------------------------------------------- /NMT/Statistics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import math 4 | from Utils.log import trace 5 | 6 | 7 | class Statistics(object): 8 | """ 9 | Accumulator for loss statistics. 10 | Currently calculates: 11 | 12 | * accuracy 13 | * perplexity 14 | * elapsed time 15 | """ 16 | 17 | def __init__(self, loss=0, n_words=0, n_correct=0): 18 | self.loss = loss 19 | self.n_words = n_words 20 | self.n_correct = n_correct 21 | self.n_src_words = 0 22 | 23 | def update(self, stat): 24 | self.loss += stat.loss 25 | self.n_words += stat.n_words 26 | self.n_correct += stat.n_correct 27 | 28 | def accuracy(self): 29 | if self.n_words == 0: 30 | return 0 31 | return 100 * (float(self.n_correct) / self.n_words) 32 | 33 | def report(self, current_epoch, idx, num_batches, lr): 34 | """Write out statistics to stdout. 35 | 36 | Args: 37 | current_epoch (int): current epoch 38 | idx (int): current batch index 39 | num__batch (int): total batches 40 | """ 41 | report = "Epoch {0:d} [{1:d}/{2:d}], Acc: {3:.2f}; PPL: {4:.2f};" 42 | report += " Loss: {5:.2f}; lr: {6:.6f} \r" 43 | return report.format(current_epoch, idx, num_batches, 44 | self.accuracy(), self.ppl(), self.loss, lr) 45 | 46 | def report_and_flush(self, current_epoch, idx, num_batches, lr): 47 | sys.stderr.flush() 48 | sys.stderr.write( 49 | self.report( 50 | current_epoch, idx, 51 | num_batches, lr)) 52 | sys.stderr.flush() 53 | 54 | def __str__(self): 55 | string = "Acc: {0:.2f}; PPL: {1:.2f}; Loss: {2:.2f};" 56 | return string.format(self.accuracy(), self.ppl(), self.loss) 57 | 58 | # def xent(self): 59 | # if self.n_words == 0: 60 | # return 0 61 | # return self.loss / self.n_words 62 | 63 | def ppl(self): 64 | if self.n_words == 0: 65 | return 0 66 | return math.exp(min(float(self.loss) / self.n_words, 100)) 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /NMT/Trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from itertools import count 6 | import torch.nn.functional as F 7 | 8 | from Utils.log import trace 9 | from NMT.Loss import LossBase 10 | from NMT.Loss import LabelSmoothingLoss 11 | from NMT.Optimizer import Optimizer 12 | from NMT.Statistics import Statistics 13 | from NMT.CheckPoint import dump_checkpoint 14 | 15 | 16 | class Trainer(object): 17 | """ 18 | Class that controls the training process. 19 | 20 | Args: 21 | model (NMT.Model.NMTModel): NMT model 22 | config (Config): global configurations 23 | """ 24 | 25 | def __init__(self, model, trg_vocab, padding_idx, config): 26 | self.model = model 27 | 28 | self.padding_idx = padding_idx 29 | # self.train_loss = LabelSmoothingLoss( 30 | # config, padding_idx, len(trg_vocab), 31 | # config.label_smoothing).cuda() 32 | self.train_loss = LossBase( 33 | config, padding_idx, len(trg_vocab)).cuda() 34 | self.valid_loss = LossBase( 35 | config, padding_idx, len(trg_vocab)).cuda() 36 | 37 | self.optim = Optimizer(config.optim, config) 38 | self.optim.set_parameters(model.named_parameters()) 39 | 40 | self.save_model = config.save_model 41 | self.last_ppl = float('inf') 42 | self.steps = 0 43 | self.max_decrease_steps = config.max_decrease_steps 44 | self.stop = False 45 | self.report_every = config.report_every 46 | self.accum_grad_count = 4 47 | self.config = config 48 | self.early_stop = config.early_stop 49 | def validate(self, valid_data): 50 | """ Validate model. 51 | valid_iter: validate data iterator 52 | Returns: 53 | :obj:`nmt.Statistics`: validation loss statistics 54 | """ 55 | self.model.eval() 56 | valid_stats = Statistics() 57 | for batch in iter(valid_data): 58 | normalization = batch.batch_size 59 | src, src_lengths = batch.src, batch.src_Ls 60 | trg, ref = batch.trg[:-1], batch.trg[1:] 61 | outputs = self.model(src, src_lengths, trg)[0] 62 | probs = self.model.generator(outputs) 63 | loss, stats = self.valid_loss.compute( 64 | probs, ref, normalization) 65 | valid_stats.update(stats) 66 | del outputs, probs, stats, loss 67 | self.model.train() 68 | return valid_stats 69 | 70 | def train(self, current_epoch, epochs, train_data, valid_data, num_batches): 71 | """ Train next epoch. 72 | Args: 73 | train_data (BatchDataIterator): training dataset iterator 74 | valid_data (BatchDataIterator): validation dataset iterator 75 | epoch (int): the epoch number 76 | num_batches (int): the batch number 77 | Returns: 78 | stats (Statistics): epoch loss statistics 79 | """ 80 | self.model.train() 81 | 82 | if self.stop: 83 | return 84 | header = '-' * 30 + "Epoch [%d]" + '-' * 30 85 | trace(header % current_epoch) 86 | train_stats = Statistics() 87 | num_batches = train_data.num_batches 88 | 89 | batch_cache = [] 90 | for idx, batch in enumerate(iter(train_data), 1): 91 | batch_cache.append(batch) 92 | if len(batch_cache) == self.accum_grad_count or idx == num_batches: 93 | stats = self.train_each_batch( 94 | batch_cache, current_epoch, idx, num_batches) 95 | batch_cache = [] 96 | if idx == train_data.num_batches: 97 | train_stats.update(stats) 98 | if idx % self.report_every == 0 or idx == num_batches: 99 | trace(stats.report(current_epoch, idx, num_batches, self.optim.lr)) 100 | if idx % (self.report_every * 10) == 0 and self.early_stop: 101 | valid_stats = self.validate(valid_data) 102 | trace("Validation: " + valid_stats.report(current_epoch, idx, num_batches, self.optim.lr)) 103 | if self.early_stop(valid_stats.ppl()): 104 | self.stop = True 105 | break 106 | valid_stats = self.validate(valid_data) 107 | trace(str(valid_stats)) 108 | suffix = ".acc{0:.2f}.ppl{1:.2f}.e{2:d}".format( 109 | valid_stats.accuracy(), valid_stats.ppl(), current_epoch) 110 | self.optim.update_lr(valid_stats.ppl(), current_epoch) 111 | dump_checkpoint(self.model, self.save_model, suffix) 112 | 113 | def train_each_batch(self, batch_cache, current_epoch, idx, num_batches): 114 | self.model.zero_grad() 115 | batch_stats = Statistics() 116 | normalization = 0 117 | 118 | while batch_cache: 119 | kld = 0 120 | batch = batch_cache.pop(0) 121 | src, src_length = batch.src, batch.src_Ls 122 | trg, ref = batch.trg[:-1], batch.trg[1:] 123 | normalization += batch.batch_size 124 | args = self.model(src, src_length, trg) 125 | outputs= args[0] 126 | # kld = args[-1] 127 | probs = self.model.generator(outputs) 128 | loss, stats = self.train_loss.compute( 129 | probs, ref, normalization) 130 | loss.backward(retain_graph=True) 131 | batch_stats.update(stats) 132 | del probs, outputs, loss 133 | self.optim.step() 134 | 135 | 136 | batch_stats.report_and_flush( 137 | current_epoch, idx, 138 | num_batches, self.optim.lr) 139 | return batch_stats 140 | def early_stop(self, ppl): 141 | if ppl < self.last_ppl: 142 | self.last_ppl = ppl 143 | self.steps = 0 144 | else: 145 | self.steps += 1 146 | if self.steps >= self.max_decrease_steps: 147 | return True 148 | return False 149 | -------------------------------------------------------------------------------- /NMT/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import * 2 | from .Loss import * 3 | from .Trainer import * 4 | from .Optimizer import * 5 | from .Statistics import * 6 | from .translate import * 7 | from .ModelFactory import * 8 | from .CheckPoint import * -------------------------------------------------------------------------------- /NMT/__pycache__/CheckPoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/CheckPoint.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/Loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/Loss.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/ModelConstructor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/ModelConstructor.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/ModelFactory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/ModelFactory.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/Optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/Optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/Statistics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/Statistics.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/Trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/Trainer.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/translate/Penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | 9 | Args: 10 | length_pen (str): option name of length pen 11 | cov_pen (str): option name of cov pen 12 | """ 13 | def __init__(self, cov_pen, length_pen): 14 | self.length_pen = length_pen 15 | self.cov_pen = cov_pen 16 | 17 | def coverage_penalty(self): 18 | if self.cov_pen == "wu": 19 | return self.coverage_wu 20 | elif self.cov_pen == "summary": 21 | return self.coverage_summary 22 | else: 23 | return self.coverage_none 24 | 25 | def length_penalty(self): 26 | if self.length_pen == "wu": 27 | return self.length_wu 28 | elif self.length_pen == "avg": 29 | return self.length_average 30 | else: 31 | return self.length_none 32 | 33 | """ 34 | Below are all the different penalty terms implemented so far 35 | """ 36 | 37 | def coverage_wu(self, beam, cov, beta=0.): 38 | """ 39 | NMT coverage re-ranking score from 40 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 41 | """ 42 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 43 | return beta * penalty 44 | 45 | def coverage_summary(self, beam, cov, beta=0.): 46 | """ 47 | Our summary penalty. 48 | """ 49 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 50 | penalty -= cov.size(1) 51 | return beta * penalty 52 | 53 | def coverage_none(self, beam, cov, beta=0.): 54 | """ 55 | returns zero as penalty 56 | """ 57 | return beam.scores.clone().fill_(0.0) 58 | 59 | def length_wu(self, beam, logprobs, alpha=0.): 60 | """ 61 | NMT length re-ranking score from 62 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 63 | """ 64 | 65 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 66 | ((5 + 1) ** alpha)) 67 | return (logprobs / modifier) 68 | 69 | def length_average(self, beam, logprobs, alpha=0.): 70 | """ 71 | Returns the average probability of tokens in a sequence. 72 | """ 73 | return logprobs / len(beam.next_ys) 74 | 75 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 76 | """ 77 | Returns unmodified scores. 78 | """ 79 | return logprobs 80 | -------------------------------------------------------------------------------- /NMT/translate/Translation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Utils.DataLoader import EOS_WORD 3 | from Utils.DataLoader import UNK_WORD 4 | from Utils.DataLoader import PAD_WORD 5 | from Utils.DataLoader import BOS_WORD 6 | from Utils.log import trace 7 | 8 | 9 | class TranslationBuilder(object): 10 | """ 11 | Luong et al, 2015. Addressing the Rare Word Problem in Neural Machine Translation. 12 | """ 13 | 14 | def __init__(self, src_vocab, trg_vocab, config): 15 | """ 16 | Args: 17 | vocab (Vocab): vocabulary 18 | replace_unk (bool): replace unknown words using attention 19 | """ 20 | self.src_vocab = src_vocab 21 | self.trg_vocab = trg_vocab 22 | self.replace_unk = config.replace_unk 23 | self.k_best = config.k_best 24 | 25 | def _build_sentence(self, vocab, pred, src, attn=None): 26 | """ 27 | build sentence using predicted output with the given vocabulary. 28 | """ 29 | tokens = [] 30 | for wid in pred: 31 | token = vocab.itos[int(wid)] 32 | if token == BOS_WORD: 33 | continue 34 | if token == EOS_WORD: 35 | break 36 | tokens.append(token) 37 | 38 | if self.replace_unk and (attn is not None) and (src is not None): 39 | for i in range(len(tokens)): 40 | if tokens[i] == UNK_WORD: 41 | _, max_ = attn[i].max(0) 42 | tokens[i] = self.src_vocab.itos[src[int(max_)]] 43 | 44 | return tokens 45 | 46 | def build_target(self, pred, src, attn=None): 47 | 48 | return self._build_sentence( 49 | self.trg_vocab, pred, src, attn) 50 | 51 | def build_source(self, src): 52 | 53 | return self._build_sentence( 54 | self.src_vocab, src, src) 55 | 56 | def build(self, batch, preds, scores, attns, gold_scores=None): 57 | """ 58 | build translation from batch output 59 | Args: 60 | preds : `[B x K_best x L_t]`. 61 | scores : `[B x K_best]`. 62 | attns : `[B x K_best x L_t x L_s]`. 63 | """ 64 | batch_size = batch.batch_size 65 | 66 | translations = [None] * batch_size 67 | order = batch.sid % batch_size 68 | 69 | 70 | for i in range(batch_size): 71 | src = batch.src[:, i].tolist() 72 | input_sent = self.build_source(src) 73 | pred_sents = [] 74 | for k in range(self.k_best): 75 | sent = self.build_target(preds[i][k], src, attns[i][k]) 76 | pred_sents.append(sent) 77 | if batch.trg is not None: 78 | gold = batch.trg[:, i].tolist() 79 | gold_sent = self.build_target(gold, src) 80 | 81 | translations[order[i]] = Translation( 82 | input_sent, pred_sents, 83 | attns[i], scores[i], 84 | gold_sent, gold_scores[i]) 85 | return translations 86 | 87 | class Translation(object): 88 | """ 89 | Container for a translated sentence. 90 | 91 | Attributes: 92 | src (`LongTensor`): src word ids 93 | src_raw ([str]): raw src words 94 | 95 | pred_sents ([[str]]): words from the n-best translations 96 | pred_scores ([[float]]): log-probs of n-best translations 97 | attns ([`FloatTensor`]) : attention distributions for each translation 98 | gold_sent ([str]): words from gold translation 99 | gold_score ([float]): log-prob of gold translation 100 | 101 | """ 102 | 103 | def __init__(self, src, preds, attns, pred_scores, gold, gold_score): 104 | self.src = src 105 | self.preds = preds 106 | self.attns = attns 107 | self.pred_scores = pred_scores 108 | self.gold = gold 109 | self.gold_score = gold_score 110 | 111 | def pprint(self, sid): 112 | """ 113 | Log translation to stderr. 114 | """ 115 | output = '\nINPUT [{}]: {}\n'.format(sid, " ".join(self.src)) 116 | 117 | best = self.preds[0] 118 | best_score = self.pred_scores[0].sum().float() 119 | output += 'PRED [{}]: {}\t'.format(sid, " ".join(best)) 120 | output += "PRED SCORE: {:.4f}\n".format(best_score) 121 | 122 | if self.gold is not None: 123 | output += 'GOLD [{}]: {}\t'.format(sid, ' '.join(self.gold)) 124 | output += ("GOLD SCORE: {:.4f}\n".format(self.gold_score)) 125 | 126 | return output 127 | -------------------------------------------------------------------------------- /NMT/translate/Translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | from NMT.translate.Beam import Beam 5 | 6 | from Utils.DataLoader import PAD_WORD 7 | from Utils.DataLoader import BOS_WORD 8 | from Utils.DataLoader import EOS_WORD 9 | 10 | 11 | from .Translation import TranslationBuilder 12 | 13 | 14 | 15 | class BatchTranslator(object): 16 | """ 17 | Uses a model to translate a batch of sentences. 18 | 19 | Args: 20 | model (:obj:`NMT.Models`): 21 | NMT model to use for translation 22 | fields (dict of Fields): data fields 23 | beam_size (int): size of beam to use 24 | k_best (int): number of translations produced 25 | max_length (int): maximum length output to produce 26 | global_scores (:obj:`GlobalScorer`): 27 | object to rescore final translations 28 | """ 29 | 30 | def __init__(self, model, config, src_vocab, trg_vocab): 31 | self.config = config 32 | self.vocab = trg_vocab 33 | self.model = model 34 | self.k_best = config.k_best 35 | self.max_length = config.max_length 36 | self.beam_size = config.beam_size 37 | self.stepwise_penalty = config.stepwise_penalty 38 | self.PAD_WID = trg_vocab.stoi[PAD_WORD] 39 | self.BOS_WID = trg_vocab.stoi[BOS_WORD] 40 | self.EOS_WID = trg_vocab.stoi[EOS_WORD] 41 | self.use_beam_search = config.use_beam_search 42 | self.builder = TranslationBuilder(src_vocab, trg_vocab, config) 43 | def beam_search_decoding(self, encoder_output, src_length, decoder_state): 44 | """ 45 | beam search. 46 | 47 | Args: 48 | encoder_output (`Variable`): the output of encoder hidden layer [L_s, B, H] 49 | decoder_state (`Variable`): the state of encoder 50 | """ 51 | batch_size = encoder_output.size(1) 52 | beam = [Beam(self.beam_size, self.PAD_WID, self.BOS_WID, self.EOS_WID, 53 | n_best=self.k_best) for _ in range(batch_size)] 54 | 55 | 56 | # repeat source `beam_size` times. 57 | # [seq_len x beam_size x batch_size x ?] 58 | # [L_s, B, H] -> [L_s, K x B, H] 59 | encoder_outputs = Variable( 60 | encoder_output.data.repeat(1, self.beam_size, 1)) 61 | src_lengths = src_length.repeat(self.beam_size) 62 | decoder_state.repeat_beam_size_times(self.beam_size) 63 | 64 | for i in range(self.max_length): 65 | with torch.no_grad(): 66 | if all((b.done() for b in beam)): 67 | break 68 | cads = [b.get_current() for b in beam] 69 | trg = torch.stack(cads).view(1, -1) 70 | 71 | outputs, decoder_state, attn = self.model.decode( 72 | trg, encoder_outputs, src_lengths, decoder_state)[:3] 73 | 74 | dist = F.log_softmax( 75 | self.model.generator(outputs), dim=-1) 76 | 77 | score, idx = dist.topk(1, dim=-1) 78 | 79 | 80 | def unpack(x): 81 | return x.view(self.beam_size, batch_size, -1).contiguous() 82 | outputs = unpack(dist.squeeze(0)) 83 | beam_attn = unpack(attn) 84 | 85 | 86 | for j, b in enumerate(beam): 87 | # j: batch_size 88 | b.advance(outputs.data[:, j, :], 89 | beam_attn.data[:, j, :src_lengths[j]]) 90 | decoder_state.beam_update_state(j, b.get_origin()) 91 | del dist, outputs 92 | return beam 93 | 94 | def monotonic_decoding(self, encoder_outputs, src_length, state): 95 | """ 96 | beam search. 97 | 98 | Args: 99 | encoder_output (`Variable`): the output of encoder hidden layer [L_s, B, H] 100 | decoder_state (`Variable`): the state of encoder 101 | """ 102 | 103 | batch_size = encoder_outputs.size(1) 104 | attns = [] 105 | preds = [] 106 | scores = [] 107 | trg = torch.LongTensor(batch_size).fill_(self.BOS_WID).cuda() 108 | scores = torch.FloatTensor(batch_size).fill_(0).cuda() 109 | #print([self.vocab.itos[x] for x in trg.tolist()]) 110 | 111 | for i in range(self.max_length): 112 | with torch.no_grad(): 113 | output, attn, state = self.model.translate_step( 114 | trg.unsqueeze(0), encoder_outputs, src_length, state) 115 | 116 | dist = F.log_softmax( 117 | self.model.generator(output), 118 | dim=-1) 119 | 120 | score, idx = dist.topk(1, dim=-1) 121 | del dist, output 122 | trg, score = idx.squeeze(1), score.squeeze(1) 123 | score.masked_fill_(trg.eq(self.PAD_WID), 0).float() 124 | scores += score.data 125 | preds.append(trg) 126 | attns.append(attn) 127 | 128 | preds = torch.stack(preds, dim=0)\ 129 | .transpose(0, 1)\ 130 | .unsqueeze(1).contiguous() 131 | 132 | attns = torch.stack(attns, dim=0)\ 133 | .transpose(0, 1)\ 134 | .unsqueeze(1).contiguous() 135 | 136 | scores = scores.unsqueeze(1).contiguous() 137 | return preds, scores, attns 138 | 139 | 140 | 141 | 142 | def translate(self, batch): 143 | """ 144 | Translate a batch of sentences. 145 | 146 | Args: 147 | batch (Batch): a batch from a dataset object 148 | """ 149 | 150 | # 1. encoding 151 | src, src_length = batch.src, batch.src_Ls 152 | 153 | encoder_outputs, encoder_state = \ 154 | self.model.encoder(src, src_length) 155 | 156 | # 2. encoder to decoder 157 | decoder_state = self.model.decoder.init_decoder_state(encoder_state) 158 | 159 | # 3. generate translations using beam search. 160 | if self.use_beam_search: 161 | beam = self.beam_search_decoding( 162 | encoder_outputs, src_length, decoder_state) 163 | preds, scores, attns = self.extract_from_beam(beam) 164 | else: 165 | preds, scores, attns = self.monotonic_decoding( 166 | encoder_outputs, src_length, decoder_state) 167 | 168 | del encoder_outputs, encoder_state 169 | 170 | gold_scores = self.get_gold_scores(batch) 171 | batch_trans = self.builder.build(batch, preds, scores, attns, gold_scores) 172 | 173 | # print([self.vocab.itos[i] for i in src[:,0].tolist()]) 174 | # print([self.vocab.itos[i] for i in preds.squeeze().tolist()]) 175 | # print([self.vocab.itos[i] for i in batch.trg[:,0].tolist()]) 176 | return batch_trans 177 | 178 | 179 | def extract_from_beam(self, beam): 180 | """ 181 | extract translations from beam. 182 | """ 183 | preds = [[]] 184 | scores = [[]] 185 | attns = [[]] 186 | for b in beam: 187 | best_k = b.sort_finished(minimum=self.k_best)[:self.k_best] 188 | for score, t, k in best_k: 189 | wid, attn = b.get_hypo(t, k) 190 | preds[-1].append(torch.IntTensor(wid)) 191 | scores[-1].append(score) 192 | attns[-1].append(attn) 193 | preds.append([]) 194 | scores.append([]) 195 | attns.append([]) 196 | return preds[:-1], scores[:-1], attns[:-1] 197 | 198 | def get_gold_scores(self, batch): 199 | src = batch.src 200 | src_lengths = batch.src_Ls 201 | 202 | trg_in = batch.trg[:-1] 203 | trg_out = batch.trg[1:] 204 | 205 | encoder_outputs, encoder_state = \ 206 | self.model.encoder(src, src_lengths) 207 | 208 | decoder_state = self.model.decoder.init_decoder_state(encoder_state) 209 | 210 | gold_scores = torch.FloatTensor(batch.batch_size).fill_(0).cuda() 211 | 212 | 213 | for t_in, t_out in zip(trg_in.split(1, dim=0), trg_out.split(1, dim=0)): 214 | output = self.model.translate_step( 215 | t_in, encoder_outputs, src_lengths, decoder_state)[0] 216 | 217 | scores = F.log_softmax( 218 | self.model.generator(output), dim=-1) 219 | 220 | trg = t_out.transpose(0, 1) 221 | scores = scores.data.gather(1, trg) 222 | scores.masked_fill_(trg.eq(self.PAD_WID), 0).float() 223 | gold_scores += scores.data.squeeze(1) 224 | return gold_scores 225 | -------------------------------------------------------------------------------- /NMT/translate/__init__.py: -------------------------------------------------------------------------------- 1 | from .Translator import BatchTranslator 2 | from .Translation import Translation, TranslationBuilder 3 | from .Beam import Beam 4 | #from .Beam import GNMTGlobalScorer 5 | from .Penalties import PenaltyBuilder 6 | 7 | -------------------------------------------------------------------------------- /NMT/translate/__pycache__/Beam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/translate/__pycache__/Beam.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/translate/__pycache__/Penalties.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/translate/__pycache__/Penalties.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/translate/__pycache__/Translation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/translate/__pycache__/Translation.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/translate/__pycache__/Translator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/translate/__pycache__/Translator.cpython-36.pyc -------------------------------------------------------------------------------- /NMT/translate/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/NMT/translate/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Neural Machine Translation System (Pytorch) 2 | ========== 3 | State-of-the-art NMT systems are difficult to understand. 4 | 5 | For the beginners, I highly recommend this project, which is a simplified version of [Opennmt-py](https://github.com/OpenNMT/OpenNMT-py). 6 | 7 | Though some modules reference to OpenNMT-py, most of the code (80%) is written by myself. 8 | With the scripts that I posted, you can even build your own NMT systems and evaluate them on WMT 2018 datasets. 9 | 10 | 11 | REQUIREMENTS 12 | ------------ 13 | (I implemented it with Pytorch 0.4. ) 14 | Python version >= 3.6 (recommended) 15 | Pytorch version >= 0.4 (recommended) 16 | 17 | Usage 18 | ------------ 19 | For training, please use a Moses-style configuration file to specify paths and hyper-parameters. 20 | 21 | python train.py --config config/nmt.ini 22 | 23 | For translation, 24 | 25 | python translate.py --config config/nmt.ini --checkpoint {pretrained_model.pt} -v 26 | 27 | A known bug is that beam_search does not support batch decoding: 28 | when using beam_search = True, 29 | you need to set test_batch_size=1 to make the output correct. 30 | 31 | For monotonic decoding (without beam_search), you can use any number for test_batch_size. 32 | 33 | ## Optimizer 34 | - LSTM: 35 | 36 | SGD 1.0 37 | learning_rate_decay as 0.9 (recommended) 38 | 39 | - GRU: 40 | 41 | Adam 1e-4 42 | max_grad_norm = 5 (recommended) 43 | 44 | - Transformer: 45 | 46 | Adam 1e-4, 47 | grad_accum_count = 4~5, 48 | label_smoothing=0.1 (recommended) 49 | 50 | ## References 51 | 52 | 1. Vaswani, Ashish, et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017 53 | 54 | 2. Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015). 55 | 56 | 3. Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. "Neural machine translation by jointly learning to align and translate." arXiv preprint arXiv:1409.0473 (2014). 57 | -------------------------------------------------------------------------------- /Utils/DataLoader.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import numpy as np 4 | import math 5 | import torch 6 | from collections import Counter, defaultdict 7 | from Utils.utils import aeq 8 | from Utils.log import trace 9 | from itertools import chain 10 | import random 11 | 12 | 13 | PAD_WORD = '' # 0 14 | UNK_WORD = '' # 1 15 | BOS_WORD = '' # 2 16 | EOS_WORD = '' # 3 17 | 18 | 19 | class Vocab(object): 20 | def __init__(self, min_freq=1, specials=[PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD]): 21 | self.specials = specials 22 | self.counter = Counter() 23 | self.stoi = {} 24 | self.itos = {} 25 | self.weight = None 26 | self.min_freq = min_freq 27 | self.vocab_size = 0 28 | #config.min_freq 29 | 30 | @property 31 | def padding_idx(self): 32 | return self.stoi[PAD_WORD] 33 | 34 | def make_vocab(self, dataset): 35 | for x in dataset: 36 | self.counter.update(x) 37 | 38 | if self.min_freq > 1: 39 | self.counter = {w:i for w, i in filter( 40 | lambda x:x[1] >= self.min_freq, self.counter.items())} 41 | 42 | for w in self.specials: 43 | if w not in self.stoi: 44 | self.stoi[w] = self.vocab_size 45 | self.vocab_size += 1 46 | 47 | for w in self.counter.keys(): 48 | if w not in self.stoi: 49 | self.stoi[w] = self.vocab_size 50 | self.vocab_size += 1 51 | 52 | self.itos = {i:w for w, i in self.stoi.items()} 53 | 54 | def load_pretrained_embedding(self, embed_path, embed_dim): 55 | self.weight = np.zeros((self.vocab_size, int(embed_dim))) 56 | with open(embed_path, "r", errors="replace") as embed_fin: 57 | for line in embed_fin: 58 | cols = line.rstrip("\n").split() 59 | w = cols[0] 60 | if w in self.stoi: 61 | val = np.array(cols[1:]) 62 | self.weight[self.stoi[w]] = val 63 | else: 64 | pass 65 | embed_fin.close() 66 | for i in range(1, 2): 67 | self.weight[i] = np.zeros((embed_dim,)) 68 | # self.weights[i] = np.random.random_sample( 69 | # (self.config.embed_dim,)) 70 | self.weight = torch.from_numpy(self.weight) 71 | 72 | def __getitem__(self, key): 73 | return self.weight[key] 74 | 75 | def __len__(self): 76 | return self.vocab_size 77 | 78 | class DataSet(list): 79 | def __init__(self, src_txt, trg_txt=None, filtered=True, max_length=50): 80 | super(DataSet, self).__init__() 81 | self.filtered = filtered 82 | self.max_length = max_length 83 | self.read(src_txt, trg_txt) 84 | 85 | def read(self, src_txt, trg_txt): 86 | with open(src_txt, "r", encoding='utf-8') as fin_src, \ 87 | open(trg_txt, "r", encoding='utf-8') as fin_trg: 88 | for sid, (line1, line2) in enumerate(zip(fin_src, fin_trg)): 89 | src, trg = line1.rstrip("\r\n"), line2.rstrip("\r\n") 90 | src = src.split() 91 | trg = trg.split() 92 | if self.filtered: 93 | if len(src) <= self.max_length and \ 94 | len(trg) <= self.max_length: 95 | self.append((sid, (src, trg))) 96 | else: 97 | self.append((sid, (src, trg))) 98 | fin_src.close() 99 | fin_trg.close() 100 | 101 | def _numericalize(self, words, stoi): 102 | return [1 if x not in stoi else stoi[x] for x in words] 103 | 104 | @staticmethod 105 | def _denumericalize(words, itos): 106 | return ['UNK' if x not in itos else itos[x] for x in words] 107 | 108 | def numericalize(self, src_w2id, trg_w2id): 109 | for i, (sid, example) in enumerate(self): 110 | x, y = example 111 | x = self._numericalize(x, src_w2id) 112 | y = self._numericalize(y, trg_w2id) 113 | self[i] = (sid, (x, y)) 114 | 115 | class DataBatchIterator(object): 116 | def __init__(self, src_txt, trg_txt, 117 | training=True, shuffle=False, share_vocab=False, 118 | batch_size=64, max_length=50, vocab=None, 119 | mini_batch_sort_order='decreasing'): 120 | self.batch_size = batch_size 121 | self.shuffle = shuffle 122 | self.training = training 123 | 124 | # init dataset 125 | self.examples = DataSet(src_txt, trg_txt, 126 | filtered=training, 127 | max_length=max_length) 128 | 129 | # init vocab 130 | self.share_vocab = share_vocab 131 | self.src_vocab = Vocab() 132 | if self.share_vocab: 133 | self.trg_vocab = self.src_vocab 134 | else: 135 | self.trg_vocab = Vocab() 136 | self.init_vocab(vocab) 137 | 138 | self.examples.numericalize( 139 | src_w2id=self.src_vocab.stoi, 140 | trg_w2id=self.trg_vocab.stoi) 141 | self.num_batches = math.ceil(len(self.examples)/self.batch_size) 142 | self.mini_batch_sort_order = mini_batch_sort_order 143 | 144 | def init_vocab(self, path): 145 | if os.path.isfile(path): 146 | self.load_vocab(path) 147 | else: 148 | self.make_vocab() 149 | self.save_vocab(path) 150 | 151 | def make_vocab(self): 152 | trace("Building vocabulary ...") 153 | self.src_vocab.make_vocab(map(lambda x:x[1][0], self.examples)) 154 | self.trg_vocab.make_vocab(map(lambda x:x[1][1], self.examples)) 155 | 156 | 157 | def load_vocab(self, path): 158 | trace("Loading vocabulary ...") 159 | if self.share_vocab: 160 | self.trg_vocab = self.src_vocab = torch.load(path) 161 | else: 162 | self.src_vocab, self.trg_vocab = torch.load(path) 163 | 164 | def get_vocab(self): 165 | if self.share_vocab: 166 | return (self.trg_vocab, ) 167 | else: 168 | return (self.src_vocab, self.trg_vocab) 169 | 170 | def _pad(self, sentence, max_L, w2id, add_bos=False, add_eos=False): 171 | if add_bos: 172 | sentence = [w2id[BOS_WORD]] + sentence 173 | if add_eos: 174 | sentence = sentence + [w2id[EOS_WORD]] 175 | if len(sentence) < max_L: 176 | sentence = sentence + [w2id[PAD_WORD]] * (max_L-len(sentence)) 177 | return [x for x in sentence] 178 | 179 | 180 | def pad_seq_pair(self, samples): 181 | samples = samples if 'none' == self.mini_batch_sort_order else \ 182 | sorted( 183 | samples, 184 | key=lambda x: len(x[1][0]), 185 | reverse=('decreasing' == self.mini_batch_sort_order) 186 | ) 187 | pairs = [x for x in map(lambda x:x[1], samples)] 188 | sid = [x for x in map(lambda x:x[0], samples)] 189 | 190 | src_Ls = [len(pair[0])+2 for pair in pairs] 191 | trg_Ls = [len(pair[1])+2 for pair in pairs] 192 | 193 | max_trg_Ls = max(trg_Ls) 194 | max_src_Ls = max(src_Ls) 195 | src = [self._pad(src, max_src_Ls, self.src_vocab.stoi, 196 | add_bos=True, add_eos=True) for src, _ in pairs] 197 | trg = [self._pad(trg, max_trg_Ls, self.trg_vocab.stoi, 198 | add_bos=True, add_eos=True) for _, trg in pairs] 199 | 200 | 201 | batch = Batch() 202 | batch.src = torch.LongTensor(src).transpose(0, 1).cuda() 203 | batch.trg = torch.LongTensor(trg).transpose(0, 1).cuda() 204 | batch.sid = torch.LongTensor(sid).cuda() 205 | batch.src_Ls = torch.LongTensor(src_Ls).cuda() 206 | batch.trg_Ls = torch.LongTensor(trg_Ls).cuda() 207 | return batch 208 | 209 | def save_vocab(self, path): 210 | if self.share_vocab: 211 | torch.save(self.trg_vocab, path) 212 | else: 213 | torch.save([self.src_vocab, self.trg_vocab], path) 214 | 215 | def __len__(self): 216 | return self.num_batches 217 | 218 | def __iter__(self): 219 | if self.shuffle: 220 | random.shuffle(self.examples) 221 | total_num = len(self.examples) 222 | for i in range(self.num_batches): 223 | samples = self.examples[i * self.batch_size: \ 224 | min(total_num, self.batch_size*(i+1))] 225 | yield self.pad_seq_pair(samples) 226 | 227 | def __repr__(self): 228 | info = "" 229 | if self.share_vocab: 230 | assert self.src_vocab.vocab_size == self.trg_vocab.vocab_size 231 | info += "Using shared vocab, " 232 | info += "Vocab: [{0}], ".format(self.src_vocab.vocab_size) 233 | else: 234 | info += "Using independent vocab, " 235 | info += "Source: [{0}], ".format(self.src_vocab.vocab_size) 236 | info += "Target: [{0}], ".format(self.trg_vocab.vocab_size) 237 | info += "Dataset: [{0}] ".format(len(self.examples)) 238 | return info 239 | 240 | class Batch(object): 241 | def __init__(self): 242 | self.src = None 243 | self.trg = None 244 | self.src_Ls = None 245 | self.trg_Ls = None 246 | 247 | def __len__(self): 248 | return self.src_Ls.size(0) 249 | @property 250 | def batch_size(self): 251 | return self.src_Ls.size(0) 252 | 253 | -------------------------------------------------------------------------------- /Utils/__pycache__/DataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/DataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/args.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/bleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/bleu.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/plot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/plot.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/rouge.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/rouge.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-h/pynmt/1f488f320e01137cd5b3439557779ec4c9d35cf4/Utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Utils/args.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | from distutils.util import strtobool 5 | 6 | 7 | # parser = ArgumentParser() 8 | # parser.is_training = True if mode == "train" else False 9 | 10 | # def set_defaults(parser, **kwarg): 11 | # parser.set_defaults(**kwarg) 12 | 13 | # def get_defaults(parser): 14 | # return parser.parse_args() 15 | 16 | # @property 17 | # def sections(parser): 18 | # print(dir(parser)) 19 | 20 | def make_parser(training=True): 21 | parser = argparse.ArgumentParser() 22 | add_default_args(parser) 23 | add_data_args(parser) 24 | add_gpu_args(parser) 25 | add_embed_args(parser) 26 | add_common_network_args(parser) 27 | add_rnn_args(parser) 28 | add_transformer_args(parser) 29 | 30 | if training: 31 | add_optim_args(parser) 32 | add_train_args(parser) 33 | else: 34 | add_translate_args(parser) 35 | return parser 36 | 37 | def add_default_args(parser): 38 | group = parser.add_argument_group('Default') 39 | 40 | group.add_argument('--config', type=str, required=True, 41 | help="Path to config") 42 | 43 | group.add_argument('--system', type=str, 44 | help="which kind of NMT system to use [RNN, Transformer]") 45 | 46 | group.add_argument('-v', '--verbose', action="store_true", 47 | help='verbose mode can print more information.') 48 | 49 | group.add_argument('--save_log', type=str, 50 | help="Path to log file") 51 | 52 | group.add_argument('--save_vocab', type=str, 53 | help="Path to vocab files") 54 | 55 | group.add_argument('--save_model', default='model', 56 | help="""Pre-trained models""") 57 | 58 | group.add_argument('--start_epoch', type=int, default=1, 59 | help='Number of training epochs to start') 60 | 61 | group.add_argument('--checkpoint', default=[], nargs='+', type=str, 62 | help="""If training from a checkpoint then this is the 63 | path to the pretrained model's state_dict.""") 64 | 65 | 66 | 67 | def add_gpu_args(parser): 68 | group = parser.add_argument_group('GPU') 69 | 70 | group.add_argument('--use_gpu', default=[], nargs='+', type=int, 71 | help="Use CUDA on the listed devices.") 72 | 73 | group.add_argument('--use_cpu', default=False, action='store_true', 74 | help="Use CPU.") 75 | 76 | def add_data_args(parser): 77 | group = parser.add_argument_group('Data') 78 | 79 | group.add_argument('--src_lang', type=str, 80 | help="source language name suffix") 81 | 82 | group.add_argument('--trg_lang', type=str, 83 | help="target language name suffix") 84 | 85 | group.add_argument('--data_path', type=str, 86 | help="path to datasets") 87 | 88 | group.add_argument('--train_dataset', default='train', type=str, 89 | help="""The training dataset""") 90 | 91 | group.add_argument('--valid_dataset', default='dev', type=str, 92 | help="""The validation dataset""") 93 | 94 | group.add_argument('--test_dataset', default='test', type=str, 95 | help="""Path to output the predictions (each line will 96 | be the decoded sequence""") 97 | 98 | group.add_argument('--max_seq_len', type=int, default=50, 99 | help="Maximum sequence length") 100 | 101 | group.add_argument('--shuffle_data', type=lambda x: bool(strtobool(x)), default=True, 102 | help="Whether to shuffle data") 103 | 104 | group.add_argument('--mini_batch_sort_order', type=str, default='decreasing', 105 | choices=['decreasing', 'increasing', 'none'], 106 | help='Order for sorting mini-batches by length') 107 | 108 | def add_embed_args(parser): 109 | group = parser.add_argument_group('Embedding') 110 | 111 | group.add_argument('--min_freq', type=int, default=5, 112 | help="Minimal frequency for the prepared data") 113 | 114 | group.add_argument('--src_embed_dim', type=int, default=512, 115 | help='Word embedding size for source.') 116 | 117 | group.add_argument('--trg_embed_dim', type=int, default=512, 118 | help='Word embedding size for target.') 119 | 120 | group.add_argument('--share_vocab', action='store_true', 121 | help="""sharing vocabulary across languages.""") 122 | 123 | group.add_argument('--share_embedding', action='store_true', 124 | help="""sharing embedding across languages.""") 125 | 126 | group.add_argument('--sparse_embeddings', type=lambda x: bool(strtobool(x)), default=False, 127 | help='Whether to use sparse embeddings') 128 | 129 | 130 | 131 | 132 | def add_train_args(parser): 133 | group = parser.add_argument_group('Train') 134 | 135 | group.add_argument('--max_decrease_steps', type=int, default=10, 136 | help='Number of maximal decreased steps for early stopping.') 137 | 138 | group.add_argument('--epochs', type=int, default=-1, 139 | help='Number of training epochs') 140 | 141 | group.add_argument('--train_batch_size', type=int, default=64, 142 | help='Maximum batch size for training') 143 | 144 | group.add_argument('--valid_batch_size', type=int, default=32, 145 | help='Maximum batch size for validation') 146 | 147 | group.add_argument('--report_every', type=int, default=100, 148 | help='Report statistics after the determinted number of steps.') 149 | 150 | def add_translate_args(parser): 151 | group = parser.add_argument_group('Translate') 152 | 153 | 154 | 155 | 156 | 157 | group.add_argument('--output', default='pred.txt', 158 | help="""Path to output the predictions (each line will 159 | be the decoded sequence""") 160 | 161 | 162 | 163 | group.add_argument('--load_epoch', type=int, default=-1, 164 | help="""Pre-trained models""") 165 | 166 | group.add_argument('--test_batch_size', type=int, default=32, 167 | help='Maximum batch size for testing') 168 | 169 | group.add_argument('--beam_size', type=int, default=5, 170 | help='Beam size') 171 | 172 | group.add_argument('--k_best', type=int, default=1, 173 | help='Output K-best translations') 174 | 175 | group.add_argument('--max_length', type=int, default=100, 176 | help='Maximum prediction length.') 177 | 178 | # Alpha and Beta values for Google Length + Coverage penalty 179 | # Described here: https://arxiv.org/pdf/1609.08144.pdf, Section 7 180 | group.add_argument('-stepwise_penalty', action='store_true', 181 | help="""Apply penalty at every decoding step. 182 | Helpful for summary penalty.""") 183 | group.add_argument('--length_penalty', default='none', 184 | choices=['none', 'wu', 'avg'], 185 | help="""Length Penalty to use.""") 186 | group.add_argument('--coverage_penalty', default='none', 187 | choices=['none', 'wu', 'summary'], 188 | help="""Coverage Penalty to use.""") 189 | group.add_argument('--alpha', type=float, default=0., 190 | help="""Google NMT length penalty parameter 191 | (higher = longer generation)""") 192 | group.add_argument('--beta', type=float, default=-0., 193 | help="""Coverage penalty parameter""") 194 | 195 | group.add_argument('--block_ngram_repeat', type=int, default=0, 196 | help='Block repetition of ngrams during decoding.') 197 | 198 | group.add_argument('--ignore_when_blocking', nargs='+', type=str, 199 | default=[], 200 | help="""Ignore these strings when blocking repeats. 201 | You want to block sentence delimiters.""") 202 | 203 | group.add_argument('--replace_unk', action="store_true", 204 | help="""Replace the generated UNK tokens with the 205 | source token that had highest attention weight. If 206 | phrase_table is provided, it will lookup the 207 | identified source token and give the corresponding 208 | target token. If it is not provided(or the identified 209 | source token does not exist in the table) then it 210 | will copy the source token""") 211 | 212 | 213 | group.add_argument('--plot_attn', action="store_true", 214 | help='Plot attention matrix for each pair') 215 | 216 | group.add_argument('--use_beam_search', action="store_true", 217 | help='use beam search for decoding.' ) 218 | 219 | group.add_argument('--ensemble', default=False, action='store_true', 220 | help="Use ensemble-decoding.") 221 | 222 | def add_common_network_args(parser): 223 | group = parser.add_argument_group('Network') 224 | 225 | group.add_argument('--enc_num_layers', type=int, default=2, 226 | help='Number of layers in the encoder') 227 | 228 | group.add_argument('--dec_num_layers', type=int, default=2, 229 | help='Number of layers in the decoder') 230 | 231 | group.add_argument('--attn_type', type=str, default='general', 232 | choices=['dot', 'general', 'mlp'], 233 | help="""The attention type to use: 234 | dotprod or general (Luong) or MLP (Bahdanau)""") 235 | 236 | group.add_argument('--hidden_size', type=int, default=512, 237 | help='Number of hidden states') 238 | 239 | group.add_argument('--dropout', type=float, default=0.3, 240 | help="Dropout probability; applied in RNN stacks.") 241 | 242 | 243 | def add_rnn_args(parser): 244 | group = parser.add_argument_group('RNN') 245 | 246 | group.add_argument('--bidirectional', action='store_true', 247 | help="""bidirectional encoding for encoder.""") 248 | 249 | group.add_argument('--rnn_type', type=str, 250 | choices=['LSTM', 'GRU'], 251 | help="""The gate type to use in the RNNs""") 252 | 253 | 254 | group.add_argument('--residual', action='store_true', 255 | help="""using residual RNN.""") 256 | 257 | 258 | 259 | def add_transformer_args(parser): 260 | group = parser.add_argument_group('Transformer') 261 | 262 | group.add_argument('--num_heads', type=int, default=8, 263 | help='Number of heads in the MultiHeadedAttention') 264 | 265 | group.add_argument('--inner_hidden_size', type=int, default=1024, 266 | help='inner hidden size for the MultiHeadedAttention') 267 | 268 | group.add_argument('--latent_size', type=int, default=512, 269 | help='latent_size') 270 | 271 | def add_optim_args(parser): 272 | # Optimization options 273 | group = parser.add_argument_group('Optimizer') 274 | 275 | 276 | group.add_argument('--early_stop', action='store_true', 277 | help="""early stop.""") 278 | 279 | group.add_argument('--optim', default='Adam', 280 | choices=['SGD', 'Adadelta', 'Adam'], 281 | help="""Optimization method.""") 282 | 283 | group.add_argument('--max_grad_norm', type=float, default=0, 284 | help="""If the norm of the gradient vector exceeds this, 285 | renormalize it to have the norm equal to 286 | max_grad_norm""") 287 | 288 | group.add_argument('--lr', type=float, default=1e-4, 289 | help="""Starting learning rate. 290 | Recommended settings: SDG = 1, Adadelta = 1, 291 | Adam = 0.001""") 292 | 293 | group.add_argument('--lr_decay_rate', type=float, default=0.5, 294 | help="""If update_learning_rate, decay learning rate by 295 | this much if (i) perplexity does not decrease on the 296 | validation set or (ii) epoch has gone past 297 | start_decay_at""") 298 | 299 | group.add_argument('--start_decay_at', type=int, default=8, 300 | help="""Start decaying every epoch after and including this 301 | epoch""") 302 | 303 | group.add_argument('--decay_method', type=str, default="", 304 | choices=['noam'], help="Use a custom decay rate.") 305 | 306 | group.add_argument('--warmup_steps', type=float, default=4000, 307 | help="""warmup_steps for Adam.""") 308 | 309 | group.add_argument('--alpha', type=float, default=0.9, 310 | help="""The alpha parameter used by RMSprop.""") 311 | 312 | group.add_argument('--eps', type=float, default=1e-8, 313 | help="""The eps parameter used by RMSprop/Adam.""") 314 | 315 | group.add_argument('--rho', type=float, default=0.95, 316 | help="""The rho parameter used by RMSprop.""") 317 | 318 | group.add_argument('--weight_decay', type=float, default=0, 319 | help="""The weight_decay parameter used by RMSprop.""") 320 | 321 | group.add_argument('--momentum', type=float, default=0, 322 | help="""The momentum parameter used by RMSprop[0]/SGD[0.9].""") 323 | 324 | group.add_argument('--adam_beta1', type=float, default=0.9, 325 | help="""The beta1 parameter used by Adam.""") 326 | 327 | group.add_argument('--adam_beta2', type=float, default=0.999, 328 | help="""The beta2 parameter used by Adam.""") 329 | 330 | group.add_argument('--label_smoothing', type=float, default=0.1, 331 | help="""using label smoothing.""") 332 | 333 | group.add_argument('--grad_accum_count', type=int, default=4, 334 | help="""using label smoothing.""") 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /Utils/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | This module provides a Python implementation of BLEU and smooth-BLEU. 18 | Smooth BLEU is computed following the method outlined in the paper: 19 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 20 | evaluation metrics for machine translation. COLING 2004. 21 | """ 22 | 23 | import collections 24 | import math 25 | 26 | 27 | def _get_ngrams(segment, max_order): 28 | """Extracts all n-grams upto a given maximum order from an input segment. 29 | Args: 30 | segment: text segment from which n-grams will be extracted. 31 | max_order: maximum length in tokens of the n-grams returned by this 32 | methods. 33 | Returns: 34 | The Counter containing all n-grams upto max_order in segment 35 | with a count of how many times each n-gram occurred. 36 | """ 37 | ngram_counts = collections.Counter() 38 | for order in range(1, max_order + 1): 39 | for i in range(0, len(segment) - order + 1): 40 | ngram = tuple(segment[i:i+order]) 41 | ngram_counts[ngram] += 1 42 | return ngram_counts 43 | 44 | 45 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 46 | smooth=False): 47 | """Computes BLEU score of translated segments against one or more references. 48 | Args: 49 | reference_corpus: list of lists of references for each translation. Each 50 | reference should be tokenized into a list of tokens. 51 | translation_corpus: list of translations to score. Each translation 52 | should be tokenized into a list of tokens. 53 | max_order: Maximum n-gram order to use when computing BLEU score. 54 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 55 | Returns: 56 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 57 | precisions and brevity penalty. 58 | """ 59 | matches_by_order = [0] * max_order 60 | possible_matches_by_order = [0] * max_order 61 | reference_length = 0 62 | translation_length = 0 63 | for (references, translation) in zip(reference_corpus, 64 | translation_corpus): 65 | reference_length += min(len(r) for r in references) 66 | translation_length += len(translation) 67 | 68 | merged_ref_ngram_counts = collections.Counter() 69 | for reference in references: 70 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 71 | translation_ngram_counts = _get_ngrams(translation, max_order) 72 | overlap = translation_ngram_counts & merged_ref_ngram_counts 73 | for ngram in overlap: 74 | matches_by_order[len(ngram)-1] += overlap[ngram] 75 | for order in range(1, max_order+1): 76 | possible_matches = len(translation) - order + 1 77 | if possible_matches > 0: 78 | possible_matches_by_order[order-1] += possible_matches 79 | 80 | precisions = [0] * max_order 81 | for i in range(0, max_order): 82 | if smooth: 83 | precisions[i] = ((matches_by_order[i] + 1.) / 84 | (possible_matches_by_order[i] + 1.)) 85 | else: 86 | if possible_matches_by_order[i] > 0: 87 | precisions[i] = (float(matches_by_order[i]) / 88 | possible_matches_by_order[i]) 89 | else: 90 | precisions[i] = 0.0 91 | 92 | if min(precisions) > 0: 93 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 94 | geo_mean = math.exp(p_log_sum) 95 | else: 96 | geo_mean = 0 97 | 98 | ratio = float(translation_length) / reference_length 99 | 100 | if ratio > 1.0: 101 | bp = 1. 102 | else: 103 | bp = math.exp(1 - 1. / ratio) 104 | 105 | bleu = geo_mean * bp 106 | 107 | return (bleu, precisions, bp, ratio, translation_length, reference_length) -------------------------------------------------------------------------------- /Utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import logging 5 | import argparse 6 | from configparser import ConfigParser 7 | 8 | from Utils.args import make_parser 9 | from Utils.log import trace 10 | from Utils.log import set_logging 11 | 12 | 13 | 14 | def check_save_path(path): 15 | save_path = os.path.abspath(path) 16 | dirname = os.path.dirname(save_path) 17 | if not os.path.exists(dirname): 18 | os.makedirs(dirname) 19 | 20 | class Config(object): 21 | 22 | def __init__(self, prefix, training=True): 23 | self.training = training 24 | self.parser = make_parser(training) 25 | self.args = self.parser.parse_args() 26 | 27 | self.config_file = self.args.config 28 | self.read_config(self.config_file) 29 | self.filter() 30 | check_save_path(self.args.save_log) 31 | check_save_path(self.args.save_vocab) 32 | check_save_path(self.args.save_model) 33 | set_logging(prefix, self.args.save_log) 34 | 35 | def filter(self): 36 | if not self.args.use_cpu: 37 | del self.args.use_cpu 38 | 39 | 40 | def __getattr__(self, name): 41 | return getattr(self.args, name) 42 | 43 | def set_defaults(self, config, section): 44 | defaults = dict(config.items(section)) 45 | for key, val in defaults.items(): 46 | if key == "use_gpu": 47 | val = eval(val) 48 | else: 49 | for attr in ["getint", "getfloat", "getboolean"]: 50 | try: 51 | val = getattr(config[section], attr)(key) 52 | break 53 | except: pass 54 | defaults[key] = val 55 | 56 | self.parser.set_defaults(**defaults) 57 | self.args = self.parser.parse_args() 58 | 59 | def check_config_exist(self): 60 | if not os.path.isfile(self.config_file): 61 | trace("""# Cannot find the configuration file. 62 | {} does not exist! Please check.""".format(self.config_file)) 63 | sys.exit(1) 64 | 65 | def read_config(self, config_file): 66 | self.check_config_exist() 67 | config = ConfigParser() 68 | config.read(config_file) 69 | 70 | groups = set(group.title for group in self.parser._action_groups) 71 | sections = groups.intersection(set(config.sections())) 72 | for section in sections: 73 | #print("#section", section) 74 | self.set_defaults(config, section) 75 | 76 | save_path = os.path.abspath(self.args.save_model) 77 | dirname = os.path.dirname(save_path) 78 | 79 | config_file_bak = os.path.join(dirname, 'config.ini') 80 | check_save_path(config_file_bak) 81 | if not os.path.isfile(config_file_bak): 82 | with open(config_file_bak, 'w') as configfile: 83 | config.write(configfile) 84 | def __repr__(self): 85 | ret = "\n" 86 | pattern = r'\' 87 | for key, value in vars(self.args).items(): 88 | class_type = re.search(pattern, str(type(value))).group(1) 89 | class_type = "[{}]".format(class_type) 90 | value_string = str(value) 91 | if len(value_string) > 80: 92 | value_string = "/".join(value_string.split("/")[:2]) +\ 93 | "/.../" + "/".join(value_string.split("/")[-2:]) 94 | ret += " {}\t{}\t{}\n".format(key.ljust(15), 95 | class_type.ljust(8), value_string) 96 | return ret 97 | 98 | # def config_device(config): 99 | # if config.device_ids: 100 | # config.device_ids = [int(x) for x in config.device_ids.split(",")] 101 | # else: 102 | # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 103 | # [str(idx) for idx in list( 104 | # range(config.gpu_ids, config.gpu_ids + config.num_gpus))]) 105 | # config.device_ids = list(range(config.num_gpus)) 106 | -------------------------------------------------------------------------------- /Utils/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | 5 | 6 | class LoggerWriter(object): 7 | def __init__(self, name, filename): 8 | self.logger = logging.getLogger(name) 9 | self.logger.setLevel(logging.DEBUG) 10 | self.datefmt = '%Y/%m/%d %H:%M:%S' 11 | self.set_default_log_handler() 12 | self.set_log_handler_to_file(filename) 13 | 14 | 15 | 16 | def set_default_log_handler(self): 17 | handler1 = logging.StreamHandler() 18 | handler1.setFormatter( 19 | logging.Formatter("%(asctime)s %(message)s", 20 | datefmt=self.datefmt) 21 | ) 22 | self.logger.addHandler(handler1) 23 | 24 | def set_log_handler_to_file(self, filename): 25 | if os.path.exists(filename): 26 | 27 | os.remove(filename) 28 | handler2 = logging.FileHandler(filename=filename) 29 | handler2.setLevel(logging.DEBUG) 30 | handler2.setFormatter( 31 | logging.Formatter("%(asctime)s %(message)s", 32 | datefmt=self.datefmt) 33 | ) 34 | self.logger.addHandler(handler2) 35 | 36 | def write(self, message): 37 | if message != "\n" or message != "": 38 | self.logger.info(message) 39 | 40 | def flush(self): 41 | handler = self.logger.handlers[0] 42 | handler.flush() 43 | 44 | logger = sys.stderr 45 | def set_logging(name, filename): 46 | global logger 47 | logger = LoggerWriter(name, filename + "." + name) 48 | 49 | def trace(*args): 50 | global logger 51 | logger.write(" ".join(map(lambda x:str(x), args))) -------------------------------------------------------------------------------- /Utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.ticker as ticker 3 | import numpy as np 4 | 5 | from matplotlib.font_manager import FontProperties 6 | from Utils.DataLoader import BOS_WORD, EOS_WORD 7 | JaFont = FontProperties(fname = '/usr/share/fonts/truetype/takao-gothic/TakaoGothic.ttf') 8 | ZhFont = FontProperties(fname = '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc') 9 | 10 | def plot_attn(source, target, attn): 11 | # Set up figure with colorbar 12 | fig = plt.figure() 13 | ax = fig.add_subplot(111) 14 | cax = ax.matshow(attn.numpy(), cmap='bone') 15 | fig.colorbar(cax) 16 | # Set up axes 17 | 18 | ax.set_xticklabels(['']+[BOS_WORD]+source+[EOS_WORD], rotation=90 , fontproperties = ZhFont) 19 | ax.set_yticklabels(['']+target+[EOS_WORD], fontproperties = JaFont) 20 | 21 | # Show label at every tick 22 | ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 23 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 24 | 25 | plt.show() 26 | -------------------------------------------------------------------------------- /Utils/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 3 | This is a modified and slightly extended verison of 4 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | import itertools 13 | import numpy as np 14 | 15 | #pylint: disable=C0103 16 | 17 | 18 | def _get_ngrams(n, text): 19 | """Calcualtes n-grams. 20 | Args: 21 | n: which n-grams to calculate 22 | text: An array of tokens 23 | Returns: 24 | A set of n-grams 25 | """ 26 | ngram_set = set() 27 | text_length = len(text) 28 | max_index_ngram_start = text_length - n 29 | for i in range(max_index_ngram_start + 1): 30 | ngram_set.add(tuple(text[i:i + n])) 31 | return ngram_set 32 | 33 | 34 | def _split_into_words(sentences): 35 | """Splits multiple sentences into words and flattens the result""" 36 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 37 | 38 | 39 | def _get_word_ngrams(n, sentences): 40 | """Calculates word n-grams for multiple sentences. 41 | """ 42 | assert len(sentences) > 0 43 | assert n > 0 44 | 45 | words = _split_into_words(sentences) 46 | return _get_ngrams(n, words) 47 | 48 | 49 | def _len_lcs(x, y): 50 | """ 51 | Returns the length of the Longest Common Subsequence between sequences x 52 | and y. 53 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 54 | Args: 55 | x: sequence of words 56 | y: sequence of words 57 | Returns 58 | integer: Length of LCS between x and y 59 | """ 60 | table = _lcs(x, y) 61 | n, m = len(x), len(y) 62 | return table[n, m] 63 | 64 | 65 | def _lcs(x, y): 66 | """ 67 | Computes the length of the longest common subsequence (lcs) between two 68 | strings. The implementation below uses a DP programming algorithm and runs 69 | in O(nm) time where n = len(x) and m = len(y). 70 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 71 | Args: 72 | x: collection of words 73 | y: collection of words 74 | Returns: 75 | Table of dictionary of coord and len lcs 76 | """ 77 | n, m = len(x), len(y) 78 | table = dict() 79 | for i in range(n + 1): 80 | for j in range(m + 1): 81 | if i == 0 or j == 0: 82 | table[i, j] = 0 83 | elif x[i - 1] == y[j - 1]: 84 | table[i, j] = table[i - 1, j - 1] + 1 85 | else: 86 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 87 | return table 88 | 89 | 90 | def _recon_lcs(x, y): 91 | """ 92 | Returns the Longest Subsequence between x and y. 93 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 94 | Args: 95 | x: sequence of words 96 | y: sequence of words 97 | Returns: 98 | sequence: LCS of x and y 99 | """ 100 | i, j = len(x), len(y) 101 | table = _lcs(x, y) 102 | 103 | def _recon(i, j): 104 | """private recon calculation""" 105 | if i == 0 or j == 0: 106 | return [] 107 | elif x[i - 1] == y[j - 1]: 108 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 109 | elif table[i - 1, j] > table[i, j - 1]: 110 | return _recon(i - 1, j) 111 | else: 112 | return _recon(i, j - 1) 113 | 114 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 115 | return recon_tuple 116 | 117 | 118 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 119 | """ 120 | Computes ROUGE-N of two text collections of sentences. 121 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 122 | papers/rouge-working-note-v1.3.1.pdf 123 | Args: 124 | evaluated_sentences: The sentences that have been picked by the summarizer 125 | reference_sentences: The sentences from the referene set 126 | n: Size of ngram. Defaults to 2. 127 | Returns: 128 | A tuple (f1, precision, recall) for ROUGE-N 129 | Raises: 130 | ValueError: raises exception if a param has len <= 0 131 | """ 132 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 133 | raise ValueError("Collections must contain at least 1 sentence.") 134 | 135 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 136 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 137 | reference_count = len(reference_ngrams) 138 | evaluated_count = len(evaluated_ngrams) 139 | 140 | # Gets the overlapping ngrams between evaluated and reference 141 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 142 | overlapping_count = len(overlapping_ngrams) 143 | 144 | # Handle edge case. This isn't mathematically correct, but it's good enough 145 | if evaluated_count == 0: 146 | precision = 0.0 147 | else: 148 | precision = overlapping_count / evaluated_count 149 | 150 | if reference_count == 0: 151 | recall = 0.0 152 | else: 153 | recall = overlapping_count / reference_count 154 | 155 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 156 | 157 | # return overlapping_count / reference_count 158 | return f1_score, precision, recall 159 | 160 | 161 | def _f_p_r_lcs(llcs, m, n): 162 | """ 163 | Computes the LCS-based F-measure score 164 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 165 | rouge-working-note-v1.3.1.pdf 166 | Args: 167 | llcs: Length of LCS 168 | m: number of words in reference summary 169 | n: number of words in candidate summary 170 | Returns: 171 | Float. LCS-based F-measure score 172 | """ 173 | r_lcs = llcs / m 174 | p_lcs = llcs / n 175 | beta = p_lcs / (r_lcs + 1e-12) 176 | num = (1 + (beta**2)) * r_lcs * p_lcs 177 | denom = r_lcs + ((beta**2) * p_lcs) 178 | f_lcs = num / (denom + 1e-12) 179 | return f_lcs, p_lcs, r_lcs 180 | 181 | 182 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 183 | """ 184 | Computes ROUGE-L (sentence level) of two text collections of sentences. 185 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 186 | rouge-working-note-v1.3.1.pdf 187 | Calculated according to: 188 | R_lcs = LCS(X,Y)/m 189 | P_lcs = LCS(X,Y)/n 190 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 191 | where: 192 | X = reference summary 193 | Y = Candidate summary 194 | m = length of reference summary 195 | n = length of candidate summary 196 | Args: 197 | evaluated_sentences: The sentences that have been picked by the summarizer 198 | reference_sentences: The sentences from the referene set 199 | Returns: 200 | A float: F_lcs 201 | Raises: 202 | ValueError: raises exception if a param has len <= 0 203 | """ 204 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 205 | raise ValueError("Collections must contain at least 1 sentence.") 206 | reference_words = _split_into_words(reference_sentences) 207 | evaluated_words = _split_into_words(evaluated_sentences) 208 | m = len(reference_words) 209 | n = len(evaluated_words) 210 | lcs = _len_lcs(evaluated_words, reference_words) 211 | return _f_p_r_lcs(lcs, m, n) 212 | 213 | 214 | def _union_lcs(evaluated_sentences, reference_sentence): 215 | """ 216 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 217 | subsequence between reference sentence ri and candidate summary C. For example 218 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 219 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 220 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 221 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 222 | LCS_u(r_i, C) = 4/5. 223 | Args: 224 | evaluated_sentences: The sentences that have been picked by the summarizer 225 | reference_sentence: One of the sentences in the reference summaries 226 | Returns: 227 | float: LCS_u(r_i, C) 228 | ValueError: 229 | Raises exception if a param has len <= 0 230 | """ 231 | if len(evaluated_sentences) <= 0: 232 | raise ValueError("Collections must contain at least 1 sentence.") 233 | 234 | lcs_union = set() 235 | reference_words = _split_into_words([reference_sentence]) 236 | combined_lcs_length = 0 237 | for eval_s in evaluated_sentences: 238 | evaluated_words = _split_into_words([eval_s]) 239 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 240 | combined_lcs_length += len(lcs) 241 | lcs_union = lcs_union.union(lcs) 242 | 243 | union_lcs_count = len(lcs_union) 244 | union_lcs_value = union_lcs_count / combined_lcs_length 245 | return union_lcs_value 246 | 247 | 248 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 249 | """ 250 | Computes ROUGE-L (summary level) of two text collections of sentences. 251 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 252 | rouge-working-note-v1.3.1.pdf 253 | Calculated according to: 254 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 255 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 256 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 257 | where: 258 | SUM(i,u) = SUM from i through u 259 | u = number of sentences in reference summary 260 | C = Candidate summary made up of v sentences 261 | m = number of words in reference summary 262 | n = number of words in candidate summary 263 | Args: 264 | evaluated_sentences: The sentences that have been picked by the summarizer 265 | reference_sentence: One of the sentences in the reference summaries 266 | Returns: 267 | A float: F_lcs 268 | Raises: 269 | ValueError: raises exception if a param has len <= 0 270 | """ 271 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 272 | raise ValueError("Collections must contain at least 1 sentence.") 273 | 274 | # total number of words in reference sentences 275 | m = len(_split_into_words(reference_sentences)) 276 | 277 | # total number of words in evaluated sentences 278 | n = len(_split_into_words(evaluated_sentences)) 279 | 280 | union_lcs_sum_across_all_references = 0 281 | for ref_s in reference_sentences: 282 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 283 | ref_s) 284 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 285 | 286 | 287 | def rouge(hypotheses, references): 288 | """Calculates average rouge scores for a list of hypotheses and 289 | references""" 290 | 291 | # Filter out hyps that are of 0 length 292 | # hyps_and_refs = zip(hypotheses, references) 293 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 294 | # hypotheses, references = zip(*hyps_and_refs) 295 | 296 | # Calculate ROUGE-1 F1, precision, recall scores 297 | rouge_1 = [ 298 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 299 | ] 300 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 301 | 302 | # Calculate ROUGE-2 F1, precision, recall scores 303 | rouge_2 = [ 304 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 305 | ] 306 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 307 | 308 | # Calculate ROUGE-L F1, precision, recall scores 309 | rouge_l = [ 310 | rouge_l_sentence_level([hyp], [ref]) 311 | for hyp, ref in zip(hypotheses, references) 312 | ] 313 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 314 | 315 | return { 316 | "rouge_1/f_score": rouge_1_f, 317 | "rouge_1/r_score": rouge_1_r, 318 | "rouge_1/p_score": rouge_1_p, 319 | "rouge_2/f_score": rouge_2_f, 320 | "rouge_2/r_score": rouge_2_r, 321 | "rouge_2/p_score": rouge_2_p, 322 | "rouge_l/f_score": rouge_l_f, 323 | "rouge_l/r_score": rouge_l_r, 324 | "rouge_l/p_score": rouge_l_p, 325 | } -------------------------------------------------------------------------------- /Utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import math 6 | import numpy as np 7 | from Utils.bleu import compute_bleu 8 | from Utils.rouge import rouge 9 | from Utils.log import trace 10 | 11 | def aeq(*args): 12 | """ 13 | Assert all arguments have the same value 14 | """ 15 | arguments = (arg for arg in args) 16 | first = next(arguments) 17 | assert all(arg == first for arg in arguments), \ 18 | "Not all arguments have the same value: " + str(args) 19 | 20 | 21 | def sequence_mask(lengths, max_len=None): 22 | """ 23 | Creates a boolean mask from sequence lengths. 24 | """ 25 | batch_size = lengths.numel() 26 | max_len = max_len or lengths.max() 27 | return (torch.arange(0, max_len) 28 | .type_as(lengths) 29 | .repeat(batch_size, 1) 30 | .lt(lengths.unsqueeze(1))) 31 | 32 | 33 | 34 | def check_save_path(path): 35 | save_path = os.path.abspath(path) 36 | dirname = os.path.dirname(save_path) 37 | if not os.path.exists(dirname): 38 | os.makedirs(dirname) 39 | 40 | def check_file_exist(path): 41 | 42 | dirname = os.path.dirname(save_path) 43 | if not os.path.isfile(dirname): 44 | os.makedirs(dirname) 45 | 46 | 47 | def report_bleu(reference_corpus, translation_corpus): 48 | 49 | bleu, precions, bp, ratio, trans_length, ref_length =\ 50 | compute_bleu([[x] for x in reference_corpus], translation_corpus) 51 | trace("BLEU: %.2f [%.2f/%.2f/%.2f/%.2f] Pred_len:%d, Ref_len:%d"%( 52 | bleu*100, *precions, trans_length, ref_length)) 53 | 54 | 55 | def report_rouge(reference_corpus, translation_corpus): 56 | 57 | scores = rouge([" ".join(x) for x in translation_corpus], 58 | [" ".join(x) for x in reference_corpus]) 59 | 60 | 61 | trace("ROUGE-1:%.2f, ROUGE-2:%.2f"%( 62 | scores["rouge_1/f_score"]*100, scores["rouge_2/f_score"]*100)) -------------------------------------------------------------------------------- /config/rnn.ini: -------------------------------------------------------------------------------- 1 | ################################### 2 | # NMT configuration # 3 | ################################### 4 | 5 | [DEFAULT] 6 | 7 | system = RNN 8 | src_lang = zh 9 | trg_lang = ja 10 | id = 01 11 | user = XXXX 12 | category = baseline 13 | prefix = /home/%(User)s/workspace/result/%(src_lang)s-%(trg_lang)s 14 | workspace = %(prefix)s/%(system)s.%(src_lang)s-%(trg_lang)s.%(Category)s.%(id)s 15 | output = %(workspace)s/output 16 | save_log = %(workspace)s/log 17 | save_model = %(workspace)s/model 18 | save_vocab = %(workspace)s/vocab 19 | 20 | 21 | 22 | [Data] 23 | max_seq_len = 50 24 | data_path = /home/%(User)s/workspace/ASPEC-JC.clean/Juman+Stanford 25 | train_dataset = train 26 | valid_dataset = dev 27 | test_dataset = test 28 | 29 | [GPU] 30 | use_gpu = [0] 31 | 32 | [Embedding] 33 | min_freq = 1 34 | src_embed_dim = 512 35 | trg_embed_dim = 512 36 | #share_vocab = True 37 | #share_embedding = True 38 | 39 | [Train] 40 | epochs = 10 41 | train_batch_size = 32 42 | valid_batch_size = 16 43 | max_decrease_steps = 30 44 | report_every = 50 45 | 46 | [Optimizer] 47 | lr = 4e-4 48 | optim = Adam 49 | max_grad_norm = 5 50 | warmup_steps = 8000 51 | adam_beta2 = 0.99 52 | grad_accum_count = 1 53 | 54 | [Network] 55 | dropout = 0.1 56 | enc_num_layers = 4 57 | dec_num_layers = 4 58 | hidden_size = 512 59 | attn_type = general 60 | 61 | [RNN] 62 | bidirectional = True 63 | rnn_type = GRU 64 | 65 | [Translate] 66 | test_batch_size = 32 67 | use_beam_search = False 68 | k_best = 1 69 | beam_size = 5 70 | replace_unk = True 71 | -------------------------------------------------------------------------------- /config/transformer.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | system = Transformer 3 | src_lang = en 4 | trg_lang = tr 5 | id = 03 6 | user = XXXX 7 | category = baseline 8 | prefix = /home/%(User)s/workspace/result/%(src_lang)s-%(trg_lang)s 9 | workspace = %(prefix)s/%(system)s.%(src_lang)s-%(trg_lang)s.%(Category)s.%(id)s 10 | output = %(workspace)s/output 11 | save_log = %(workspace)s/log 12 | save_model = %(workspace)s/model 13 | save_vocab = %(workspace)s/vocab 14 | 15 | [Data] 16 | max_seq_len = 100 17 | data_path = /home/%(User)s/workspace/WMT2018/en-tr/SPM 18 | train_dataset = train 19 | valid_dataset = test2017 20 | test_dataset = test2017 21 | 22 | [GPU] 23 | use_gpu = [0] 24 | 25 | [Embedding] 26 | min_freq = 1 27 | src_embed_dim = 512 28 | trg_embed_dim = 512 29 | #share_vocab = True 30 | #share_embedding = True 31 | 32 | [Train] 33 | epochs = 10 34 | train_batch_size = 64 35 | valid_batch_size = 16 36 | max_decrease_steps = 30 37 | report_every = 50 38 | 39 | [Optimizer] 40 | lr = 2 41 | optim = Adam 42 | max_grad_norm = 0 43 | decay_method = noam 44 | warmup_steps = 8000 45 | adam_beta2 = 0.98 46 | grad_accum_count = 4 47 | 48 | [Network] 49 | dropout = 0.1 50 | enc_num_layers = 4 51 | dec_num_layers = 4 52 | hidden_size = 512 53 | attn_type = general 54 | 55 | [Transformer] 56 | num_heads = 8 57 | inner_hidden_size = 2048 58 | 59 | [Translate] 60 | test_batch_size = 32 61 | use_beam_search = False 62 | k_best = 1 63 | beam_size = 5 64 | replace_unk = True 65 | 66 | -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | import math 5 | import codecs 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from itertools import count 10 | 11 | 12 | from Utils.log import trace 13 | from Utils.config import Config 14 | from Utils.DataLoader import DataBatchIterator 15 | from NMT.ModelFactory import model_factory 16 | from NMT.CheckPoint import CheckPoint 17 | from NMT.Trainer import Statistics 18 | from NMT.translate import BatchTranslator 19 | from Utils.plot import plot_attn 20 | from Utils.utils import report_bleu 21 | from Utils.utils import report_rouge 22 | from train import load_dataset 23 | from NMT.CheckPoint import dump_checkpoint 24 | 25 | def main(): 26 | """main function for checkpoint ensemble.""" 27 | config = Config("ensemble", training=True) 28 | trace(config) 29 | torch.backends.cudnn.benchmark = True 30 | 31 | train_data = load_dataset(config.train_dataset, 32 | config.train_batch_size, 33 | config, prefix="Training:") 34 | 35 | # Build model. 36 | vocab = train_data.get_vocab() 37 | model = model_factory(config, config.checkpoint, *vocab) 38 | cp = CheckPoint(config.checkpoint) 39 | model.load_state_dict(cp.state_dict['model']) 40 | dump_checkpoint(model, config.save_model, ".ensemble") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/clean_synthetic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash $0 prefix LANG_F [=en] LANG_E" 4 | 5 | prefix=$1 6 | LANG_F=$2 7 | LANG_E=$3 8 | TASK=${LANG_F}-${LANG_E} 9 | spm_encode=spm_encode 10 | spm_model=/WMT2018/$TASK/$TASK.spm.model 11 | truecase_model=/path-to/$TASK/true 12 | MOSES_SCRIPT=/path-to/mosesdecoder-RELEASE-2.1.1/scripts 13 | 14 | 15 | mkdir -p SPM 16 | 17 | cat $prefix.${LANG_F} \ 18 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_F} \ 19 | | ${MOSES_SCRIPT}/tokenizer/tokenizer.perl -a -l ${LANG_F} \ 20 | | ${MOSES_SCRIPT}/recaser/truecase.perl -model ${truecase_model}/truecase-model.${LANG_F} \ 21 | | $spm_encode --model=$spm_model --output_format=piece \ 22 | > SPM/synthetic.${LANG_E}-${LANG_F}.${LANG_F} 23 | cat $prefix.${LANG_E} \ 24 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_E} \ 25 | | ${MOSES_SCRIPT}/tokenizer/tokenizer.perl -a -l ${LANG_E} \ 26 | | ${MOSES_SCRIPT}/recaser/truecase.perl -model ${truecase_model}/truecase-model.${LANG_E} \ 27 | | $spm_encode --model=$spm_model --output_format=piece \ 28 | > SPM/synthetic.${LANG_F}-${LANG_E}.${LANG_E} 29 | 30 | -------------------------------------------------------------------------------- /scripts/data-processing-test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash dataPost-processing.sh WORKSPACE OUTPUT_FILE LANG_F LANG_E" 4 | LANG_F=$3 5 | LANG_E=$4 6 | TASK=${LANG_F}-${LANG_E} 7 | 8 | MOSES_SCRIPT=/itigo/files/Tools/accessByOldOrganization/TranslationEngines/mosesdecoder-RELEASE-2.1.1/scripts 9 | #Evaluation 10 | EVAL_SCRIPTS=/itigo/files/Tools/accessByOldOrganization/MTEvaluation 11 | 12 | WORKSPACE=$1 13 | OUTPUT=$2 14 | 15 | for lang in $src $tgt; do 16 | side="ref" 17 | if [ $lang == $tgt ]; then 18 | side="src" 19 | fi 20 | ${MOSES_SCRIPT}/ems/support/input-from-sgm.perl < $dev_dir/news$devset-${LANG_E}${LANG_F}-$side.$lang.sgm 21 | 22 | 23 | mkdir -p ${WORKSPACE}/eval 24 | 25 | cp $CORPUS/test.${LANG_E} ${WORKSPACE}/eval/ref.${LANG_E} 26 | cp ${WORKSPACE}/${OUTPUT} ${WORKSPACE}/eval/pred.${LANG_E} 27 | 28 | 29 | cd ${WORKSPACE}/eval 30 | 31 | for file in pred; do 32 | cat ${file}.${LANG_E} |\ 33 | sed -r 's/(@@ )|(@@ ?$)//g'\ 34 | > detok 35 | done 36 | 37 | # For BLEU 38 | perl ${MOSES_SCRIPT}/generic/multi-bleu.perl ref.${LANG_E} < detok 39 | 40 | 41 | -------------------------------------------------------------------------------- /scripts/dataPost-processing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash dataPost-processing.sh WORKSPACE OUTPUT_FILE LANG_F LANG_E" 4 | LANG_F=$3 5 | LANG_E=$4 6 | TASK=${LANG_F}-${LANG_E} 7 | 8 | MOSES_SCRIPT=/itigo/files/Tools/accessByOldOrganization/TranslationEngines/mosesdecoder-RELEASE-2.1.1/scripts 9 | #Evaluation 10 | EVAL_SCRIPTS=/itigo/files/Tools/accessByOldOrganization/MTEvaluation 11 | 12 | WORKSPACE=$1 13 | OUTPUT=$2 14 | 15 | if [ ${LANG_F} == 'en' ]; then 16 | CORPUS=/itigo/Uploads/WMT2018/${LANG_F}-${LANG_E}/orig_clean 17 | else 18 | CORPUS=/itigo/Uploads/WMT2018/${LANG_E}-${LANG_F}/orig_clean 19 | fi 20 | 21 | mkdir -p ${WORKSPACE}/eval 22 | 23 | cp $CORPUS/test.${LANG_E} ${WORKSPACE}/eval/ref.${LANG_E} 24 | cp ${WORKSPACE}/${OUTPUT} ${WORKSPACE}/eval/pred.${LANG_E} 25 | 26 | 27 | cd ${WORKSPACE}/eval 28 | 29 | for file in pred; do 30 | cat ${file}.${LANG_E} |\ 31 | sed -r 's/(@@ )|(@@ ?$)//g'\ 32 | > detok 33 | done 34 | 35 | # For BLEU 36 | perl ${MOSES_SCRIPT}/generic/multi-bleu.perl ref.${LANG_E} < detok 37 | 38 | 39 | -------------------------------------------------------------------------------- /scripts/eval.test2017.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash eval.sh OUTPUT_FILE LANG_F LANG_E" 4 | 5 | OUTPUT=$1 6 | LANG_F=$2 7 | LANG_E=$3 8 | TASK=${LANG_F}-${LANG_E} 9 | 10 | MOSES_SCRIPT=/itigo/files/Tools/accessByOldOrganization/TranslationEngines/mosesdecoder-RELEASE-2.1.1/scripts 11 | wmt17=/itigo/Uploads/WMT2018/wmt17-submitted-data 12 | spm_decode=spm_decode 13 | 14 | if [[ ${LANG_F} != 'en' ]]; then 15 | spm_model=/itigo/Uploads/WMT2018/${LANG_E}-${LANG_F}/${LANG_E}-${LANG_F}.spm.model 16 | else 17 | spm_model=/itigo/Uploads/WMT2018/${LANG_F}-${LANG_E}/${LANG_F}-${LANG_E}.spm.model 18 | fi 19 | mteavl=${MOSES_SCRIPT}/generic/mteval-v13a.pl 20 | src=${wmt17}/sgm/sources/newstest2017-${LANG_F}${LANG_E}-src.${LANG_F}.sgm 21 | ref=${wmt17}/sgm/references/newstest2017-${LANG_F}${LANG_E}-ref.${LANG_E}.sgm 22 | 23 | 24 | mkdir -p eval 25 | 26 | # cat $OUTPUT \ 27 | # > eval/decoder-output.sgm 28 | 29 | # | sed 's/\@\@ //g' \ 30 | # | ${MOSES_SCRIPT}/recaser/detruecase.perl\ 31 | # | ${MOSES_SCRIPT}/tokenizer/detokenizer.perl -l ${LANG_E}\ 32 | # | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_E}\ 33 | # | ${MOSES_SCRIPT}/ems/support/wrap-xml.perl ${LANG_E} $src WAU\ 34 | 35 | 36 | cat $OUTPUT \ 37 | | $spm_decode --model $spm_model --input_format=piece \ 38 | | ${MOSES_SCRIPT}/recaser/detruecase.perl\ 39 | | ${MOSES_SCRIPT}/tokenizer/detokenizer.perl -l ${LANG_E}\ 40 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_E}\ 41 | | ${MOSES_SCRIPT}/ems/support/wrap-xml.perl ${LANG_E} $src WAU\ 42 | > eval/decoder-output.newstest2017.sgm 43 | # For BLEU 44 | perl $mteavl -s $src -r $ref -t eval/decoder-output.newstest2017.sgm 45 | 46 | 47 | -------------------------------------------------------------------------------- /scripts/eval.test2018.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash eval.sh OUTPUT_FILE LANG_F LANG_E" 4 | 5 | OUTPUT=$1 6 | LANG_F=$2 7 | LANG_E=$3 8 | TASK=${LANG_F}-${LANG_E} 9 | 10 | MOSES_SCRIPT=/path-to/mosesdecoder-RELEASE-2.1.1/scripts 11 | wmt18=/path-to/WMT2018/test2018 12 | 13 | spm_decode=spm_decode 14 | 15 | if [[ ${LANG_F} != 'en' ]]; then 16 | spm_model=/path-to/WMT2018/${LANG_E}-${LANG_F}/${LANG_E}-${LANG_F}.spm.model 17 | else 18 | spm_model=/path-to/WMT2018/${LANG_F}-${LANG_E}/${LANG_F}-${LANG_E}.spm.model 19 | fi 20 | 21 | 22 | mteavl=${MOSES_SCRIPT}/generic/mteval-v13a.pl 23 | 24 | 25 | src=${wmt18}/newstest2018-${LANG_F}${LANG_E}-src-ts.${LANG_F}.sgm 26 | 27 | mkdir -p eval 28 | cat $OUTPUT \ 29 | | $spm_decode --model $spm_model --input_format=piece \ 30 | | ${MOSES_SCRIPT}/recaser/detruecase.perl\ 31 | | ${MOSES_SCRIPT}/tokenizer/detokenizer.perl -l ${LANG_E}\ 32 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_E}\ 33 | | ${MOSES_SCRIPT}/ems/support/wrap-xml.perl ${LANG_E} $src WAU\ 34 | >eval/decoder-output.newstest2018.sgm 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/extract_raw_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | echo "Usage: bash dataPost-processing.sh WORKSPACE OUTPUT_FILE LANG_F LANG_E" 4 | LANG_F=$2 5 | LANG_E=$3 6 | TASK=${LANG_F}-${LANG_E} 7 | 8 | MOSES_SCRIPT=/itigo/files/Tools/accessByOldOrganization/TranslationEngines/mosesdecoder-RELEASE-2.1.1/scripts 9 | 10 | CORPUS=$1 11 | 12 | 13 | truecase_model=/itigo/Uploads/WMT2018/${LANG_F}-${LANG_E}/true 14 | 15 | 16 | for prefix in newstest2018; do 17 | perl ${MOSES_SCRIPT}/ems/support/input-from-sgm.perl \ 18 | < ${CORPUS}/$prefix-${LANG_F}${LANG_E}-src-ts.${LANG_F}.sgm \ 19 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_F} \ 20 | | ${MOSES_SCRIPT}/tokenizer/tokenizer.perl -a -l ${LANG_F} \ 21 | | ${MOSES_SCRIPT}/recaser/truecase.perl -model ${truecase_model}/truecase-model.${LANG_F} \ 22 | > ${CORPUS}/../${LANG_F}-${LANG_E}/newstest2018.${LANG_F} 23 | 24 | perl ${MOSES_SCRIPT}/ems/support/input-from-sgm.perl \ 25 | < ${CORPUS}/$prefix-${LANG_E}${LANG_F}-src-ts.${LANG_E}.sgm \ 26 | | ${MOSES_SCRIPT}/tokenizer/normalize-punctuation.perl -l ${LANG_E} \ 27 | | ${MOSES_SCRIPT}/tokenizer/tokenizer.perl -a -l ${LANG_E} \ 28 | | ${MOSES_SCRIPT}/recaser/truecase.perl -model ${truecase_model}/truecase-model.${LANG_E} \ 29 | > ${CORPUS}/../${LANG_F}-${LANG_E}/newstest2018.${LANG_E} 30 | 31 | done -------------------------------------------------------------------------------- /scripts/measure_bleu.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } -------------------------------------------------------------------------------- /scripts/train_bpe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | 4 | BPE_SCRIPTS=/path-to/BPE 5 | CORPUS=/path-to/WMT2018/en-tr/orig_clean 6 | MONO_CORPUS=/path-to/WMT2018/en-tr/mono_data 7 | mkdir -p BPE 8 | ${BPE_SCRIPTS}/learn_joint_bpe_and_vocab.py --input $CORPUS/train.$1 $CORPUS/train.$2 --min-frequency 5 --symbols 30000 --output BPE/vocab.$1-$2 --write-vocabulary BPE/vocab.$1 BPE/vocab.$2 9 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/train.$1 --codes BPE/vocab.$1-$2 --output BPE/train.$1 10 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/dev.$1 --codes BPE/vocab.$1-$2 --output BPE/dev.$1 11 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/test.$1 --codes BPE/vocab.$1-$2 --output BPE/test.$1 12 | 13 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/train.$2 --codes BPE/vocab.$1-$2 --output BPE/train.$2 14 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/dev.$2 --codes BPE/vocab.$1-$2 --output BPE/dev.$2 15 | ${BPE_SCRIPTS}/apply_bpe.py --input $CORPUS/test.$2 --codes BPE/vocab.$1-$2 --output BPE/test.$2 16 | 17 | # ${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.orig.$1 --codes BPE/vocab.$1-$2 --output BPE/synthetic.orig.$1 18 | # ${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.orig.$2 --codes BPE/vocab.$1-$2 --output BPE/synthetic.orig.$2 19 | 20 | #${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.orig.2M.$1 --codes BPE/vocab.$1-$2 --output BPE/synthetic.orig.2M.$1 21 | #${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.orig.2M.$2 --codes BPE/vocab.$1-$2 --output BPE/synthetic.orig.2M.$2 22 | 23 | 24 | #${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.trans.$1 --codes BPE/vocab.$1-$2 --output BPE/synthetic.trans.$1-$2.$1 25 | #cp BPE/synthetic.orig.$2 BPE/synthetic.trans.$1-$2.$2 26 | #${BPE_SCRIPTS}/apply_bpe.py --input $MONO_CORPUS/synthetic.orig.$2 --codes BPE/vocab.$1-$2 --output BPE/synthetic.orig.$2 27 | 28 | -------------------------------------------------------------------------------- /scripts/train_spm.sh: -------------------------------------------------------------------------------- 1 | src=$1 2 | trg=$2 3 | spm_train=spm_train 4 | spm_decode=spm_decode 5 | spm_encode=spm_encode 6 | 7 | mkdir -p SPM 8 | $spm_train --add_dummy_prefix False --input corpus --vocab_size=16000 --model_prefix $src-$trg.spm 9 | 10 | for lang in $1 $2; do 11 | $spm_encode --model=$src-$trg.spm.model --output_format=piece < train.$lang > SPM/train.$lang 12 | $spm_encode --model=$src-$trg.spm.model --output_format=piece < dev2018.$lang > SPM/dev2018.$lang 13 | $spm_encode --model=$src-$trg.spm.model --output_format=piece < newstest2018.$lang > SPM/newstest2018.$lang 14 | 15 | done -------------------------------------------------------------------------------- /scripts/train_word2vec.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # author: Hao WANG 3 | 4 | LANG1=$1 5 | LANG2=$2 6 | EMBEDED_DIM=500 7 | TASK=$1-$2 8 | word2vec=/itigo/files/Tools/accessByOldOrganization/Word2Vec/word2vec 9 | Data=/itigo/Uploads/ASPEC-JC.clean/BPE 10 | 11 | 12 | 13 | $word2vec -train $Data/train.$LANG1 -size ${EMBEDED_DIM} -output w2v.$LANG1 14 | 15 | $word2vec -train $Data/train.$LANG2 -size ${EMBEDED_DIM} -output w2v.$LANG2 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import glob 6 | import random 7 | import argparse 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from Utils.log import trace 13 | from Utils.config import Config 14 | from Utils.DataLoader import DataBatchIterator 15 | from Utils.DataLoader import PAD_WORD 16 | 17 | 18 | from NMT import Trainer 19 | from NMT import Statistics 20 | from NMT import model_factory 21 | from NMT import dump_checkpoint 22 | 23 | 24 | def main(): 25 | # Load config. 26 | config = Config("train", training=True) 27 | trace(config) 28 | torch.backends.cudnn.benchmark = True 29 | 30 | # Load train dataset. 31 | train_data = load_dataset( 32 | config.train_dataset, 33 | config.train_batch_size, 34 | config, prefix="Training:") 35 | 36 | # Load valid dataset. 37 | valid_data = load_dataset( 38 | config.valid_dataset, 39 | config.valid_batch_size, 40 | config, prefix="Validation:") 41 | 42 | # Build model. 43 | vocab = train_data.get_vocab() 44 | model = model_factory(config, 45 | config.checkpoint, *vocab) 46 | if config.verbose: trace(model) 47 | 48 | # start training 49 | trg_vocab = train_data.trg_vocab 50 | padding_idx = trg_vocab.padding_idx 51 | trainer = Trainer(model, trg_vocab, padding_idx, config) 52 | start_epoch = 1 53 | for epoch in range(start_epoch, config.epochs + 1): 54 | trainer.train(epoch, config.epochs, 55 | train_data, valid_data, 56 | train_data.num_batches) 57 | dump_checkpoint(trainer.model, config.save_model) 58 | 59 | 60 | def load_dataset(dataset, batch_size, config, prefix): 61 | # Load training/validation dataset. 62 | train_src = os.path.join( 63 | config.data_path, dataset + "." + config.src_lang) 64 | train_trg = os.path.join( 65 | config.data_path, dataset + "." + config.trg_lang) 66 | train_data = DataBatchIterator( 67 | train_src, train_trg, 68 | share_vocab=config.share_vocab, 69 | training=config.training, 70 | shuffle=config.shuffle_data, 71 | batch_size=batch_size, 72 | max_length=config.max_seq_len, 73 | vocab=config.save_vocab, 74 | mini_batch_sort_order=config.mini_batch_sort_order) 75 | trace(prefix, train_data) 76 | return train_data 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | import math 5 | import codecs 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from itertools import count 10 | 11 | 12 | from Utils.log import trace 13 | from Utils.config import Config 14 | from Utils.DataLoader import DataBatchIterator 15 | from NMT.ModelFactory import model_factory 16 | 17 | #from NMT.Loss import LossBase 18 | from NMT.Trainer import Statistics 19 | from NMT.translate import BatchTranslator 20 | from Utils.plot import plot_attn 21 | from Utils.utils import report_bleu 22 | from Utils.utils import report_rouge 23 | from train import load_dataset 24 | 25 | 26 | def main(): 27 | config = Config("translate", training=False) 28 | if config.verbose: trace(config) 29 | torch.backends.cudnn.benchmark = True 30 | 31 | test_data = load_dataset(config.test_dataset, 32 | config.test_batch_size, 33 | config, prefix="Translate:") 34 | 35 | 36 | # Build model. 37 | vocab = test_data.get_vocab() 38 | pred_file = codecs.open(config.output+".pred.txt", 'w', 'utf-8') 39 | 40 | 41 | model = model_factory(config, config.checkpoint, *vocab) 42 | translator = BatchTranslator(model, config, test_data.src_vocab, test_data.trg_vocab) 43 | 44 | 45 | # Statistics 46 | counter = count(1) 47 | pred_list = [] 48 | gold_list = [] 49 | for batch in tqdm(iter(test_data), total=test_data.num_batches): 50 | 51 | batch_trans = translator.translate(batch) 52 | 53 | for trans in batch_trans: 54 | if config.verbose: 55 | sent_number = next(counter) 56 | trace(trans.pprint(sent_number)) 57 | 58 | if config.plot_attn: 59 | plot_attn(trans.src, trans.preds[0], trans.attns[0].cpu()) 60 | 61 | pred_file.write(" ".join(trans.preds[0]) + "\n") 62 | pred_list.append(trans.preds[0]) 63 | gold_list.append(trans.gold) 64 | report_bleu(gold_list, pred_list) 65 | report_rouge(gold_list, pred_list) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | --------------------------------------------------------------------------------