├── .gitignore ├── NMT ├── Loss.py ├── ModelConstructor.py ├── Models │ ├── Decoders.py │ ├── Encoders.py │ ├── NMTModel.py │ ├── VNMTModel.py │ ├── VRNMTModel.py │ └── __init__.py ├── Modules │ ├── GlobalAttention.py │ ├── StackedRNN.py │ └── __init__.py ├── Optimizer.py ├── Statistics.py ├── Trainer.py ├── __init__.py └── translate │ ├── Beam.py │ ├── Penalties.py │ ├── Translation.py │ ├── Translator.py │ └── __init__.py ├── README.md ├── Utils ├── DataLoader.py ├── args.py ├── bleu.py ├── config.py ├── plot.py ├── rouge.py └── utils.py ├── config ├── nmt.ini ├── vnmt.ini └── vrnmt.ini ├── train.py └── translate.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | !/.gitignore 3 | *.pyc 4 | scripts 5 | __pycache__/* 6 | __pycache__ -------------------------------------------------------------------------------- /NMT/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import NMT 6 | import Utils 7 | from NMT.Trainer import Trainer 8 | from NMT.Statistics import Statistics 9 | 10 | import torch.nn.functional as F 11 | from Utils.utils import trace 12 | 13 | class NMTLoss(nn.Module): 14 | """ 15 | Standard NMT Loss Computation. 16 | """ 17 | def __init__(self, config, padding_idx): 18 | super(NMTLoss, self).__init__() 19 | self.padding_idx = padding_idx 20 | self.criterion = nn.CrossEntropyLoss( 21 | ignore_index=self.padding_idx, size_average=False) 22 | self.kld_weight = config.kld_weight 23 | self.kld_increase = 0.01 24 | def kld_weight_step(self, epoch, start_increase_kld_at=8): 25 | pass 26 | # if self.kld_weight < 0.1 and epoch >= start_increase_kld_at: 27 | # self.kld_weight += self.kld_increase 28 | # trace("Increase KLD weight to %.2f" % self.kld_weight) 29 | 30 | def compute_batch_loss(self, probs, golds, normalization, kld_loss): 31 | """Compute the forward loss and backpropagate. 32 | 33 | Args: 34 | probs (FloatTensor) : distribution of output model `[(trg_len x batch) x V]` 35 | golds (LongTensor) : target examples 36 | output (FloatTensor) : output of decoder model `[trg_len x batch x hidden]` 37 | 38 | Returns: 39 | :`NMT.Statistics`: validation loss statistics 40 | 41 | """ 42 | ce_loss = self.criterion(probs, golds.view(-1)) 43 | 44 | loss = ce_loss + self.kld_weight * kld_loss 45 | loss = loss.div(normalization) 46 | loss_dict = { 47 | "CELoss": float(ce_loss)/normalization, 48 | "KLDLoss":float(kld_loss)/normalization 49 | } 50 | del ce_loss, kld_loss 51 | batch_stats = self.create_stats(float(loss), probs, golds.view(-1), loss_dict) 52 | return loss, batch_stats 53 | 54 | def create_stats(self, loss, probs, golds, loss_dict): 55 | """ 56 | Args: 57 | loss (`FloatTensor`): the loss computed by the loss criterion. 58 | scores (`FloatTensor`): a score for each possible output 59 | target (`FloatTensor`): true targets 60 | 61 | Returns: 62 | `Statistics` : statistics for this batch. 63 | """ 64 | preds = probs.data.topk(1, dim=-1)[1] 65 | non_padding = golds.ne(self.padding_idx) 66 | correct = preds.squeeze().eq(golds).masked_select(non_padding) 67 | num_words = non_padding.long().sum() 68 | num_correct = correct.long().sum() 69 | return Statistics( 70 | float(loss), int(num_words), 71 | int(num_correct), loss_dict) -------------------------------------------------------------------------------- /NMT/ModelConstructor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is for models creation, which consults options 3 | and creates each encoder and decoder accordingly. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | from NMT.Models import NMTModel 10 | from NMT.Models import VNMTModel 11 | from NMT.Models import VRNMTModel 12 | from NMT.Models import PackedRNNEncoder 13 | from NMT.Models import RNNEncoder 14 | from NMT.Models import InputFeedRNNDecoder 15 | from NMT.Models import VarInputFeedRNNDecoder 16 | 17 | from torch.nn.init import xavier_uniform 18 | from Utils.DataLoader import PAD_WORD 19 | from Utils.utils import trace, aeq 20 | 21 | def make_embeddings(vocab_size, embed_dim, dropout, padding_idx): 22 | """ 23 | Make an Embeddings instance. 24 | Args: 25 | config: global configuration settings. 26 | vocab (Vocab): words dictionary. 27 | feature_dicts([Vocab], optional): a list of feature dictionary. 28 | for_encoder(bool): make Embeddings for encoder or decoder? 29 | """ 30 | return nn.Embedding(vocab_size, 31 | embed_dim, 32 | padding_idx=padding_idx, 33 | max_norm=None, 34 | norm_type=2, 35 | scale_grad_by_freq=False, 36 | sparse=False) 37 | # return nn.Embeddings(embed_dim, 38 | # position_encoding=False, 39 | # dropout=dropout, 40 | # word_padding_idx=padding_idx, 41 | # word_vocab_size=vocab_size, 42 | # sparse=True) 43 | 44 | def model_factory(config, src_vocab, trg_vocab, train_mode=True, checkpoint=None): 45 | 46 | # Make embedding. 47 | padding_idx = src_vocab.stoi[PAD_WORD] 48 | src_embeddings = make_embeddings(src_vocab.vocab_size, config.trg_embed_dim, config.dropout, padding_idx) 49 | 50 | padding_idx = trg_vocab.stoi[PAD_WORD] 51 | trg_embeddings = make_embeddings(trg_vocab.vocab_size, config.src_embed_dim, config.dropout, padding_idx) 52 | 53 | # Make NMT Model (= encoder + decoder). 54 | if config.system == "NMT": 55 | 56 | encoder = PackedRNNEncoder( 57 | config.rnn_type, 58 | config.src_embed_dim, 59 | config.hidden_size, config.enc_num_layers, 60 | config.dropout, config.bidirectional) 61 | decoder = InputFeedRNNDecoder( 62 | config.rnn_type, config.trg_embed_dim, config.hidden_size, 63 | config.dec_num_layers, config.attn_type, 64 | config.bidirectional, config.dropout) 65 | model = NMTModel( 66 | encoder, decoder, 67 | src_embeddings, trg_embeddings, 68 | trg_vocab.vocab_size, config) 69 | 70 | elif config.system == "VNMT": 71 | encoder = PackedRNNEncoder( 72 | config.rnn_type, 73 | config.src_embed_dim, 74 | config.hidden_size, 75 | config.enc_num_layers, 76 | config.dropout, 77 | config.bidirectional) 78 | 79 | decoder = InputFeedRNNDecoder( 80 | config.rnn_type, 81 | config.trg_embed_dim+config.latent_size, 82 | config.hidden_size, 83 | config.dec_num_layers, 84 | config.attn_type, 85 | config.bidirectional, 86 | config.dropout) 87 | 88 | model = VNMTModel( 89 | encoder, decoder, 90 | src_embeddings, trg_embeddings, 91 | trg_vocab.vocab_size, 92 | config) 93 | 94 | elif config.system == "VRNMT": 95 | encoder = PackedRNNEncoder( 96 | config.rnn_type, 97 | config.src_embed_dim, 98 | config.hidden_size, 99 | config.enc_num_layers, 100 | config.dropout, 101 | config.bidirectional) 102 | 103 | decoder = VarInputFeedRNNDecoder( 104 | config.rnn_type, 105 | config.trg_embed_dim, 106 | config.latent_size, 107 | config.hidden_size, 108 | config.dec_num_layers, 109 | config.attn_type, 110 | config.bidirectional, 111 | config.dropout) 112 | 113 | model = VRNMTModel( 114 | encoder, decoder, 115 | src_embeddings, trg_embeddings, 116 | trg_vocab.vocab_size, 117 | config) 118 | 119 | 120 | if checkpoint is not None: 121 | trace('Loading model parameters.') 122 | model.load_state_dict(checkpoint['model']) 123 | 124 | # Load the model states from checkpoint or initialize them. 125 | if train_mode and config.param_init != 0.0: 126 | trace("Initializing model parameters.") 127 | for p in model.parameters(): 128 | p.data.uniform_(-config.param_init, config.param_init) 129 | 130 | 131 | 132 | # if hasattr(model.encoder, 'embeddings'): 133 | # model.encoder.embeddings.load_pretrained_vectors( 134 | # config.pre_word_vecs_enc, config.fix_word_vecs_enc) 135 | # if hasattr(model.decoder, 'embeddings'): 136 | # model.decoder.embeddings.load_pretrained_vectors( 137 | # config.pre_word_vecs_dec, config.fix_word_vecs_dec) 138 | 139 | if train_mode: 140 | model.train() 141 | else: 142 | model.eval() 143 | 144 | if config.gpu_ids is not None: 145 | model.cuda() 146 | else: 147 | model.cpu() 148 | 149 | return model 150 | -------------------------------------------------------------------------------- /NMT/Models/Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from NMT.Modules import StackedLSTM 5 | from NMT.Modules import StackedGRU 6 | from NMT.Modules import GlobalAttention 7 | 8 | 9 | class RNNDecoderState(object): 10 | def __init__(self, hidden_size, rnn_state): 11 | """ 12 | Args: 13 | hidden_size (int): the size of hidden layer of the decoder. 14 | rnn_state: final hidden state from the encoder. 15 | transformed to shape: layers x batch x (directions*dim). 16 | """ 17 | if isinstance(rnn_state, tuple): 18 | # LSTM 19 | self.state = rnn_state 20 | else: 21 | # GRU 22 | self.state = (rnn_state, ) 23 | 24 | self.coverage = None 25 | 26 | # Init the input feed. 27 | batch_size = self.state[0].size(1) 28 | 29 | self.input_feed = Variable( 30 | self.state[0].data.new(batch_size, hidden_size).zero_()).unsqueeze(0) 31 | 32 | @property 33 | def _all(self): 34 | return self.state + (self.input_feed,) 35 | 36 | 37 | def beam_update(self, idx, positions, beam_size): 38 | for e in self._all: 39 | sizes = e.size() 40 | br = sizes[1] 41 | if len(sizes) == 3: 42 | sent_states = e.view(sizes[0], beam_size, br // beam_size, 43 | sizes[2])[:, :, idx] 44 | else: 45 | sent_states = e.view(sizes[0], beam_size, 46 | br // beam_size, 47 | sizes[2], 48 | sizes[3])[:, :, idx] 49 | 50 | sent_states.data.copy_( 51 | sent_states.data.index_select(1, positions)) 52 | 53 | def update_state(self, rnn_state, input_feed): 54 | if not isinstance(rnn_state, tuple): 55 | self.state = (rnn_state,) 56 | else: 57 | self.state = rnn_state 58 | self.input_feed = input_feed 59 | 60 | def repeat_beam_size_times(self, beam_size): 61 | """ Repeat beam_size times along batch dimension. """ 62 | 63 | repeat_func = lambda x: Variable( 64 | x.data.repeat(1, beam_size, 1), requires_grad=False) 65 | vars = tuple(repeat_func(h) for h in self._all) 66 | self.state = vars[:-1] 67 | self.input_feed = vars[-1] 68 | 69 | class RNNDecoderBase(nn.Module): 70 | def __init__(self, rnn_type, 71 | input_size, 72 | hidden_size, 73 | num_layers=2, 74 | attn_type="general", 75 | bidirectional_encoder=True, 76 | dropout=0.0, 77 | embeddings=None): 78 | 79 | super(RNNDecoderBase, self).__init__() 80 | 81 | # Basic attributes. 82 | self.rnn_type = rnn_type 83 | self.decoder_type = 'rnn' 84 | self.bidirectional_encoder = bidirectional_encoder 85 | self.num_layers = num_layers 86 | self.input_size = input_size 87 | self.hidden_size = hidden_size 88 | self.embeddings = embeddings 89 | self.dropout = nn.Dropout(dropout) 90 | # Build the RNN. 91 | self.rnn = self._build_rnn(rnn_type, 92 | input_size=input_size, 93 | hidden_size=hidden_size, 94 | num_layers=num_layers, 95 | dropout=dropout) 96 | 97 | 98 | self.attn = GlobalAttention( 99 | hidden_size, attn_type=attn_type 100 | ) 101 | 102 | def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout): 103 | if rnn_type == "LSTM": 104 | stacked_cell = StackedLSTM 105 | elif rnn_type == "GRU": 106 | stacked_cell = StackedGRU 107 | else: 108 | raise NotImplementedError 109 | return stacked_cell(num_layers, input_size, 110 | hidden_size, dropout) 111 | 112 | def forward(self, trg, encoder_outputs, lengths, state): 113 | """ 114 | Args: 115 | trg (`LongTensor`): sequences of padded tokens 116 | `[L_t x B x D]`. 117 | encoder_outputs (`FloatTensor`): vectors from the encoder 118 | `[L_s x B x H]`. 119 | 120 | lengths (`LongTensor`): the padded source lengths 121 | `[B]`. 122 | state (`DecoderState`): 123 | decoder state object to initialize the decoder 124 | Returns: 125 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 126 | * decoder_outputs: output from the decoder (after attn) 127 | `[trg_len x batch x hidden]`. 128 | * decoder_state: final hidden state from the decoder 129 | * attns: distribution over source words at each target word 130 | `[L_t x B x L_s]`. 131 | """ 132 | # Run the forward pass of the RNN. 133 | decoder_outputs, final_state, attns = self.forward_step( 134 | trg, encoder_outputs, lengths, state) 135 | 136 | # Update the state with the result. 137 | final_output = decoder_outputs[-1] 138 | state.update_state(final_state, final_output.unsqueeze(0)) 139 | 140 | # Concatenates sequence of tensors along a new dimension. 141 | decoder_outputs = torch.stack(decoder_outputs) 142 | for k in attns: 143 | attns[k] = torch.stack(attns[k]) 144 | 145 | return decoder_outputs, state, attns 146 | 147 | 148 | 149 | class InputFeedRNNDecoder(RNNDecoderBase): 150 | def __init__(self, rnn_type, 151 | embedding_size, 152 | hidden_size, 153 | num_layers=2, 154 | attn_type="general", 155 | bidirectional_encoder=True, 156 | dropout=0.0): 157 | super(InputFeedRNNDecoder, self).__init__(rnn_type, 158 | embedding_size + hidden_size, 159 | hidden_size, 160 | num_layers, 161 | attn_type, 162 | bidirectional_encoder, 163 | dropout) 164 | 165 | def forward_step(self, trg, encoder_outputs, lengths, dec_state): 166 | """ 167 | Input feed concatenates hidden state with input at every time step. 168 | 169 | Args: 170 | trg (`LongTensor`): sequences of padded tokens 171 | `[L_t x B x nfeats]`. 172 | encoder_outputs (`FloatTensor`): vectors from the encoder 173 | `[L_s x B x H]`. 174 | state (`DecoderState`): 175 | decoder state object to initialize the decoder 176 | lengths (`LongTensor`): the padded source lengths 177 | `[B]`. 178 | Returns: 179 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 180 | * decoder_outputs: output from the decoder (after attn) 181 | `[L_t x B x H]`. 182 | * decoder_state: final hidden state from the decoder 183 | * attns: distribution over source words at each target word 184 | `[L_t x B x L_s]`. 185 | """ 186 | 187 | input_feed = dec_state.input_feed.squeeze(0) # [B x H] 188 | 189 | decoder_outputs = [] 190 | attns = {"std": []} 191 | 192 | rnn_state = dec_state.state 193 | 194 | for t, emb_t in enumerate(trg.split(1, dim=0)): 195 | # iterate over each target word 196 | emb_t = emb_t.squeeze(0) 197 | # teacher forcing 198 | decoder_input = torch.cat([emb_t, input_feed], 1) 199 | # update state and feed to next RNNCell 200 | rnn_output, rnn_state = self.rnn(decoder_input, rnn_state) 201 | 202 | decoder_output, attn = self.attn( 203 | rnn_output, encoder_outputs.transpose(0, 1), lengths=lengths) 204 | 205 | decoder_output = self.dropout(decoder_output) 206 | 207 | input_feed = decoder_output 208 | decoder_outputs += [decoder_output] 209 | attns["std"] += [attn] 210 | 211 | # Return result. 212 | return decoder_outputs, rnn_state, attns 213 | 214 | class VarInputFeedRNNDecoder(RNNDecoderBase): 215 | def __init__(self, rnn_type, 216 | embedding_size, 217 | hidden_size, 218 | latent_size, 219 | num_layers=2, 220 | attn_type="general", 221 | bidirectional_encoder=True, 222 | dropout=0.0): 223 | super(VarInputFeedRNNDecoder, self).__init__(rnn_type, 224 | embedding_size + hidden_size + latent_size, 225 | hidden_size, 226 | num_layers, 227 | attn_type, 228 | bidirectional_encoder, 229 | dropout) 230 | 231 | self.context_to_mu = nn.Linear( 232 | hidden_size, 233 | latent_size) 234 | self.context_to_logvar = nn.Linear( 235 | hidden_size, 236 | latent_size) 237 | 238 | def reparameterize(self, state): 239 | """ 240 | context [B x 2H] 241 | """ 242 | hidden = self.get_hidden(state) 243 | mu = self.context_to_mu(hidden) 244 | logvar = self.context_to_logvar(hidden) 245 | if self.training: 246 | std = logvar.mul(0.5).exp_() 247 | eps = Variable(std.data.new(std.size()).normal_()) 248 | z = eps.mul(std).add_(mu) 249 | else: 250 | z = mu 251 | return z, mu, logvar 252 | 253 | 254 | def get_hidden(self, state): 255 | hidden = None 256 | if self.rnn_type == "GRU": 257 | hidden = state[-1] 258 | elif self.rnn_type == "LSTM": 259 | hidden = state[0][-1] 260 | return hidden 261 | def compute_kld(self, mu, logvar): 262 | kld = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) 263 | return kld 264 | def forward_step(self, trg, encoder_outputs, lengths, dec_state): 265 | """ 266 | Input feed concatenates hidden state with input at every time step. 267 | 268 | Args: 269 | trg (`LongTensor`): sequences of padded tokens 270 | `[L_t x B x nfeats]`. 271 | encoder_outputs (`FloatTensor`): vectors from the encoder 272 | `[L_s x B x H]`. 273 | state (`DecoderState`): 274 | decoder state object to initialize the decoder 275 | lengths (`LongTensor`): the padded source lengths 276 | `[B]`. 277 | Returns: 278 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 279 | * decoder_outputs: output from the decoder (after attn) 280 | `[L_t x B x H]`. 281 | * decoder_state: final hidden state from the decoder 282 | * attns: distribution over source words at each target word 283 | `[L_t x B x L_s]`. 284 | """ 285 | 286 | input_feed = dec_state.input_feed.squeeze(0) # [B x H] 287 | 288 | decoder_outputs = [] 289 | attns = {"std": []} 290 | 291 | rnn_state = dec_state.state 292 | 293 | kld = 0. 294 | for t, emb_t in enumerate(trg.split(1, dim=0)): 295 | # iterate over each target word 296 | emb_t = emb_t.squeeze(0) 297 | # teacher forcing 298 | z, mu, logvar = self.reparameterize(rnn_state[0]) 299 | decoder_input = torch.cat([emb_t, input_feed, z], -1) 300 | kld += self.compute_kld(mu, logvar) 301 | # update state and feed to next RNNCell 302 | rnn_output, rnn_state = self.rnn(decoder_input, rnn_state) 303 | 304 | decoder_output, attn = self.attn( 305 | rnn_output, encoder_outputs.transpose(0, 1), lengths=lengths) 306 | 307 | decoder_output = self.dropout(decoder_output) 308 | 309 | input_feed = decoder_output 310 | decoder_outputs += [decoder_output] 311 | attns["std"] += [attn] 312 | 313 | # Return result. 314 | return decoder_outputs, rnn_state, attns, kld 315 | def forward(self, trg, encoder_outputs, lengths, state): 316 | """ 317 | Args: 318 | trg (`LongTensor`): sequences of padded tokens 319 | `[L_t x B x D]`. 320 | encoder_outputs (`FloatTensor`): vectors from the encoder 321 | `[L_s x B x H]`. 322 | 323 | lengths (`LongTensor`): the padded source lengths 324 | `[B]`. 325 | state (`DecoderState`): 326 | decoder state object to initialize the decoder 327 | Returns: 328 | (`FloatTensor`,:obj:`nmt.Models.DecoderState`,`FloatTensor`): 329 | * decoder_outputs: output from the decoder (after attn) 330 | `[trg_len x batch x hidden]`. 331 | * decoder_state: final hidden state from the decoder 332 | * attns: distribution over source words at each target word 333 | `[L_t x B x L_s]`. 334 | """ 335 | # Run the forward pass of the RNN. 336 | decoder_outputs, final_state, attns, kld = self.forward_step( 337 | trg, encoder_outputs, lengths, state) 338 | 339 | # Update the state with the result. 340 | final_output = decoder_outputs[-1] 341 | state.update_state(final_state, final_output.unsqueeze(0)) 342 | 343 | # Concatenates sequence of tensors along a new dimension. 344 | decoder_outputs = torch.stack(decoder_outputs) 345 | for k in attns: 346 | attns[k] = torch.stack(attns[k]) 347 | 348 | return decoder_outputs, state, attns, kld 349 | -------------------------------------------------------------------------------- /NMT/Models/Encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.nn.utils.rnn import pad_packed_sequence 6 | 7 | 8 | class PackedRNNEncoder(nn.Module): 9 | def __init__(self, rnn_type, embed_dim, 10 | hidden_size, num_layers=2, dropout=0.0, bidirectional=True): 11 | super(PackedRNNEncoder, self).__init__() 12 | 13 | num_directions = 2 if bidirectional else 1 14 | assert hidden_size % num_directions == 0 15 | hidden_size = hidden_size // num_directions 16 | self.num_layers = num_layers 17 | self.bidirectional = bidirectional 18 | self.rnn_type = rnn_type 19 | self.rnn = getattr(nn, rnn_type)(embed_dim, 20 | hidden_size=hidden_size, 21 | num_layers=num_layers, 22 | dropout=dropout, 23 | bidirectional=bidirectional) 24 | self.hidden_size = hidden_size * num_directions 25 | def fix_final_state(self, final_state): 26 | def resize(h): 27 | # The encoder hidden is (layers*directions) x batch x dim. 28 | # We need to convert it to layers x batch x (directions*dim). 29 | if self.bidirectional: 30 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], -1) 31 | return h 32 | if self.bidirectional: 33 | if self.rnn_type == "GRU": 34 | final_state = resize(final_state) 35 | elif self.rnn_type == "LSTM": 36 | final_state = tuple([resize(h) for h in final_state]) 37 | return final_state 38 | def forward(self, input, lengths=None, state=None): 39 | if lengths is not None: 40 | packed = pack_padded_sequence(input, lengths.view(-1).tolist()) 41 | output, final_state = self.rnn(packed, state) 42 | 43 | if lengths is not None: 44 | output = pad_packed_sequence(output)[0] 45 | return output, self.fix_final_state(final_state) 46 | 47 | class RNNEncoder(PackedRNNEncoder): 48 | def __init__(self, *arg, **kwargs): 49 | super(RNNEncoder, self).__init__(*arg, **kwargs) 50 | 51 | def forward(self, input, lengths=None, state=None): 52 | 53 | if lengths is not None: 54 | lengths, rank = torch.sort(lengths, dim=0, descending=True) 55 | input = input.index_select(1, rank) 56 | output, final_state = super(RNNEncoder, self).forward(input, lengths) 57 | _, order = torch.sort(rank, dim=0, descending=False) 58 | if isinstance(final_state, tuple): 59 | final_state = tuple(x.index_select(1, order) for x in final_state) 60 | else: 61 | final_state = final_state.index_select(1, order) 62 | return output.index_select(1, order), final_state 63 | 64 | -------------------------------------------------------------------------------- /NMT/Models/NMTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from NMT.Models.Decoders import RNNDecoderState 6 | 7 | class NMTModel(nn.Module): 8 | """ 9 | Core model for NMT. 10 | { 11 | Encoder + Decoder. 12 | } 13 | """ 14 | def __init__(self, 15 | encoder, decoder, 16 | src_embedding, trg_embedding, 17 | trg_vocab_size, 18 | config): 19 | super(NMTModel, self).__init__() 20 | 21 | self.encoder = encoder 22 | self.src_embedding = src_embedding 23 | 24 | self.trg_embedding = trg_embedding 25 | 26 | self.decoder = decoder 27 | self.generator = nn.Linear(config.hidden_size, trg_vocab_size) 28 | self.config = config 29 | 30 | 31 | def encoder2decoder(self, encoder_state): 32 | if isinstance(encoder_state, tuple): 33 | # LSTM: encoder_state = (hidden, state) 34 | return RNNDecoderState(self.encoder.hidden_size, encoder_state) 35 | else: 36 | # GRU: encoder_state = state 37 | return RNNDecoderState(self.encoder.hidden_size, encoder_state) 38 | 39 | def forward(self, src, lengths, trg, decoder_state=None): 40 | """ 41 | Forward propagate a `src` and `trg` pair for training. 42 | Possible initialized with a beginning decoder state. 43 | 44 | Args: 45 | src (Tensor): source sequence. [L x B x N]`. 46 | trg (LongTensor): source sequence. [L x B]`. 47 | lengths (LongTensor): the src lengths, pre-padding `[batch]`. 48 | dec_state (`DecoderState`, optional): initial decoder state 49 | Returns: 50 | (:obj:`FloatTensor`, `dict`, :obj:`nmt.Models.DecoderState`): 51 | 52 | * decoder output `[trg_len x batch x hidden]` 53 | * dictionary attention dists of `[trg_len x batch x src_len]` 54 | * final decoder state 55 | """ 56 | trg = trg[:-1] 57 | # encoding side 58 | encoder_outputs, encoder_state = self.encoder( 59 | self.src_embedding(src), lengths) 60 | 61 | # encoder to decoder 62 | decoder_state = self.encoder2decoder(encoder_state) 63 | 64 | decoder_input = self.trg_embedding(trg) 65 | # decoding side 66 | decoder_outputs, decoder_state, attns = self.decoder( 67 | decoder_input, encoder_outputs, lengths, decoder_state) 68 | 69 | return decoder_outputs, decoder_state, attns 70 | 71 | 72 | -------------------------------------------------------------------------------- /NMT/Models/VNMTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | from NMT.Models import NMTModel 7 | from NMT.Models.Decoders import RNNDecoderState 8 | 9 | def compute_kld(mu, logvar): 10 | kld = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) 11 | return kld 12 | 13 | 14 | class VNMTModel(NMTModel): 15 | """Recent work has found that VAE + LSTM decoders underperforms vanilla LSTM decoder. 16 | You should use VAE + GRU. 17 | see Yang et. al, ICML 2017, `Improved VAE for Text Modeling using Dilated Convolutions`. 18 | """ 19 | def __init__(self, 20 | encoder, decoder, 21 | src_embedding, trg_embedding, 22 | trg_vocab_size, 23 | config): 24 | super(VNMTModel, self).__init__( 25 | encoder, decoder, 26 | src_embedding, trg_embedding, 27 | trg_vocab_size, config) 28 | self.context_to_mu = nn.Linear( 29 | config.hidden_size, 30 | config.latent_size) 31 | self.context_to_logvar = nn.Linear( 32 | config.hidden_size, 33 | config.latent_size) 34 | self.lstm_state2context = nn.Linear( 35 | 2*config.hidden_size, 36 | config.latent_size) 37 | def get_hidden(self, state): 38 | hidden = None 39 | if self.encoder.rnn_type == "GRU": 40 | hidden = state[-1] 41 | elif self.encoder.rnn_type == "LSTM": 42 | hidden, context = state[0][-1], state[1][-1] 43 | hidden = self.lstm_state2context(torch.cat([hidden, context], -1)) 44 | return hidden 45 | 46 | def reparameterize(self, encoder_state): 47 | """ 48 | context [B x 2H] 49 | """ 50 | hidden = self.get_hidden(encoder_state) 51 | mu = self.context_to_mu(hidden) 52 | logvar = self.context_to_logvar(hidden) 53 | if self.training: 54 | std = logvar.mul(0.5).exp_() 55 | eps = Variable(std.data.new(std.size()).normal_()) 56 | z = eps.mul(std).add_(mu) 57 | else: 58 | z = mu 59 | return z, mu, logvar 60 | 61 | 62 | def forward(self, src, src_lengths, trg, trg_lengths=None, decoder_state=None): 63 | """ 64 | Forward propagate a `src` and `trg` pair for training. 65 | Possible initialized with a beginning decoder state. 66 | 67 | Args: 68 | src (Tensor): source sequence. [L x B x N]`. 69 | trg (LongTensor): source sequence. [L x B]`. 70 | src_lengths (LongTensor): the src lengths, pre-padding `[batch]`. 71 | trg_lengths (LongTensor): the trg lengths, pre-padding `[batch]`. 72 | dec_state (`DecoderState`, optional): initial decoder state 73 | z (`FloatTensor`): latent variables 74 | Returns: 75 | (:obj:`FloatTensor`, `dict`, :obj:`nmt.Models.DecoderState`): 76 | 77 | * decoder output `[trg_len x batch x hidden]` 78 | * dictionary attention dists of `[trg_len x batch x src_len]` 79 | * final decoder state 80 | """ 81 | 82 | # encoding side 83 | encoder_outputs, encoder_state = self.encoder( 84 | self.src_embedding(src), src_lengths) 85 | 86 | # re-parameterize 87 | z, mu, logvar = self.reparameterize(encoder_state) 88 | # encoder to decoder 89 | decoder_state = self.encoder2decoder(encoder_state) 90 | 91 | trg_feed = trg[:-1] 92 | decoder_input = torch.cat([ 93 | self.trg_embedding(trg_feed), 94 | z.unsqueeze(0).repeat(trg_feed.size(0) ,1, 1)], 95 | -1) 96 | 97 | # decoding side 98 | decoder_outputs, decoder_state, attns = self.decoder( 99 | decoder_input, encoder_outputs, src_lengths, decoder_state) 100 | 101 | return decoder_outputs, decoder_state, attns, compute_kld(mu, logvar) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /NMT/Models/VRNMTModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | from NMT.Models import NMTModel 7 | 8 | 9 | class VRNMTModel(NMTModel): 10 | def forward(self, src, lengths, trg, decoder_state=None): 11 | """ 12 | Forward propagate a `src` and `trg` pair for training. 13 | Possible initialized with a beginning decoder state. 14 | 15 | Args: 16 | src (Tensor): source sequence. [L x B x N]`. 17 | trg (LongTensor): source sequence. [L x B]`. 18 | lengths (LongTensor): the src lengths, pre-padding `[batch]`. 19 | dec_state (`DecoderState`, optional): initial decoder state 20 | Returns: 21 | (:obj:`FloatTensor`, `dict`, :obj:`nmt.Models.DecoderState`): 22 | 23 | * decoder output `[trg_len x batch x hidden]` 24 | * dictionary attention dists of `[trg_len x batch x src_len]` 25 | * final decoder state 26 | """ 27 | trg = trg[:-1] 28 | # encoding side 29 | encoder_outputs, encoder_state = self.encoder( 30 | self.src_embedding(src), lengths) 31 | 32 | # encoder to decoder 33 | decoder_state = self.encoder2decoder(encoder_state) 34 | 35 | decoder_input = self.trg_embedding(trg) 36 | # decoding side 37 | decoder_outputs, decoder_state, attns, kld = self.decoder( 38 | decoder_input, encoder_outputs, lengths, decoder_state) 39 | 40 | return decoder_outputs, decoder_state, attns, kld 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /NMT/Models/__init__.py: -------------------------------------------------------------------------------- 1 | from NMT.Models.NMTModel import NMTModel 2 | from NMT.Models.VNMTModel import VNMTModel 3 | from NMT.Models.VRNMTModel import VRNMTModel 4 | from NMT.Models.Encoders import PackedRNNEncoder 5 | from NMT.Models.Decoders import RNNDecoderState 6 | from NMT.Models.Encoders import RNNEncoder 7 | from NMT.Models.Decoders import InputFeedRNNDecoder 8 | from NMT.Models.Decoders import VarInputFeedRNNDecoder 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /NMT/Modules/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from Utils.utils import aeq, sequence_mask 5 | 6 | 7 | class GlobalAttention(nn.Module): 8 | """ 9 | Global attention takes a matrix and a query vector. It 10 | then computes a parameterized convex combination of the matrix 11 | based on the input query. 12 | 13 | * Luong Attention (dot, general): 14 | * dot: :math:`score(H_j,q) = H_j^T q` 15 | * general: :math:`score(H_j, q) = H_j^T W_a q` 16 | 17 | * Bahdanau Attention (mlp): 18 | * :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)` 19 | 20 | Args: 21 | dim (int): dimensionality of query and key 22 | coverage (bool): use coverage term 23 | attn_type (str): type of attention to use, options [dot,general,mlp] 24 | 25 | """ 26 | def __init__(self, dim, coverage=False, attn_type="dot"): 27 | super(GlobalAttention, self).__init__() 28 | 29 | self.dim = dim 30 | self.attn_type = attn_type 31 | assert (self.attn_type in ["dot", "general", "mlp"]), ( 32 | "Please select a valid attention type.") 33 | 34 | if self.attn_type == "general": 35 | self.linear_in = nn.Linear(dim, dim, bias=False) 36 | elif self.attn_type == "mlp": 37 | self.linear_context = nn.Linear(dim, dim, bias=False) 38 | self.linear_query = nn.Linear(dim, dim, bias=True) 39 | self.v = nn.Linear(dim, 1, bias=False) 40 | # mlp wants it with bias 41 | out_bias = self.attn_type == "mlp" 42 | self.linear_out = nn.Linear(dim*2, dim, bias=out_bias) 43 | 44 | self.sm = nn.Softmax(dim=-1) 45 | self.tanh = nn.Tanh() 46 | 47 | if coverage: 48 | self.linear_cover = nn.Linear(1, dim, bias=False) 49 | 50 | def score(self, h_t, h_s): 51 | """ 52 | Args: 53 | h_t (`FloatTensor`): sequence of queries `[batch x trg_len x dim]` 54 | h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` 55 | 56 | Returns: 57 | :obj:`FloatTensor`: 58 | raw attention scores (unnormalized) for each src index 59 | `[batch x trg_len x src_len]` 60 | 61 | """ 62 | 63 | # Check input sizes 64 | src_batch, src_len, src_dim = h_s.size() 65 | trg_batch, trg_len, trg_dim = h_t.size() 66 | aeq(src_batch, trg_batch) 67 | aeq(src_dim, trg_dim) 68 | aeq(self.dim, src_dim) 69 | 70 | if self.attn_type in ["general", "dot"]: 71 | if self.attn_type == "general": 72 | h_t_ = h_t.view(trg_batch*trg_len, trg_dim) 73 | h_t_ = self.linear_in(h_t_) 74 | h_t = h_t_.view(trg_batch, trg_len, trg_dim) 75 | h_s_ = h_s.transpose(1, 2) 76 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 77 | return torch.bmm(h_t, h_s_) 78 | else: 79 | dim = self.dim 80 | wq = self.linear_query(h_t.view(-1, dim)) 81 | wq = wq.view(trg_batch, trg_len, 1, dim) 82 | wq = wq.expand(trg_batch, trg_len, src_len, dim) 83 | 84 | uh = self.linear_context(h_s.contiguous().view(-1, dim)) 85 | uh = uh.view(src_batch, 1, src_len, dim) 86 | uh = uh.expand(src_batch, trg_len, src_len, dim) 87 | 88 | # (batch, t_len, s_len, d) 89 | wquh = self.tanh(wq + uh) 90 | 91 | return self.v(wquh.view(-1, dim)).view(trg_batch, trg_len, src_len) 92 | 93 | def forward(self, input, memory_bank, lengths=None): 94 | """ 95 | 96 | Args: 97 | input (`FloatTensor`): query vectors `[batch x trg_len x dim]` 98 | memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` 99 | lengths (`LongTensor`): the source context lengths `[batch]` 100 | Returns: 101 | (`FloatTensor`, `FloatTensor`): 102 | 103 | * Computed vector `[trg_len x batch x dim]` 104 | * Attention distribtutions for each query 105 | `[trg_len x batch x src_len]` 106 | """ 107 | 108 | # one step input 109 | if input.dim() == 2: 110 | one_step = True 111 | input = input.unsqueeze(1) 112 | else: 113 | one_step = False 114 | 115 | batch, sourceL, dim = memory_bank.size() 116 | batch_, targetL, dim_ = input.size() 117 | 118 | # compute attention scores, as in Luong et al. 119 | align = self.score(input, memory_bank) 120 | 121 | if lengths is not None: 122 | mask = sequence_mask(lengths) 123 | mask = mask.unsqueeze(1) # Make it broadcastable. 124 | align.data.masked_fill_(1 - mask, -float('inf')) 125 | 126 | # Softmax to normalize attention weights 127 | align_vectors = self.sm(align.view(batch*targetL, sourceL)) 128 | align_vectors = align_vectors.view(batch, targetL, sourceL) 129 | 130 | # each context vector c_t is the weighted average 131 | # over all the source hidden states 132 | c = torch.bmm(align_vectors, memory_bank) 133 | 134 | # concatenate 135 | concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) 136 | attn_h = self.linear_out(concat_c).view(batch, targetL, dim) 137 | if self.attn_type in ["general", "dot"]: 138 | attn_h = self.tanh(attn_h) 139 | 140 | if one_step: 141 | attn_h = attn_h.squeeze(1) 142 | align_vectors = align_vectors.squeeze(1) 143 | 144 | # Check output sizes 145 | batch_, dim_ = attn_h.size() 146 | aeq(batch, batch_) 147 | aeq(dim, dim_) 148 | batch_, sourceL_ = align_vectors.size() 149 | aeq(batch, batch_) 150 | aeq(sourceL, sourceL_) 151 | else: 152 | attn_h = attn_h.transpose(0, 1).contiguous() 153 | align_vectors = align_vectors.transpose(0, 1).contiguous() 154 | 155 | # Check output sizes 156 | targetL_, batch_, dim_ = attn_h.size() 157 | aeq(targetL, targetL_) 158 | aeq(batch, batch_) 159 | aeq(dim, dim_) 160 | targetL_, batch_, sourceL_ = align_vectors.size() 161 | aeq(targetL, targetL_) 162 | aeq(batch, batch_) 163 | aeq(sourceL, sourceL_) 164 | 165 | return attn_h, align_vectors 166 | -------------------------------------------------------------------------------- /NMT/Modules/StackedRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class StackedLSTM(nn.Module): 6 | """ 7 | Customed stacked LSTM, which we can custom the first layer size. 8 | """ 9 | def __init__(self, num_layers, input_size, hidden_size, dropout): 10 | super(StackedLSTM, self).__init__() 11 | self.dropout = nn.Dropout(dropout) 12 | self.num_layers = num_layers 13 | self.layers = nn.ModuleList() 14 | 15 | for i in range(num_layers): 16 | if i == 0: 17 | self.layers.append(nn.LSTMCell(input_size, hidden_size)) 18 | else: 19 | self.layers.append(nn.LSTMCell(hidden_size, hidden_size)) 20 | 21 | def forward(self, input, hidden): 22 | h_0, c_0 = hidden 23 | h_1, c_1 = [], [] 24 | for i, layer in enumerate(self.layers): 25 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 26 | input = h_1_i 27 | if i + 1 != self.num_layers: 28 | input = self.dropout(input) 29 | h_1 += [h_1_i] 30 | c_1 += [c_1_i] 31 | 32 | h_1 = torch.stack(h_1) 33 | c_1 = torch.stack(c_1) 34 | 35 | return input, (h_1, c_1) 36 | 37 | 38 | class StackedGRU(nn.Module): 39 | 40 | def __init__(self, num_layers, input_size, hidden_size, dropout): 41 | super(StackedGRU, self).__init__() 42 | self.dropout = nn.Dropout(dropout) 43 | self.num_layers = num_layers 44 | self.layers = nn.ModuleList() 45 | 46 | for i in range(num_layers): 47 | if i == 0: 48 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 49 | else: 50 | self.layers.append(nn.GRUCell(hidden_size, hidden_size)) 51 | 52 | 53 | def forward(self, input, hidden): 54 | h_1 = [] 55 | for i, layer in enumerate(self.layers): 56 | h_1_i = layer(input, hidden[0][i]) 57 | input = h_1_i 58 | if i + 1 != self.num_layers: 59 | input = self.dropout(input) 60 | h_1 += [h_1_i] 61 | 62 | h_1 = torch.stack(h_1) 63 | return input, (h_1,) 64 | -------------------------------------------------------------------------------- /NMT/Modules/__init__.py: -------------------------------------------------------------------------------- 1 | from NMT.Modules.GlobalAttention import GlobalAttention 2 | from NMT.Modules.StackedRNN import StackedLSTM, StackedGRU 3 | -------------------------------------------------------------------------------- /NMT/Optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm 3 | from Utils.utils import trace 4 | 5 | 6 | class MultipleOptimizer(object): 7 | def __init__(self, op): 8 | self.optimizers = op 9 | 10 | def zero_grad(self): 11 | for op in self.optimizers: 12 | op.zero_grad() 13 | 14 | def step(self): 15 | for op in self.optimizers: 16 | op.step() 17 | 18 | class Optimizer(object): 19 | def __init__(self, method, config): 20 | 21 | self.last_ppl = None 22 | self.lr = config.lr 23 | self.original_lr = config.lr 24 | self.max_grad_norm = config.max_grad_norm 25 | self.method = method 26 | self.lr_decay_rate = config.lr_decay_rate 27 | self.start_decay_at = config.start_decay_at 28 | self.start_decay = False 29 | self.alpha = config.alpha 30 | self._step = 0 31 | self.momentum = config.momentum 32 | self.betas = [config.adam_beta1, config.adam_beta2] 33 | self.adagrad_accum=config.adagrad_accum_init, 34 | self.decay_method=config.decay_method, 35 | self.warmup_steps=config.warmup_steps, 36 | self.model_size=config.hidden_size 37 | self.eps = config.eps 38 | 39 | def set_parameters(self, params): 40 | self.params = [] 41 | #self.sparse_params = [] 42 | for k, p in params: 43 | if p.requires_grad: 44 | if self.method != 'sparseadam' or "embed" not in k: 45 | self.params.append(p) 46 | # else: 47 | # self.sparse_params.append(p) 48 | if self.method == 'SGD': 49 | self.optimizer = optim.SGD(self.params, lr=self.lr) 50 | elif self.method == 'Adagrad': 51 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 52 | for group in self.optimizer.param_groups: 53 | for p in group['params']: 54 | self.optimizer.state[p]['sum'] = self.optimizer\ 55 | .state[p]['sum'].fill_(self.adagrad_accum) 56 | elif self.method == 'Adadelta': 57 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 58 | elif self.method == 'Adam': 59 | self.optimizer = optim.Adam(self.params, lr=self.lr, 60 | betas=self.betas, eps=1e-9) 61 | elif self.method == 'RMSprop': 62 | # does not work properly. 63 | self.optimizer = optim.RMSprop(self.params, lr=self.lr, 64 | alpha=self.alpha, eps=self.eps, weight_decay=self.lr_decay_rate, 65 | momentum=self.momentum, centered=False) 66 | def _set_rate(self, lr): 67 | self.lr = lr 68 | if self.method != 'sparseadam': 69 | self.optimizer.param_groups[0]['lr'] = self.lr 70 | else: 71 | for op in self.optimizer.optimizers: 72 | op.param_groups[0]['lr'] = self.lr 73 | 74 | def step(self): 75 | """Update the model parameters based on current gradients. 76 | 77 | Optionally, will employ gradient modification or update learning 78 | rate. 79 | """ 80 | self._step += 1 81 | 82 | if self.max_grad_norm: 83 | clip_grad_norm(self.params, self.max_grad_norm) 84 | self.optimizer.step() 85 | 86 | def update_lr(self, ppl, epoch): 87 | """ 88 | Decay learning rate if val perf does not improve 89 | or we hit the start_decay_at limit. 90 | """ 91 | 92 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 93 | self.start_decay = True 94 | # if self.last_ppl is not None and ppl > self.last_ppl: 95 | # self.start_decay = True 96 | 97 | if self.start_decay: 98 | self.lr = self.lr * self.lr_decay_rate 99 | trace("Decaying learning rate to %g" % self.lr) 100 | 101 | self.last_ppl = ppl 102 | if self.method != 'sparseadam': 103 | self.optimizer.param_groups[0]['lr'] = self.lr 104 | -------------------------------------------------------------------------------- /NMT/Statistics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import math 4 | 5 | class Statistics(object): 6 | """ 7 | Accumulator for loss statistics. 8 | Currently calculates: 9 | 10 | * accuracy 11 | * perplexity 12 | * elapsed time 13 | """ 14 | def __init__(self, loss=0, n_words=0, n_correct=0, loss_dict={"CELoss": 0.0, "KLDLoss": 0.0}): 15 | self.loss = loss 16 | self.loss_dict = loss_dict 17 | self.n_words = n_words 18 | self.n_correct = n_correct 19 | self.n_src_words = 0 20 | self.start_time = time.time() 21 | 22 | def update(self, stat): 23 | self.loss += stat.loss 24 | for key, val in stat.loss_dict.items(): 25 | self.loss_dict[key] += val 26 | self.n_words += stat.n_words 27 | self.n_correct += stat.n_correct 28 | 29 | def accuracy(self): 30 | if self.n_words == 0: 31 | return 0 32 | return 100 * (float(self.n_correct) / self.n_words) 33 | 34 | def loss_detail(self): 35 | return (self.loss_dict["CELoss"], self.loss_dict["KLDLoss"]) 36 | 37 | def xent(self): 38 | if self.n_words == 0: 39 | return 0 40 | return self.loss / self.n_words 41 | 42 | def ppl(self): 43 | if self.n_words == 0: 44 | return 0 45 | return math.exp(min(float(self.loss)/self.n_words, 100)) 46 | #return math.exp(min(self.loss, 100)) 47 | 48 | def elapsed_time(self): 49 | return time.time() - self.start_time 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /NMT/Trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | from NMT.Statistics import Statistics 8 | from NMT.Models import NMTModel 9 | from NMT.Models import VNMTModel 10 | from NMT.Models import VRNMTModel 11 | from Utils.utils import report_stats 12 | class Trainer(object): 13 | """ 14 | Class that controls the training process. 15 | 16 | Args: 17 | model (NMT.Model.NMTModel): NMT base model 18 | 19 | loss_func (NMT.Loss.LossComputeBase): set loss func 20 | optimizer (NMT.Optimizer.Optimizer): optimizer 21 | config (Config): global configurations 22 | """ 23 | 24 | def __init__(self, model, loss_func, optimizer, config): 25 | # Basic attributes. 26 | self.model = model 27 | self.loss_func = loss_func 28 | self.optim = optimizer 29 | self.config = config 30 | self.progress_step = 0 31 | 32 | def train(self, train_iter, epoch, num_batches): 33 | """ Train next epoch. 34 | Args: 35 | train_iter (BatchDataIterator): training data iterator 36 | epoch (int): the epoch number 37 | num_batches (int): the batch number 38 | Returns: 39 | stats (Statistics): epoch loss statistics 40 | """ 41 | self.model.train() 42 | 43 | total_stats = Statistics() 44 | self.loss_func.kld_weight_step(epoch, self.config.start_increase_kld_at) 45 | for idx, batch in enumerate(train_iter): 46 | self.model.zero_grad() 47 | src, src_lengths = batch.src, batch.src_Ls 48 | trg, trg_lengths = batch.trg, batch.trg_Ls 49 | ref = batch.trg[1:] 50 | kld_loss = 0. 51 | normalization = batch.batch_size 52 | if isinstance(self.model, VRNMTModel): 53 | outputs, _, _, kld_loss = self.model( 54 | src, src_lengths, trg) 55 | elif isinstance(self.model, VNMTModel): 56 | outputs, _, _, kld_loss = self.model( 57 | src, src_lengths, trg) 58 | elif isinstance(self.model, NMTModel): 59 | outputs, _, _ = self.model( 60 | src, src_lengths, trg) 61 | 62 | probs = self.model.generator( 63 | outputs.view(-1, outputs.size(2))) 64 | 65 | loss, batch_stats = self.loss_func.compute_batch_loss( 66 | probs, ref, normalization, kld_loss=kld_loss) 67 | 68 | loss.backward() 69 | # 4. Update the parameters and statistics. 70 | self.optim.step() 71 | 72 | del loss, outputs, probs 73 | 74 | 75 | report_stats( 76 | batch_stats, epoch, idx+1, num_batches, 77 | self.progress_step, self.optim.lr) 78 | 79 | total_stats.update(batch_stats) 80 | self.progress_step += 1 81 | 82 | 83 | return total_stats 84 | 85 | def validate(self, valid_iter): 86 | """ Validate model. 87 | valid_iter: validate data iterator 88 | Returns: 89 | :obj:`nmt.Statistics`: validation loss statistics 90 | """ 91 | 92 | self.model.eval() 93 | total_stats = Statistics() 94 | self.model.zero_grad() 95 | for batch in valid_iter: 96 | kld_loss = 0. 97 | normalization = batch.batch_size 98 | 99 | src, src_lengths = batch.src, batch.src_Ls 100 | trg, trg_lengths = batch.trg, batch.trg_Ls 101 | ref = batch.trg[1:] 102 | # F-prop through the model. 103 | if isinstance(self.model, VRNMTModel): 104 | outputs, _, _, kld_loss = self.model( 105 | src, src_lengths, trg) 106 | elif isinstance(self.model, VNMTModel): 107 | outputs, _, _, kld_loss = self.model( 108 | src, src_lengths, trg) 109 | elif isinstance(self.model, NMTModel): 110 | outputs, _, _ = self.model( 111 | src, src_lengths, trg) 112 | 113 | probs = self.model.generator(outputs.view(-1, outputs.size(2))) 114 | loss, batch_stats = self.loss_func.compute_batch_loss( 115 | probs, ref, normalization, kld_loss=kld_loss) 116 | 117 | # # Update statistics. 118 | total_stats.update(batch_stats) 119 | del outputs, probs, batch_stats, loss 120 | # # Set model back to training mode. 121 | return total_stats 122 | 123 | def lr_step(self, ppl, epoch): 124 | return self.optim.update_lr(ppl, epoch) 125 | 126 | def dump_checkpoint(self, epoch, config, valid_stats): 127 | """ 128 | Save a checkpoint. 129 | 130 | Args: 131 | epoch (int): epoch number 132 | config (Config): global configurations 133 | valid_stats : statistics of last validation run 134 | """ 135 | 136 | real_model = (self.model.module 137 | if isinstance(self.model, nn.DataParallel) 138 | else self.model) 139 | 140 | model_state_dict = real_model.state_dict() 141 | 142 | model_state_dict = {k: v for k, v in model_state_dict.items()} 143 | 144 | 145 | checkpoint = { 146 | 'model': model_state_dict, 147 | 'config': config, 148 | 'epoch': epoch, 149 | 'optim': self.optim, 150 | } 151 | # torch.save(checkpoint, 152 | # '%s_acc_%.2f_loss_%.2f_e%d.pt' 153 | # % (config.save_model, valid_stats.accuracy(), 154 | # valid_stats.loss, epoch)) 155 | 156 | # torch.save(checkpoint, 157 | # '%s_acc_%.2f_loss_%.2f_e%d.pt' 158 | # % (config.save_model, valid_stats.accuracy(), 159 | # valid_stats.loss, epoch)) 160 | if epoch == config.epochs: 161 | torch.save(checkpoint, '%s.pt'% config.save_model) 162 | -------------------------------------------------------------------------------- /NMT/__init__.py: -------------------------------------------------------------------------------- 1 | from NMT.Models import * 2 | from NMT.Loss import * 3 | from NMT.Trainer import * 4 | from NMT.Optimizer import * 5 | from NMT.Statistics import * 6 | from NMT.translate import * 7 | from NMT.ModelConstructor import * 8 | -------------------------------------------------------------------------------- /NMT/translate/Beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from NMT.translate import Penalties 3 | from Utils.DataLoader import EOS_WORD 4 | 5 | class Beam(object): 6 | """ 7 | Args: 8 | size (int): beam size 9 | pad, bos, eos (int): indices of padding, beginning, and ending. 10 | k_best (int): nbest size to use 11 | global_scorer (:obj:`GlobalScorer`) 12 | """ 13 | def __init__(self, size, pad, bos, eos, config, k_best=1, global_scorer=None): 14 | 15 | self.size = size 16 | 17 | # The score for each translation on the beam. 18 | self.scores = torch.FloatTensor(size).zero_().cuda() 19 | self.all_scores = [] 20 | 21 | # The backpointers at each time-step. 22 | self.prev_ks = [] 23 | 24 | # The outputs at each time-step. 25 | self.next_ys = [torch.LongTensor(size) 26 | .fill_(pad).cuda()] 27 | self.next_ys[0][0] = bos 28 | 29 | # Has EOS topped the beam yet. 30 | self._eos = eos 31 | self.eos_top = False 32 | 33 | # The attentions (matrix) for each time. 34 | self.attn = [] 35 | 36 | # Time and k pair for finished. 37 | self.finished = [] 38 | self.k_best = k_best 39 | 40 | # Information for global scoring. 41 | self.global_scorer = global_scorer 42 | self.global_state = {} 43 | 44 | # Minimum prediction length 45 | self.min_length = 3 46 | 47 | # Apply Penalty at every step 48 | self.stepwise_penalty = config.stepwise_penalty 49 | self.block_ngram_repeat = config.block_ngram_repeat 50 | self.ignore_when_blocking = set(config.ignore_when_blocking) 51 | 52 | def get_current_state(self): 53 | "Get the outputs for the current timestep." 54 | return self.next_ys[-1] 55 | 56 | def get_current_origin(self): 57 | "Get the backpointers for the current timestep." 58 | return self.prev_ks[-1] 59 | 60 | def advance(self, word_probs, attn_out): 61 | """ 62 | Given prob over words for every last beam `wordLk` and attention 63 | `attn_out`: Compute and update the beam search. 64 | 65 | Parameters: 66 | 67 | * `word_probs`- probs of advancing from the last step (K x words) 68 | * `attn_out`- attention at the last step 69 | 70 | Returns: True if beam search is complete. 71 | """ 72 | num_words = word_probs.size(1) 73 | if self.stepwise_penalty: 74 | self.global_scorer.update_score(self, attn_out) 75 | # force the output to be longer than self.min_length 76 | # cur_len = len(self.next_ys) 77 | # if cur_len < self.min_length: 78 | # for k in range(len(word_probs)): 79 | # word_probs[k][self._eos] = -1e20 80 | # Sum the previous scores. 81 | if len(self.prev_ks) > 0: 82 | beam_scores = word_probs + \ 83 | self.scores.unsqueeze(1).expand_as(word_probs) 84 | # Don't let EOS have children. 85 | for i in range(self.next_ys[-1].size(0)): 86 | if self.next_ys[-1][i] == self._eos: 87 | beam_scores[i] = -1e20 88 | 89 | # Block ngram repeats 90 | if self.block_ngram_repeat > 0: 91 | ngrams = [] 92 | le = len(self.next_ys) 93 | for j in range(self.next_ys[-1].size(0)): 94 | hyp, _ = self.get_hypoo(le-1, j) 95 | ngrams = set() 96 | fail = False 97 | gram = [] 98 | for i in range(le-1): 99 | # Last n tokens, n = block_ngram_repeat 100 | gram = (gram + [hyp[i]])[-self.block_ngram_repeat:] 101 | # Skip the blocking if it is in the exclusion list 102 | if set(gram) & self.exclusion_tokens: 103 | continue 104 | if tuple(gram) in ngrams: 105 | fail = True 106 | ngrams.add(tuple(gram)) 107 | if fail: 108 | beam_scores[j] = -10e20 109 | else: 110 | beam_scores = word_probs[0] 111 | flat_beam_scores = beam_scores.view(-1) 112 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 113 | True, True) 114 | 115 | self.all_scores.append(self.scores) 116 | self.scores = best_scores 117 | 118 | # best_scores_id is flattened beam x word array, so calculate which 119 | # word and beam each score came from 120 | prev_k = best_scores_id / num_words 121 | self.prev_ks.append(prev_k) 122 | self.next_ys.append((best_scores_id - prev_k * num_words)) 123 | self.attn.append(attn_out.index_select(0, prev_k)) 124 | self.global_scorer.update_global_state(self) 125 | 126 | for i in range(self.next_ys[-1].size(0)): 127 | if self.next_ys[-1][i] == self._eos: 128 | global_scores = self.global_scorer.score(self, self.scores) 129 | s = global_scores[i] 130 | self.finished.append((s, len(self.next_ys) - 1, i)) 131 | 132 | # End condition is when top-of-beam is EOS and no global score. 133 | if self.next_ys[-1][0] == self._eos: 134 | self.all_scores.append(self.scores) 135 | self.eos_top = True 136 | 137 | def done(self): 138 | return self.eos_top and len(self.finished) >= self.k_best 139 | 140 | def sort_finished(self, minimum=None): 141 | if minimum is not None: 142 | i = 0 143 | # Add from beam until we have minimum outputs. 144 | while len(self.finished) < minimum: 145 | global_scores = self.global_scorer.score(self, self.scores) 146 | s = global_scores[i] 147 | self.finished.append((s, len(self.next_ys) - 1, i)) 148 | i += 1 149 | 150 | self.finished.sort(key=lambda a: -a[0]) 151 | scores = [sc for sc, _, _ in self.finished] 152 | ks = [(t, k) for _, t, k in self.finished] 153 | return scores, ks 154 | 155 | def get_hypo(self, timestep, k): 156 | """ 157 | Walk back to construct the full hypothesis. 158 | """ 159 | hyp, attn = [], [] 160 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 161 | hyp.append(self.next_ys[j+1][k]) 162 | attn.append(self.attn[j][k]) 163 | k = self.prev_ks[j][k] 164 | return hyp[::-1], torch.stack(attn[::-1]) 165 | 166 | 167 | class GNMTGlobalScorer(object): 168 | """ 169 | NMT re-ranking score from 170 | "Google's Neural Machine Translation System" :cite:`wu2016google` 171 | 172 | Args: 173 | alpha (float): length parameter 174 | beta (float): coverage parameter 175 | """ 176 | def __init__(self, alpha, beta, cov_penalty, length_penalty): 177 | self.alpha = alpha 178 | self.beta = beta 179 | penalty_builder = Penalties.PenaltyBuilder(cov_penalty, 180 | length_penalty) 181 | # Term will be subtracted from probability 182 | self.cov_penalty = penalty_builder.coverage_penalty() 183 | # Probability will be divided by this 184 | self.length_penalty = penalty_builder.length_penalty() 185 | 186 | def score(self, beam, logprobs): 187 | """ 188 | Rescores a prediction based on penalty functions 189 | """ 190 | normalized_probs = self.length_penalty(beam, 191 | logprobs, 192 | self.alpha) 193 | if not beam.stepwise_penalty: 194 | penalty = self.cov_penalty(beam, 195 | beam.global_state["coverage"], 196 | self.beta) 197 | normalized_probs -= penalty 198 | 199 | return normalized_probs 200 | 201 | def update_score(self, beam, attn): 202 | """ 203 | Function to update scores of a Beam that is not finished 204 | """ 205 | if "prev_penalty" in beam.global_state.keys(): 206 | beam.scores.add_(beam.global_state["prev_penalty"]) 207 | penalty = self.cov_penalty(beam, 208 | beam.global_state["coverage"] + attn, 209 | self.beta) 210 | beam.scores.sub_(penalty) 211 | 212 | def update_global_state(self, beam): 213 | "Keeps the coverage vector as sum of attentions" 214 | if len(beam.prev_ks) == 1: 215 | beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) 216 | beam.global_state["coverage"] = beam.attn[-1] 217 | self.cov_total = beam.attn[-1].sum(1) 218 | else: 219 | self.cov_total += torch.min(beam.attn[-1], 220 | beam.global_state['coverage']).sum(1) 221 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 222 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 223 | 224 | prev_penalty = self.cov_penalty(beam, 225 | beam.global_state["coverage"], 226 | self.beta) 227 | beam.global_state["prev_penalty"] = prev_penalty 228 | -------------------------------------------------------------------------------- /NMT/translate/Penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | 9 | Args: 10 | length_pen (str): option name of length pen 11 | cov_pen (str): option name of cov pen 12 | """ 13 | def __init__(self, cov_pen, length_pen): 14 | self.length_pen = length_pen 15 | self.cov_pen = cov_pen 16 | 17 | def coverage_penalty(self): 18 | if self.cov_pen == "wu": 19 | return self.coverage_wu 20 | elif self.cov_pen == "summary": 21 | return self.coverage_summary 22 | else: 23 | return self.coverage_none 24 | 25 | def length_penalty(self): 26 | if self.length_pen == "wu": 27 | return self.length_wu 28 | elif self.length_pen == "avg": 29 | return self.length_average 30 | else: 31 | return self.length_none 32 | 33 | """ 34 | Below are all the different penalty terms implemented so far 35 | """ 36 | 37 | def coverage_wu(self, beam, cov, beta=0.): 38 | """ 39 | NMT coverage re-ranking score from 40 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 41 | """ 42 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 43 | return beta * penalty 44 | 45 | def coverage_summary(self, beam, cov, beta=0.): 46 | """ 47 | Our summary penalty. 48 | """ 49 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 50 | penalty -= cov.size(1) 51 | return beta * penalty 52 | 53 | def coverage_none(self, beam, cov, beta=0.): 54 | """ 55 | returns zero as penalty 56 | """ 57 | return beam.scores.clone().fill_(0.0) 58 | 59 | def length_wu(self, beam, logprobs, alpha=0.): 60 | """ 61 | NMT length re-ranking score from 62 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 63 | """ 64 | 65 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 66 | ((5 + 1) ** alpha)) 67 | return (logprobs / modifier) 68 | 69 | def length_average(self, beam, logprobs, alpha=0.): 70 | """ 71 | Returns the average probability of tokens in a sequence. 72 | """ 73 | return logprobs / len(beam.next_ys) 74 | 75 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 76 | """ 77 | Returns unmodified scores. 78 | """ 79 | return logprobs 80 | -------------------------------------------------------------------------------- /NMT/translate/Translation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Utils.DataLoader import EOS_WORD 3 | from Utils.DataLoader import UNK_WORD 4 | from Utils.utils import trace 5 | 6 | 7 | class TranslationBuilder(object): 8 | """ 9 | Luong et al, 2015. Addressing the Rare Word Problem in Neural Machine Translation. 10 | """ 11 | def __init__(self, src_vocab, trg_vocab, config): 12 | """ 13 | Args: 14 | src_vocab (Vocab): source vocabulary 15 | trg_vocab (Vocab): source vocabulary 16 | replace_unk (bool): replace unknown words using attention 17 | """ 18 | self.src_vocab = src_vocab 19 | self.trg_vocab = trg_vocab 20 | self.replace_unk = config.replace_unk 21 | self.k_best = config.k_best 22 | 23 | def _build_sentence(self, src, pred, vocab, attn): 24 | """ 25 | build sentence using predicted output with the given vocabulary. 26 | """ 27 | tokens = [] 28 | for wid in pred: 29 | token = vocab.itos[int(wid)] 30 | if token == EOS_WORD: 31 | break 32 | tokens.append(token) 33 | 34 | if self.replace_unk and (attn is not None) and (src is not None): 35 | for i in range(len(tokens)): 36 | if tokens[i] == UNK_WORD: 37 | _, max_ = attn[i].max(0) 38 | tokens[i] = src[int(max_[0])] 39 | return tokens 40 | 41 | def from_batch_translator_output(self, outputs): 42 | """ 43 | build translation from batch output 44 | """ 45 | batch = outputs["batch"] 46 | batch_size = batch.batch_size 47 | preds, pred_score, attns, gold_score = list(zip(*zip( 48 | outputs["predictions"], 49 | outputs["scores"], 50 | outputs["attention"], 51 | outputs["gold_score"]))) 52 | 53 | src = batch.src.data 54 | trg = batch.trg.data 55 | 56 | translations = [] 57 | for b in range(batch_size): 58 | pred_sents = [self._build_sentence( 59 | src[:,b:,0], preds[b][n], self.trg_vocab, attns[b][n]) for n in range(self.k_best)] 60 | gold = trg[1:,b:,0].squeeze().cpu().numpy() 61 | input = src[:,b:,0].squeeze().cpu().numpy() 62 | 63 | input_sent = self._build_sentence(src[:,b:,0], input, self.src_vocab, None) 64 | gold_sent = self._build_sentence(src[:,b:,0], gold, self.trg_vocab, None) 65 | 66 | translation = Translation(input_sent[1:], pred_sents, 67 | attns[b], pred_score[b], gold_sent, 68 | gold_score[b]) 69 | translations.append(translation) 70 | 71 | return translations 72 | 73 | 74 | class Translation(object): 75 | """ 76 | Container for a translated sentence. 77 | 78 | Attributes: 79 | src (`LongTensor`): src word ids 80 | src_raw ([str]): raw src words 81 | 82 | pred_sents ([[str]]): words from the n-best translations 83 | pred_scores ([[float]]): log-probs of n-best translations 84 | attns ([`FloatTensor`]) : attention distributions for each translation 85 | gold_sent ([str]): words from gold translation 86 | gold_score ([float]): log-prob of gold translation 87 | 88 | """ 89 | def __init__(self, src_sent, pred_sents, 90 | attns, pred_scores, trg_sent, gold_score): 91 | self.src_sent = src_sent 92 | self.pred_sents = pred_sents 93 | self.attns = attns 94 | self.pred_scores = pred_scores 95 | self.gold_sent = trg_sent 96 | self.gold_score = gold_score 97 | 98 | def log(self, sent_number): 99 | """ 100 | Log translation to stdout. 101 | """ 102 | output = '\nINPUT {}: {}\n'.format(sent_number, " ".join(self.src_sent)) 103 | 104 | best_pred = self.pred_sents[0] 105 | best_score = self.pred_scores[0] 106 | pred_sent = ' '.join(best_pred) 107 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 108 | trace("PRED SCORE: {:.4f}".format(best_score)) 109 | 110 | if self.gold_sent is not None: 111 | trg_sent = ' '.join(self.gold_sent) 112 | output += 'GOLD {}: {}\n'.format(sent_number, trg_sent) 113 | # output += ("GOLD SCORE: {:.4f}".format(self.gold_score)) 114 | trace("GOLD SCORE: {:.4f}".format(self.gold_score)) 115 | if len(self.pred_sents) > 1: 116 | trace('\nBEST HYP:') 117 | for score, sent in zip(self.pred_scores, self.pred_sents): 118 | output += "[{:.4f}] {}\n".format(score, sent) 119 | 120 | return output 121 | -------------------------------------------------------------------------------- /NMT/translate/Translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | from NMT.translate.Beam import Beam 5 | 6 | from Utils.DataLoader import PAD_WORD 7 | from Utils.DataLoader import BOS_WORD 8 | from Utils.DataLoader import EOS_WORD 9 | 10 | 11 | from NMT.Models import RNNEncoder 12 | from NMT.Models import VNMTModel 13 | 14 | class BatchTranslator(object): 15 | """ 16 | Uses a model to translate a batch of sentences. 17 | 18 | Args: 19 | model (:obj:`nmt.Modules.NMTModel`): 20 | NMT model to use for translation 21 | fields (dict of Fields): data fields 22 | beam_size (int): size of beam to use 23 | k_best (int): number of translations produced 24 | max_length (int): maximum length output to produce 25 | global_scores (:obj:`GlobalScorer`): 26 | object to rescore final translations 27 | """ 28 | 29 | def __init__(self, model, config, trg_vocab, global_scorer): 30 | self.config = config 31 | self.vocab = trg_vocab 32 | self.model = model 33 | self.k_best = config.k_best 34 | self.max_length = 100 35 | self.global_scorer = global_scorer 36 | self.beam_size = config.beam_size 37 | self.stepwise_penalty = config.stepwise_penalty 38 | self.block_ngram_repeat = config.block_ngram_repeat 39 | self.ignore_when_blocking = set(config.ignore_when_blocking) 40 | 41 | self.PAD_WID = trg_vocab.stoi[PAD_WORD] 42 | self.BOS_WID = trg_vocab.stoi[BOS_WORD] 43 | self.EOS_WID = trg_vocab.stoi[EOS_WORD] 44 | 45 | 46 | def beam_search(self, batch_size, encoder_outputs, decoder_states, src_lengths, z=None): 47 | """ 48 | beam search. 49 | 50 | Args: 51 | batch (`Batch`): a batch from a dataset object 52 | encoder_outputs (`Variable`): the outputs of encoder hidden layer 53 | decoder_states 54 | """ 55 | beam = [Beam(self.beam_size, self.PAD_WID, self.BOS_WID, self.EOS_WID, 56 | self.config, global_scorer=self.global_scorer) 57 | for _ in range(batch_size)] 58 | if z is not None: 59 | z = Variable(z.data.repeat(1, self.beam_size, 1), requires_grad=True) 60 | 61 | for i in range(self.max_length): 62 | if all((b.done() for b in beam)): 63 | break 64 | trg = Variable(torch.stack( 65 | [cad.get_current_state() for cad in beam] 66 | ).t().contiguous().view(1, -1)) 67 | trg = trg.unsqueeze(2) 68 | # Run one step. 69 | 70 | decoder_input = self.model.trg_embedding(trg) 71 | if isinstance(self.model, VNMTModel): 72 | decoder_input = torch.cat([decoder_input, z], -1) 73 | decoder_output, decoder_states, attn = self.model.decoder( 74 | decoder_input, encoder_outputs, src_lengths, decoder_states)[:3] 75 | 76 | decoder_output = decoder_output.squeeze(0) # [b x H] 77 | 78 | # (b) Compute a vector of batch x beam word scores. 79 | output = F.log_softmax(self.model.generator(decoder_output).data, dim=-1) 80 | output = output.view(self.beam_size, batch_size, -1) 81 | beam_attn = attn["std"].view(self.beam_size, batch_size, -1) 82 | 83 | for j, b in enumerate(beam): 84 | b.advance(output[:, j], beam_attn.data[:, j, :src_lengths[j]]) 85 | decoder_states.beam_update( 86 | j, b.get_current_origin(), self.beam_size) 87 | return beam 88 | 89 | def translate_batch(self, batch, exclusion_tokens=[]): 90 | """ 91 | Translate a batch of sentences. 92 | 93 | Args: 94 | batch (Batch): a batch from a dataset object 95 | """ 96 | batch_size = batch.batch_size 97 | 98 | # 1. encoding 99 | src, src_lengths = batch.src, batch.src_Ls 100 | encoder_output, encoder_state = self.model.encoder( 101 | self.model.src_embedding(src), src_lengths) 102 | z = None 103 | if isinstance(self.model, VNMTModel): 104 | # re-parameterize 105 | z, mu, logvar = self.model.reparameterize(encoder_state) 106 | # encoder to decoder 107 | decoder_state = self.model.encoder2decoder(encoder_state) 108 | 109 | # decoder_input = torch.cat([ 110 | # self.model.trg_embedding(trg), 111 | # z.unsqueeze(0).repeat(trg.size(0) ,1, 1)], 112 | # -1) 113 | 114 | # 2. repeat source `beam_size` times. 115 | 116 | encoder_outputs = Variable( 117 | encoder_output.data.repeat( 118 | 1, self.beam_size, 1), 119 | requires_grad=False) 120 | 121 | src_lengths = src_lengths.repeat(self.beam_size) 122 | decoder_states = decoder_state 123 | decoder_states.repeat_beam_size_times(self.beam_size) 124 | 125 | # 3. Generate translations using beam search. 126 | beam = self.beam_search( 127 | batch_size, encoder_outputs, decoder_states, src_lengths, z) 128 | 129 | # 4. Extract sentences from beam. 130 | batch_trans = self.extract_trans_from_beam(beam) 131 | 132 | batch_trans["gold_score"] = self._run_target(batch) 133 | batch_trans["batch"] = batch 134 | return batch_trans 135 | 136 | def extract_trans_from_beam(self, beam): 137 | """ 138 | extract translations from beam. 139 | """ 140 | ret = {"predictions": [], 141 | "scores": [], 142 | "attention": [] 143 | } 144 | 145 | for b in beam: 146 | scores, ks = b.sort_finished(minimum=self.k_best) 147 | hypos, attn = [], [] 148 | for i, (times, k) in enumerate(ks[:self.k_best]): 149 | hypo, att = b.get_hypo(times, k) 150 | hypos.append(hypo) 151 | attn.append(att) 152 | ret["predictions"].append(hypos) 153 | ret["scores"].append(scores) 154 | ret["attention"].append(attn) 155 | return ret 156 | 157 | def _run_target(self, batch): 158 | src_lengths = batch.src_Ls 159 | src = batch.src 160 | trg_in = batch.trg[:-1] 161 | trg_out = batch.trg[1:] 162 | # (1) run the encoder on the src 163 | 164 | encoder_outputs, encoder_state = self.model.encoder( 165 | self.model.src_embedding(src), src_lengths) 166 | 167 | z = None 168 | if isinstance(self.model, VNMTModel): 169 | z, mu, logvar = self.model.reparameterize(encoder_state) 170 | 171 | decoder_state = self.model.encoder2decoder(encoder_state) 172 | 173 | # (2) if a target is specified, compute the 'goldScore' 174 | # (i.e. log likelihood) of the target under the model 175 | gold_scores = torch.FloatTensor(batch.batch_size).fill_(0).cuda() 176 | decoder_input = self.model.trg_embedding(trg_in) 177 | if isinstance(self.model, VNMTModel): 178 | decoder_input = torch.cat( 179 | [decoder_input, 180 | z.unsqueeze(0).repeat(trg_in.size(0),1, 1)], 181 | -1) 182 | decoder_output, decoder_states, attn = self.model.decoder( 183 | decoder_input, encoder_outputs, src_lengths, decoder_state)[:3] 184 | 185 | 186 | trg_pad = self.vocab.stoi[PAD_WORD] 187 | for dec, trg in zip(decoder_output, trg_out.data): 188 | # Log prob of each word. 189 | out = F.log_softmax(self.model.generator(dec).data, dim=-1) 190 | scores = out.data.gather(1, trg) 191 | scores.masked_fill_(trg.eq(trg_pad), 0) 192 | gold_scores += scores.squeeze().float() 193 | return gold_scores 194 | -------------------------------------------------------------------------------- /NMT/translate/__init__.py: -------------------------------------------------------------------------------- 1 | from .Translator import BatchTranslator 2 | from .Translation import Translation, TranslationBuilder 3 | from .Beam import Beam, GNMTGlobalScorer 4 | from .Penalties import PenaltyBuilder 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Variational Neural Machine Translation System 2 | ========== 3 | 4 | Implemented by Pytorch 0.4, some modules references to OpenNMT-py. 5 | 6 | 7 | 8 | 9 | ## References 10 | 11 | 1. Su, Jinsong, et al. "Variational Recurrent Neural Machine Translation." arXiv preprint arXiv:1801.05119 (2018). 12 | 13 | 2. Zhang, Biao, et al. "Variational neural machine translation." arXiv preprint arXiv:1605.07869 (2016) 14 | 15 | ## Differences 16 | 17 | For Variational NMT, 18 | I did not use the mean-pooling for both sides (source and target). 19 | I tested only using the last source hidden state is sufficient to achieve good performance. 20 | 21 | For Variational Recurrent NMT, 22 | I tested only using the current RNN state is sufficient to achieve good performance. 23 | 24 | The paper 25 | 26 | `Yang, Zichao, et al. "Improved variational autoencoders for text modeling using dilated convolutions." arXiv preprint arXiv:1702.08139 (2017). ` 27 | 28 | explains the reason why use GRU instead of LSTM for building RNN cell, in general, VAE-LSTM-decoder performs worse than vanilla-LSTM-decoder. 29 | 30 | 31 | ## Usage 32 | 33 | Training 34 | 35 | python train.py --config config/nmt.ini 36 | 37 | Test 38 | 39 | python translate.py --config config/nmt.ini -------------------------------------------------------------------------------- /Utils/DataLoader.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import numpy as np 4 | import math 5 | import torch 6 | from collections import Counter, defaultdict 7 | from Utils.utils import aeq 8 | from Utils.utils import trace 9 | from itertools import chain 10 | import random 11 | 12 | 13 | PAD_WORD = '' # 0 14 | UNK_WORD = '' # 1 15 | BOS_WORD = '' # 2 16 | EOS_WORD = '' # 3 17 | 18 | 19 | class Vocab(object): 20 | def __init__(self, lang=None, config=None, **kwargs): 21 | self.specials = [PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD] 22 | self.counter = Counter() 23 | self.stoi = {} 24 | self.itos = {} 25 | self.lang = lang 26 | self.weights = None 27 | self.min_freq = config.min_freq 28 | def make_vocab(self, dataset): 29 | for x in dataset: 30 | self.counter.update(x) 31 | 32 | if self.min_freq > 1: 33 | self.counter = {w:i for w, i in filter( 34 | lambda x:x[1] >= self.min_freq, self.counter.items())} 35 | self.vocab_size = 0 36 | for w in self.specials: 37 | self.stoi[w] = self.vocab_size 38 | self.vocab_size += 1 39 | 40 | for w in self.counter.keys(): 41 | self.stoi[w] = self.vocab_size 42 | self.vocab_size += 1 43 | 44 | self.itos = {i:w for w, i in self.stoi.items()} 45 | 46 | def load_pretrained_embedding(self, embed_path, embed_dim): 47 | self.weights = np.zeros((self.vocab_size, int(embed_dim))) 48 | with open(embed_path, "r", errors="replace") as embed_fin: 49 | for line in embed_fin: 50 | cols = line.rstrip("\n").split() 51 | w = cols[0] 52 | if w in self.stoi: 53 | weight = np.array(cols[1:]) 54 | self.weights[self.stoi[w]] = weight 55 | else: 56 | pass 57 | embed_fin.close() 58 | for i in range(1, 2): 59 | self.weights[i] = np.zeros((embed_dim,)) 60 | # self.weights[i] = np.random.random_sample( 61 | # (self.config.embed_dim,)) 62 | self.weights = torch.from_numpy(self.weights) 63 | 64 | def __getitem__(self, key): 65 | return self.weights[key] 66 | 67 | def __len__(self): 68 | return self.vocab_size 69 | 70 | class DataSet(list): 71 | def __init__(self, *args, config=None, is_train=True, dataset="train"): 72 | self.config = config 73 | self.is_train = is_train 74 | self.src_lang = config.src_lang 75 | self.trg_lang = config.trg_lang 76 | self.dataset = dataset 77 | self.data_path = ( 78 | os.path.join(self.config.data_path, dataset + "." + self.src_lang), 79 | os.path.join(self.config.data_path, dataset + "." + self.trg_lang) 80 | ) 81 | super(DataSet, self).__init__(*args) 82 | 83 | def read(self): 84 | with open(self.data_path[0], "r") as fin_src,\ 85 | open(self.data_path[1], "r") as fin_trg: 86 | for line1, line2 in zip(fin_src, fin_trg): 87 | src, trg = line1.rstrip("\r\n"), line2.rstrip("\r\n") 88 | src = src.split() 89 | trg = trg.split() 90 | if self.is_train: 91 | if len(src) <= self.config.max_seq_len and \ 92 | len(trg) <= self.config.max_seq_len: 93 | self.append((src, trg)) 94 | else: 95 | self.append((src, trg)) 96 | fin_src.close() 97 | fin_trg.close() 98 | 99 | def _numericalize(self, words, stoi): 100 | return [1 if x not in stoi else stoi[x] for x in words] 101 | 102 | 103 | def numericalize(self, src_w2id, trg_w2id): 104 | for i, example in enumerate(self): 105 | x, y = example 106 | x = self._numericalize(x, src_w2id) 107 | y = self._numericalize(y, trg_w2id) 108 | self[i] = (x, y) 109 | 110 | class DataBatchIterator(object): 111 | def __init__(self, config, dataset="train", 112 | is_train=True, batch_size=64, 113 | shuffle=False, sample=False, 114 | sort_in_batch=True): 115 | self.config = config 116 | self.examples = DataSet(config=config, is_train=is_train, dataset=dataset) 117 | self.src_vocab = Vocab(lang=config.src_lang, config=config) 118 | self.trg_vocab = Vocab(lang=config.trg_lang, config=config) 119 | self.is_train = (dataset == "train") 120 | self.max_seq_len = config.max_seq_len 121 | self.sort_in_batch = sort_in_batch 122 | self.is_shuffle = shuffle 123 | self.is_sample = sample 124 | self.batch_size = batch_size 125 | self.num_batches = 0 126 | 127 | def set_vocab(self, src_vocab, trg_vocab): 128 | self.src_vocab = src_vocab 129 | self.trg_vocab = trg_vocab 130 | 131 | def load(self, vocab_cache=None): 132 | if not vocab_cache and self.is_train: 133 | self.examples.read() 134 | self.src_vocab.make_vocab([x[0] for x in self.examples]) 135 | self.trg_vocab.make_vocab([x[1] for x in self.examples]) 136 | self.examples.numericalize( 137 | src_w2id=self.src_vocab.stoi, 138 | trg_w2id=self.trg_vocab.stoi) 139 | 140 | # self.src_vocab.load_pretrained_embedding( 141 | # self.config.embed_path+".%s"%self.config.source_lang, self.config.embed_dim) 142 | # self.trg_vocab.load_pretrained_embedding( 143 | # self.config.embed_path+".%s"%self.config.target_lang, self.config.embed_dim) 144 | 145 | if not self.is_train: 146 | self.examples.read() 147 | assert len(self.src_vocab) > 0 148 | self.examples.numericalize( 149 | src_w2id=self.src_vocab.stoi, 150 | trg_w2id=self.trg_vocab.stoi) 151 | # self.src_vocab.load_pretrained_embedding( 152 | # self.config.embed_path+".%s"%self.config.source_lang, self.config.embed_dim) 153 | # self.trg_vocab.load_pretrained_embedding( 154 | # self.config.embed_path+".%s"%self.config.target_lang, self.config.embed_dim) 155 | self.num_batches = math.ceil(len(self.examples)/self.batch_size) 156 | 157 | def _pad(self, sentence, max_L, w2id, add_bos=False, add_eos=False): 158 | if add_bos: 159 | sentence = [w2id[BOS_WORD]] + sentence 160 | if add_eos: 161 | sentence = sentence + [w2id[EOS_WORD]] 162 | if len(sentence) < max_L: 163 | sentence = sentence + [w2id[PAD_WORD]] * (max_L-len(sentence)) 164 | return [x for x in sentence] 165 | 166 | 167 | def pad_seq_pair(self, samples): 168 | if self.sort_in_batch: 169 | samples = sorted(samples, key=lambda x: len(x[0]), reverse=True) 170 | pairs = [pair for pair in samples] 171 | 172 | 173 | src_Ls = [len(pair[0])+2 for pair in pairs] 174 | trg_Ls = [len(pair[1])+2 for pair in pairs] 175 | 176 | max_trg_Ls = max(trg_Ls) 177 | max_src_Ls = max(src_Ls) 178 | src = [self._pad( 179 | src, max_src_Ls, self.src_vocab.stoi, add_bos=True, add_eos=True) for src, _ in pairs] 180 | trg = [self._pad( 181 | trg, max_trg_Ls, self.trg_vocab.stoi, add_bos=True, add_eos=True) for _, trg in pairs] 182 | 183 | batch = Batch() 184 | batch.src = torch.LongTensor(src).transpose(0, 1).cuda() 185 | batch.trg = torch.LongTensor(trg).transpose(0, 1).cuda() 186 | 187 | batch.src_Ls = torch.LongTensor(src_Ls).cuda() 188 | batch.trg_Ls = torch.LongTensor(trg_Ls).cuda() 189 | return batch 190 | 191 | 192 | def __iter__(self): 193 | if self.is_shuffle: 194 | random.shuffle(self.examples) 195 | total_num = len(self.examples) 196 | for i in range(self.num_batches): 197 | if self.is_sample: 198 | samples = random.sample(self.examples, self.batch_size) 199 | else: 200 | samples = self.examples[i * self.batch_size: \ 201 | min(total_num, self.batch_size*(i+1))] 202 | yield self.pad_seq_pair(samples) 203 | 204 | class Batch(object): 205 | def __init__(self): 206 | self.src = None 207 | self.trg = None 208 | self.src_Ls = None 209 | self.trg_Ls = None 210 | 211 | def __len__(self): 212 | return self.src_Ls.size(0) 213 | @property 214 | def batch_size(self): 215 | return self.src_Ls.size(0) 216 | 217 | -------------------------------------------------------------------------------- /Utils/args.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | 4 | def parse_args(script_name): 5 | """load arguments from command line. Default values are stored in config.ini""" 6 | parser = ArgumentParser() 7 | 8 | parse_base_args(parser) 9 | parse_data_args(parser) 10 | parse_gpu_args(parser) 11 | parse_network_args(parser) 12 | 13 | parse_embed_args(parser) 14 | 15 | if script_name == "train": 16 | parse_optim_args(parser) 17 | parse_train_args(parser) 18 | 19 | elif script_name == "translate": 20 | parse_translate_args(parser) 21 | 22 | parse_loss_args(parser) 23 | parse_logging_args(parser) 24 | args = parser.parse_args() 25 | 26 | return args, parser 27 | 28 | 29 | def parse_base_args(parser): 30 | parser.add_argument('--config', type=str, required=True, 31 | help="path to datasets") 32 | parser.add_argument('--system', type=str, 33 | help="use which NMT system") 34 | 35 | def parse_data_args(parser): 36 | group = parser.add_argument_group('Data') 37 | group.add_argument('--data_path', type=str, 38 | help="path to datasets") 39 | 40 | group.add_argument('--save_vocab', type=str, 41 | help="path to vocab files") 42 | 43 | group.add_argument('--max_seq_len', type=int, default=50, 44 | help="Maximum sequence length") 45 | 46 | group.add_argument('--trg_lang', type=str, 47 | help="target language name suffix") 48 | 49 | group.add_argument('--src_lang', type=str, 50 | help="source language name suffix") 51 | 52 | def parse_gpu_args(parser): 53 | group = parser.add_argument_group('GPU') 54 | group.add_argument('--gpu_ids', default=[], nargs='+', type=int, 55 | help="Use CUDA on the listed devices.") 56 | group.add_argument('--use_cpu', default=False, action='store_true') 57 | 58 | def parse_embed_args(parser): 59 | group = parser.add_argument_group('Embedding') 60 | group.add_argument('--min_freq', type=int, default=5, 61 | help="min frequency for the prepared data") 62 | group.add_argument('--src_embed_dim', type=int, default=500, 63 | help='Word embedding size for src.') 64 | group.add_argument('--trg_embed_dim', type=int, default=500, 65 | help='Word embedding size for trg.') 66 | # group.add_argument('-feat_size', type=int, default=-1, 67 | # help="""If specified, feature embedding sizes 68 | # will be set to this. Otherwise, feat_vec_exponent 69 | # will be used.""") 70 | # group.add_argument('-feat_vec_exponent', type=float, default=0.7, 71 | # help="""If -feat_merge_size is not set, feature 72 | # embedding sizes will be set to N^feat_vec_exponent 73 | # where N is the number of values the feature takes.""") 74 | 75 | # parser.add_argument('--share_decoder_embeddings', action='store_true', 76 | # help="""Use a shared weight matrix for the input and 77 | # output word embeddings in the decoder.""") 78 | # group.add_argument('-share_embeddings', action='store_true', 79 | # help="""Share the word embeddings between encoder 80 | # and decoder. Need to use shared dictionary for this 81 | # option.""") 82 | # group.add_argument('-position_encoding', action='store_true', 83 | # help="""Use a sin to mark relative words positions. 84 | # Necessary for non-RNN style models. 85 | # """) 86 | # group.add_argument('--dynamic_dict', action='store_true', 87 | # help="Create dynamic dictionaries") 88 | # group.add_argument('--share_vocab', action='store_true', 89 | # help="Share source and target vocabulary") 90 | 91 | def parse_train_args(parser): 92 | # Model loading/saving options 93 | 94 | group = parser.add_argument_group('train') 95 | 96 | group.add_argument('--save_model', default='output/model', 97 | help="""Model filename validation perplexity""") 98 | 99 | group.add_argument('--checkpoint', default='', type=str, 100 | help="""If training from a checkpoint then this is the 101 | path to the pretrained model's state_dict.""") 102 | 103 | group.add_argument('--epochs', type=int, default=13, 104 | help='Number of training epochs') 105 | 106 | group.add_argument('--batch_size', type=int, default=64, 107 | help='Maximum batch size for training') 108 | 109 | group.add_argument('--valid_batch_size', type=int, default=32, 110 | help='Maximum batch size for training') 111 | 112 | group.add_argument('--param_init', type=float, default=0.1, 113 | help="""Parameters are initialized over uniform distribution 114 | with support (-param_init, param_init). 115 | Use 0 to not use initialization""") 116 | 117 | group.add_argument('--param_init_glorot', action='store_true', 118 | help="""Init parameters with xavier_uniform. 119 | Required for transformer.""") 120 | 121 | 122 | 123 | 124 | 125 | def parse_network_args(parser): 126 | group = parser.add_argument_group('Network') 127 | 128 | 129 | group.add_argument('--enc_num_layers', type=int, default=2, 130 | help='Number of layers in the encoder') 131 | 132 | group.add_argument('--dec_num_layers', type=int, default=2, 133 | help='Number of layers in the decoder') 134 | group.add_argument('--bidirectional', action='store_true', 135 | help="""bidirectional encoding for encoder.""") 136 | 137 | group.add_argument('--hidden_size', type=int, default=500, 138 | help='Size of rnn hidden states') 139 | 140 | group.add_argument('--latent_size', type=int, default=300, 141 | help='Size of latent states') 142 | 143 | group.add_argument('--rnn_type', type=str, default='LSTM', 144 | choices=['LSTM', 'GRU'], 145 | help="""The gate type to use in the RNNs""") 146 | 147 | group.add_argument('--attn_type', type=str, default='general', 148 | choices=['dot', 'general', 'mlp'], 149 | help="""The attention type to use: 150 | dotprod or general (Luong) or MLP (Bahdanau)""") 151 | group.add_argument('--meanpool', action="store_true", 152 | help='use the mean pooling of encoder outputs for VAE') 153 | # group.add_argument('--use_target', action='store_true', 154 | # help="""use target for variational re-parameterization.""") 155 | group.add_argument('--dropout', type=float, default=0.3, 156 | help="Dropout probability; applied in RNN stacks.") 157 | 158 | def parse_loss_args(parser): 159 | # loss options. 160 | group = parser.add_argument_group('Loss Functions') 161 | group.add_argument("--kld_weight", default=0.05, type=float, 162 | help="weight for the Kullback-Leibler divergence Loss in total loss score.") 163 | group.add_argument("--start_increase_kld_at", default=8, type=int, 164 | help="start to increase KLD loss weight at .") 165 | 166 | 167 | def parse_optim_args(parser): 168 | group = parser.add_argument_group('Optimization- Rate') 169 | # Optimization options 170 | group = parser.add_argument_group('Optimizer') 171 | 172 | 173 | 174 | group.add_argument('--optim', default='sgd', 175 | choices=['sgd', 'adagrad', 'adadelta', 'adam', 176 | 'sparseadam'], 177 | help="""Optimization method.""") 178 | 179 | group.add_argument('--adagrad_accum_init', type=float, default=0, 180 | help="""Initializes the accumulator values in adagrad. 181 | Mirrors the initial_accumulator_value option 182 | in the tensorflow adagrad (use 0.1 for their default). 183 | """) 184 | 185 | group.add_argument('--max_grad_norm', type=float, default=5, 186 | help="""If the norm of the gradient vector exceeds this, 187 | renormalize it to have the norm equal to 188 | max_grad_norm""") 189 | 190 | 191 | group.add_argument('--lr', type=float, default=1.0, 192 | help="""Starting learning rate. 193 | Recommended settings: sgd = 1, adagrad = 0.1, 194 | adadelta = 1, adam = 0.001""") 195 | group.add_argument('--lr_decay_rate', type=float, default=0.5, 196 | help="""If update_learning_rate, decay learning rate by 197 | this much if (i) perplexity does not decrease on the 198 | validation set or (ii) epoch has gone past 199 | start_decay_at""") 200 | group.add_argument('--start_decay_at', type=int, default=8, 201 | help="""Start decaying every epoch after and including this 202 | epoch""") 203 | 204 | group.add_argument('--decay_method', type=str, default="", 205 | choices=['noam'], help="Use a custom decay rate.") 206 | 207 | group.add_argument('--warmup_steps', type=int, default=4000, 208 | help="""Number of warmup steps for custom decay.""") 209 | group.add_argument('--alpha', type=float, default=0.9, 210 | help="""The alpha parameter used by RMSprop.""") 211 | group.add_argument('--eps', type=float, default=1e-8, 212 | help="""The eps parameter used by RMSprop/Adam.""") 213 | 214 | group.add_argument('--weight_decay', type=float, default=0, 215 | help="""The weight_decay parameter used by RMSprop.""") 216 | group.add_argument('--momentum', type=float, default=0, 217 | help="""The momentum parameter used by RMSprop[0]/SGD[0.9].""") 218 | group.add_argument('--adam_beta1', type=float, default=0.9, 219 | help="""The beta1 parameter used by Adam. 220 | Almost without exception a value of 0.9 is used in 221 | the literature, seemingly giving good results, 222 | so we would discourage changing this value from 223 | the default without due consideration.""") 224 | group.add_argument('--adam_beta2', type=float, default=0.999, 225 | help="""The beta2 parameter used by Adam. 226 | Typically a value of 0.999 is recommended, as this is 227 | the value suggested by the original paper describing 228 | Adam, and is also the value adopted in other frameworks 229 | such as Tensorflow and Keras, i.e. see: 230 | https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 231 | https://keras.io/optimizers/ . 232 | Whereas recently the paper "Attention is All You Need" 233 | suggested a value of 0.98 for beta2, this parameter may 234 | not work well for normal models / default 235 | baselines.""") 236 | 237 | def parse_translate_args(parser): 238 | 239 | group = parser.add_argument_group('Translator') 240 | group.add_argument('--output', default='pred.txt', 241 | help="""Path to output the predictions (each line will 242 | be the decoded sequence""") 243 | 244 | group.add_argument('--save_model', default='model', 245 | help="""Pre-trained models""") 246 | 247 | group.add_argument('--beam_size', type=int, default=5, 248 | help='Beam size') 249 | 250 | group.add_argument('--max_length', type=int, default=100, 251 | help='Maximum prediction length.') 252 | 253 | # Alpha and Beta values for Google Length + Coverage penalty 254 | # Described here: https://arxiv.org/pdf/1609.08144.pdf, Section 7 255 | group.add_argument('-stepwise_penalty', action='store_true', 256 | help="""Apply penalty at every decoding step. 257 | Helpful for summary penalty.""") 258 | group.add_argument('--length_penalty', default='none', 259 | choices=['none', 'wu', 'avg'], 260 | help="""Length Penalty to use.""") 261 | group.add_argument('--coverage_penalty', default='none', 262 | choices=['none', 'wu', 'summary'], 263 | help="""Coverage Penalty to use.""") 264 | group.add_argument('--alpha', type=float, default=0., 265 | help="""Google NMT length penalty parameter 266 | (higher = longer generation)""") 267 | group.add_argument('--beta', type=float, default=-0., 268 | help="""Coverage penalty parameter""") 269 | 270 | group.add_argument('--block_ngram_repeat', type=int, default=0, 271 | help='Block repetition of ngrams during decoding.') 272 | 273 | group.add_argument('--ignore_when_blocking', nargs='+', type=str, 274 | default=[], 275 | help="""Ignore these strings when blocking repeats. 276 | You want to block sentence delimiters.""") 277 | 278 | group.add_argument('--replace_unk', action="store_true", 279 | help="""Replace the generated UNK tokens with the 280 | source token that had highest attention weight. If 281 | phrase_table is provided, it will lookup the 282 | identified source token and give the corresponding 283 | target token. If it is not provided(or the identified 284 | source token does not exist in the table) then it 285 | will copy the source token""") 286 | 287 | group.add_argument('-k_best', type=int, default=1, 288 | help="""If verbose is set, will output the k_best 289 | decoded sentences""") 290 | 291 | 292 | 293 | def parse_logging_args(parser): 294 | group = parser.add_argument_group('Logging') 295 | group.add_argument('--verbose', action="store_true", 296 | help='Print scores and predictions for each sentence') 297 | group.add_argument('--plot_attn', action="store_true", 298 | help='Plot attention matrix for each pair') 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | -------------------------------------------------------------------------------- /Utils/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | This module provides a Python implementation of BLEU and smooth-BLEU. 18 | Smooth BLEU is computed following the method outlined in the paper: 19 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 20 | evaluation metrics for machine translation. COLING 2004. 21 | """ 22 | 23 | import collections 24 | import math 25 | 26 | 27 | def _get_ngrams(segment, max_order): 28 | """Extracts all n-grams upto a given maximum order from an input segment. 29 | Args: 30 | segment: text segment from which n-grams will be extracted. 31 | max_order: maximum length in tokens of the n-grams returned by this 32 | methods. 33 | Returns: 34 | The Counter containing all n-grams upto max_order in segment 35 | with a count of how many times each n-gram occurred. 36 | """ 37 | ngram_counts = collections.Counter() 38 | for order in range(1, max_order + 1): 39 | for i in range(0, len(segment) - order + 1): 40 | ngram = tuple(segment[i:i+order]) 41 | ngram_counts[ngram] += 1 42 | return ngram_counts 43 | 44 | 45 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 46 | smooth=False): 47 | """Computes BLEU score of translated segments against one or more references. 48 | Args: 49 | reference_corpus: list of lists of references for each translation. Each 50 | reference should be tokenized into a list of tokens. 51 | translation_corpus: list of translations to score. Each translation 52 | should be tokenized into a list of tokens. 53 | max_order: Maximum n-gram order to use when computing BLEU score. 54 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 55 | Returns: 56 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 57 | precisions and brevity penalty. 58 | """ 59 | matches_by_order = [0] * max_order 60 | possible_matches_by_order = [0] * max_order 61 | reference_length = 0 62 | translation_length = 0 63 | for (references, translation) in zip(reference_corpus, 64 | translation_corpus): 65 | reference_length += min(len(r) for r in references) 66 | translation_length += len(translation) 67 | 68 | merged_ref_ngram_counts = collections.Counter() 69 | for reference in references: 70 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 71 | translation_ngram_counts = _get_ngrams(translation, max_order) 72 | overlap = translation_ngram_counts & merged_ref_ngram_counts 73 | for ngram in overlap: 74 | matches_by_order[len(ngram)-1] += overlap[ngram] 75 | for order in range(1, max_order+1): 76 | possible_matches = len(translation) - order + 1 77 | if possible_matches > 0: 78 | possible_matches_by_order[order-1] += possible_matches 79 | 80 | precisions = [0] * max_order 81 | for i in range(0, max_order): 82 | if smooth: 83 | precisions[i] = ((matches_by_order[i] + 1.) / 84 | (possible_matches_by_order[i] + 1.)) 85 | else: 86 | if possible_matches_by_order[i] > 0: 87 | precisions[i] = (float(matches_by_order[i]) / 88 | possible_matches_by_order[i]) 89 | else: 90 | precisions[i] = 0.0 91 | 92 | if min(precisions) > 0: 93 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 94 | geo_mean = math.exp(p_log_sum) 95 | else: 96 | geo_mean = 0 97 | 98 | ratio = float(translation_length) / reference_length 99 | 100 | if ratio > 1.0: 101 | bp = 1. 102 | else: 103 | bp = math.exp(1 - 1. / ratio) 104 | 105 | bleu = geo_mean * bp 106 | 107 | return (bleu, precisions, bp, ratio, translation_length, reference_length) -------------------------------------------------------------------------------- /Utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from configparser import SafeConfigParser 5 | from Utils.utils import trace 6 | 7 | 8 | def get_correct_args(config, items, section): 9 | d = {} 10 | for key, value in items: 11 | val = None 12 | if key == "gpu_ids": 13 | value = eval(value) 14 | if isinstance(value, list): 15 | val = value 16 | else: 17 | val = [int(value)] 18 | else: 19 | try: 20 | val = config[section].getint(key) 21 | except: 22 | try: 23 | val = config[section].getfloat(key) 24 | except: 25 | try: 26 | val = config[section].getboolean(key) 27 | except: 28 | val = value 29 | d[key]= val 30 | return d 31 | 32 | 33 | def read_config(args, args_parser, config_file=None): 34 | 35 | if config_file is None: 36 | return args_parser.parse_args() 37 | if not os.path.isfile(config_file): 38 | trace("""# Cannot find the configuration file. 39 | {} does not exist! Please check.""".format(config_file)) 40 | sys.exit(1) 41 | config = SafeConfigParser() 42 | config.read(config_file) 43 | for section in config.sections(): 44 | default = get_correct_args(config, config.items(section), section) 45 | args_parser.set_defaults(**{ 46 | k:v for k,v in filter( 47 | lambda x: hasattr(args, x[0]), default.items())}) 48 | 49 | args = args_parser.parse_args() 50 | return args 51 | 52 | def config_debug(debug): 53 | print(os.environ) 54 | 55 | def format_config(args): 56 | ret = "\n" 57 | pattern = r'\' 58 | for key, value in vars(args).items(): 59 | class_type = re.search(pattern, str(type(value))).group(1) 60 | class_type = "[{}]".format(class_type) 61 | value_string = str(value) 62 | if len(value_string) > 80: 63 | value_string = "/".join(value_string.split("/")[:2]) +\ 64 | "/.../" + "/".join(value_string.split("/")[-2:]) 65 | ret += " {}\t{}\t{}\n".format(key.ljust(15), class_type.ljust(8), value_string) 66 | return ret 67 | 68 | 69 | def config_device(config): 70 | if config.device_ids: 71 | config.device_ids = [int(x) for x in config.device_ids.split(",")] 72 | else: 73 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 74 | [str(idx) for idx in list( 75 | range(config.gpu_ids, config.gpu_ids + config.num_gpus))]) 76 | config.device_ids = list(range(config.num_gpus)) 77 | -------------------------------------------------------------------------------- /Utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.ticker as ticker 3 | import numpy as np 4 | 5 | from matplotlib.font_manager import FontProperties 6 | from Utils.DataLoader import BOS_WORD, EOS_WORD 7 | JaFont = FontProperties(fname = '/usr/share/fonts/truetype/takao-gothic/TakaoGothic.ttf') 8 | ZhFont = FontProperties(fname = '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc') 9 | 10 | def plot_attn(source, target, attn): 11 | # Set up figure with colorbar 12 | fig = plt.figure() 13 | ax = fig.add_subplot(111) 14 | cax = ax.matshow(attn.numpy(), cmap='bone') 15 | fig.colorbar(cax) 16 | # Set up axes 17 | 18 | ax.set_xticklabels(['']+[BOS_WORD]+source+[EOS_WORD], rotation=90 , fontproperties = ZhFont) 19 | ax.set_yticklabels(['']+target+[EOS_WORD], fontproperties = JaFont) 20 | 21 | # Show label at every tick 22 | ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 23 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 24 | 25 | plt.show() 26 | -------------------------------------------------------------------------------- /Utils/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 3 | This is a modified and slightly extended verison of 4 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | import itertools 13 | import numpy as np 14 | 15 | #pylint: disable=C0103 16 | 17 | 18 | def _get_ngrams(n, text): 19 | """Calcualtes n-grams. 20 | Args: 21 | n: which n-grams to calculate 22 | text: An array of tokens 23 | Returns: 24 | A set of n-grams 25 | """ 26 | ngram_set = set() 27 | text_length = len(text) 28 | max_index_ngram_start = text_length - n 29 | for i in range(max_index_ngram_start + 1): 30 | ngram_set.add(tuple(text[i:i + n])) 31 | return ngram_set 32 | 33 | 34 | def _split_into_words(sentences): 35 | """Splits multiple sentences into words and flattens the result""" 36 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 37 | 38 | 39 | def _get_word_ngrams(n, sentences): 40 | """Calculates word n-grams for multiple sentences. 41 | """ 42 | assert len(sentences) > 0 43 | assert n > 0 44 | 45 | words = _split_into_words(sentences) 46 | return _get_ngrams(n, words) 47 | 48 | 49 | def _len_lcs(x, y): 50 | """ 51 | Returns the length of the Longest Common Subsequence between sequences x 52 | and y. 53 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 54 | Args: 55 | x: sequence of words 56 | y: sequence of words 57 | Returns 58 | integer: Length of LCS between x and y 59 | """ 60 | table = _lcs(x, y) 61 | n, m = len(x), len(y) 62 | return table[n, m] 63 | 64 | 65 | def _lcs(x, y): 66 | """ 67 | Computes the length of the longest common subsequence (lcs) between two 68 | strings. The implementation below uses a DP programming algorithm and runs 69 | in O(nm) time where n = len(x) and m = len(y). 70 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 71 | Args: 72 | x: collection of words 73 | y: collection of words 74 | Returns: 75 | Table of dictionary of coord and len lcs 76 | """ 77 | n, m = len(x), len(y) 78 | table = dict() 79 | for i in range(n + 1): 80 | for j in range(m + 1): 81 | if i == 0 or j == 0: 82 | table[i, j] = 0 83 | elif x[i - 1] == y[j - 1]: 84 | table[i, j] = table[i - 1, j - 1] + 1 85 | else: 86 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 87 | return table 88 | 89 | 90 | def _recon_lcs(x, y): 91 | """ 92 | Returns the Longest Subsequence between x and y. 93 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 94 | Args: 95 | x: sequence of words 96 | y: sequence of words 97 | Returns: 98 | sequence: LCS of x and y 99 | """ 100 | i, j = len(x), len(y) 101 | table = _lcs(x, y) 102 | 103 | def _recon(i, j): 104 | """private recon calculation""" 105 | if i == 0 or j == 0: 106 | return [] 107 | elif x[i - 1] == y[j - 1]: 108 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 109 | elif table[i - 1, j] > table[i, j - 1]: 110 | return _recon(i - 1, j) 111 | else: 112 | return _recon(i, j - 1) 113 | 114 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 115 | return recon_tuple 116 | 117 | 118 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 119 | """ 120 | Computes ROUGE-N of two text collections of sentences. 121 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 122 | papers/rouge-working-note-v1.3.1.pdf 123 | Args: 124 | evaluated_sentences: The sentences that have been picked by the summarizer 125 | reference_sentences: The sentences from the referene set 126 | n: Size of ngram. Defaults to 2. 127 | Returns: 128 | A tuple (f1, precision, recall) for ROUGE-N 129 | Raises: 130 | ValueError: raises exception if a param has len <= 0 131 | """ 132 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 133 | raise ValueError("Collections must contain at least 1 sentence.") 134 | 135 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 136 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 137 | reference_count = len(reference_ngrams) 138 | evaluated_count = len(evaluated_ngrams) 139 | 140 | # Gets the overlapping ngrams between evaluated and reference 141 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 142 | overlapping_count = len(overlapping_ngrams) 143 | 144 | # Handle edge case. This isn't mathematically correct, but it's good enough 145 | if evaluated_count == 0: 146 | precision = 0.0 147 | else: 148 | precision = overlapping_count / evaluated_count 149 | 150 | if reference_count == 0: 151 | recall = 0.0 152 | else: 153 | recall = overlapping_count / reference_count 154 | 155 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 156 | 157 | # return overlapping_count / reference_count 158 | return f1_score, precision, recall 159 | 160 | 161 | def _f_p_r_lcs(llcs, m, n): 162 | """ 163 | Computes the LCS-based F-measure score 164 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 165 | rouge-working-note-v1.3.1.pdf 166 | Args: 167 | llcs: Length of LCS 168 | m: number of words in reference summary 169 | n: number of words in candidate summary 170 | Returns: 171 | Float. LCS-based F-measure score 172 | """ 173 | r_lcs = llcs / m 174 | p_lcs = llcs / n 175 | beta = p_lcs / (r_lcs + 1e-12) 176 | num = (1 + (beta**2)) * r_lcs * p_lcs 177 | denom = r_lcs + ((beta**2) * p_lcs) 178 | f_lcs = num / (denom + 1e-12) 179 | return f_lcs, p_lcs, r_lcs 180 | 181 | 182 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 183 | """ 184 | Computes ROUGE-L (sentence level) of two text collections of sentences. 185 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 186 | rouge-working-note-v1.3.1.pdf 187 | Calculated according to: 188 | R_lcs = LCS(X,Y)/m 189 | P_lcs = LCS(X,Y)/n 190 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 191 | where: 192 | X = reference summary 193 | Y = Candidate summary 194 | m = length of reference summary 195 | n = length of candidate summary 196 | Args: 197 | evaluated_sentences: The sentences that have been picked by the summarizer 198 | reference_sentences: The sentences from the referene set 199 | Returns: 200 | A float: F_lcs 201 | Raises: 202 | ValueError: raises exception if a param has len <= 0 203 | """ 204 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 205 | raise ValueError("Collections must contain at least 1 sentence.") 206 | reference_words = _split_into_words(reference_sentences) 207 | evaluated_words = _split_into_words(evaluated_sentences) 208 | m = len(reference_words) 209 | n = len(evaluated_words) 210 | lcs = _len_lcs(evaluated_words, reference_words) 211 | return _f_p_r_lcs(lcs, m, n) 212 | 213 | 214 | def _union_lcs(evaluated_sentences, reference_sentence): 215 | """ 216 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 217 | subsequence between reference sentence ri and candidate summary C. For example 218 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 219 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 220 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 221 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 222 | LCS_u(r_i, C) = 4/5. 223 | Args: 224 | evaluated_sentences: The sentences that have been picked by the summarizer 225 | reference_sentence: One of the sentences in the reference summaries 226 | Returns: 227 | float: LCS_u(r_i, C) 228 | ValueError: 229 | Raises exception if a param has len <= 0 230 | """ 231 | if len(evaluated_sentences) <= 0: 232 | raise ValueError("Collections must contain at least 1 sentence.") 233 | 234 | lcs_union = set() 235 | reference_words = _split_into_words([reference_sentence]) 236 | combined_lcs_length = 0 237 | for eval_s in evaluated_sentences: 238 | evaluated_words = _split_into_words([eval_s]) 239 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 240 | combined_lcs_length += len(lcs) 241 | lcs_union = lcs_union.union(lcs) 242 | 243 | union_lcs_count = len(lcs_union) 244 | union_lcs_value = union_lcs_count / combined_lcs_length 245 | return union_lcs_value 246 | 247 | 248 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 249 | """ 250 | Computes ROUGE-L (summary level) of two text collections of sentences. 251 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 252 | rouge-working-note-v1.3.1.pdf 253 | Calculated according to: 254 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 255 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 256 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 257 | where: 258 | SUM(i,u) = SUM from i through u 259 | u = number of sentences in reference summary 260 | C = Candidate summary made up of v sentences 261 | m = number of words in reference summary 262 | n = number of words in candidate summary 263 | Args: 264 | evaluated_sentences: The sentences that have been picked by the summarizer 265 | reference_sentence: One of the sentences in the reference summaries 266 | Returns: 267 | A float: F_lcs 268 | Raises: 269 | ValueError: raises exception if a param has len <= 0 270 | """ 271 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 272 | raise ValueError("Collections must contain at least 1 sentence.") 273 | 274 | # total number of words in reference sentences 275 | m = len(_split_into_words(reference_sentences)) 276 | 277 | # total number of words in evaluated sentences 278 | n = len(_split_into_words(evaluated_sentences)) 279 | 280 | union_lcs_sum_across_all_references = 0 281 | for ref_s in reference_sentences: 282 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 283 | ref_s) 284 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 285 | 286 | 287 | def rouge(hypotheses, references): 288 | """Calculates average rouge scores for a list of hypotheses and 289 | references""" 290 | 291 | # Filter out hyps that are of 0 length 292 | # hyps_and_refs = zip(hypotheses, references) 293 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 294 | # hypotheses, references = zip(*hyps_and_refs) 295 | 296 | # Calculate ROUGE-1 F1, precision, recall scores 297 | rouge_1 = [ 298 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 299 | ] 300 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 301 | 302 | # Calculate ROUGE-2 F1, precision, recall scores 303 | rouge_2 = [ 304 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 305 | ] 306 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 307 | 308 | # Calculate ROUGE-L F1, precision, recall scores 309 | rouge_l = [ 310 | rouge_l_sentence_level([hyp], [ref]) 311 | for hyp, ref in zip(hypotheses, references) 312 | ] 313 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 314 | 315 | return { 316 | "rouge_1/f_score": rouge_1_f, 317 | "rouge_1/r_score": rouge_1_r, 318 | "rouge_1/p_score": rouge_1_p, 319 | "rouge_2/f_score": rouge_2_f, 320 | "rouge_2/r_score": rouge_2_r, 321 | "rouge_2/p_score": rouge_2_p, 322 | "rouge_l/f_score": rouge_l_f, 323 | "rouge_l/r_score": rouge_l_r, 324 | "rouge_l/p_score": rouge_l_p, 325 | } -------------------------------------------------------------------------------- /Utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import math 6 | import datetime 7 | import numpy as np 8 | from Utils.bleu import compute_bleu 9 | from Utils.rouge import rouge 10 | 11 | def aeq(*args): 12 | """ 13 | Assert all arguments have the same value 14 | """ 15 | arguments = (arg for arg in args) 16 | first = next(arguments) 17 | assert all(arg == first for arg in arguments), \ 18 | "Not all arguments have the same value: " + str(args) 19 | 20 | 21 | def sequence_mask(lengths, max_len=None): 22 | """ 23 | Creates a boolean mask from sequence lengths. 24 | """ 25 | batch_size = lengths.numel() 26 | max_len = max_len or lengths.max() 27 | return (torch.arange(0, max_len) 28 | .type_as(lengths) 29 | .repeat(batch_size, 1) 30 | .lt(lengths.unsqueeze(1))) 31 | 32 | def report_stats(stats, epoch, batch, n_batches, step_time, lr): 33 | """Write out statistics to stdout. 34 | 35 | Args: 36 | epoch (int): current epoch 37 | batch (int): current batch 38 | n_batch (int): total batches 39 | """ 40 | sys.stderr.flush() 41 | sys.stderr.write(( 42 | """Epoch {0:d},[{1:d}/{2:d}] Acc: {3:.2f}; PPL: {4:.2f}; Loss: {5:.2f}; CELoss: {6:.2f}, KLDLoss: {7:.2f} \r""").format( 43 | epoch, batch, n_batches, 44 | stats.accuracy(), stats.ppl(), 45 | stats.loss, stats.loss_detail()[0], stats.loss_detail()[1])) 46 | sys.stderr.flush() 47 | 48 | 49 | def debug_trace(*args, file=sys.stderr): 50 | print(datetime.datetime.now().strftime( 51 | '%Y/%m/%d %H:%M:%S'), '[DEBUG]', *args, file=file, flush=True) 52 | 53 | 54 | def trace(*args, file=sys.stderr): 55 | print(datetime.datetime.now().strftime( 56 | '%Y/%m/%d %H:%M:%S'), *args, file=file, flush=True) 57 | 58 | def check_save_path(path): 59 | save_path = os.path.abspath(path) 60 | dirname = os.path.dirname(save_path) 61 | if not os.path.exists(dirname): 62 | os.makedirs(dirname) 63 | 64 | 65 | def report_bleu(reference_corpus, translation_corpus): 66 | 67 | bleu, precions, bp, ratio, trans_length, ref_length =\ 68 | compute_bleu([[x] for x in reference_corpus], translation_corpus) 69 | trace("BLEU: %.2f [%.2f/%.2f/%.2f/%.2f] Pred_len:%d, Ref_len:%d"%( 70 | bleu*100, *precions, trans_length, ref_length)) 71 | 72 | 73 | def report_rouge(reference_corpus, translation_corpus): 74 | 75 | scores = rouge([" ".join(x) for x in translation_corpus], 76 | [" ".join(x) for x in reference_corpus]) 77 | 78 | 79 | trace("ROUGE-1:%.2f, ROUGE-2:%.2f"%( 80 | scores["rouge_1/f_score"]*100, scores["rouge_2/f_score"]*100)) -------------------------------------------------------------------------------- /config/nmt.ini: -------------------------------------------------------------------------------- 1 | ################################### 2 | # NMT configuration # 3 | ################################### 4 | 5 | [DEFAULT] 6 | 7 | src_lang=zh 8 | trg_lang=ja 9 | 10 | system=NMT 11 | 12 | User=XXX 13 | workspace=%(system)s_%(src_lang)s-%(trg_lang)s 14 | save_log=/home/%(User)s/workspace/result/%(workspace)s/log 15 | save_model=/home/%(User)s/workspace/result/%(workspace)s/model 16 | save_vocab=/home/%(User)s/workspace/result/%(workspace)s/vocab 17 | 18 | [Data] 19 | 20 | 21 | max_seq_len=10 22 | data_path=/itigo/Uploads/nmt/data 23 | 24 | 25 | 26 | 27 | 28 | 29 | [GPU] 30 | 31 | gpu_ids=[0] 32 | use_gpu=True 33 | 34 | [Embedding] 35 | 36 | min_freq=5 37 | src_embed_dim=500 38 | trg_embed_dim=500 39 | 40 | 41 | 42 | [Train] 43 | 44 | epochs=10 45 | batch_size=128 46 | valid_batch_size=64 47 | 48 | 49 | [Optim] 50 | 51 | lr=1.0 52 | optim=Adadelta 53 | max_grad_norm=5 54 | 55 | 56 | [Network] 57 | 58 | dropout=0.3 59 | enc_num_layers=2 60 | dec_num_layers=2 61 | bidirectional=True 62 | rnn_type=GRU 63 | hidden_size=500 64 | attn_type=general 65 | 66 | 67 | 68 | [Translate] 69 | 70 | k_best=1 71 | beam_size=10 72 | output=/home/User/workspace/result/%(workspace)s/output 73 | 74 | -------------------------------------------------------------------------------- /config/vnmt.ini: -------------------------------------------------------------------------------- 1 | ################################### 2 | # VNMT configuration # 3 | ################################### 4 | 5 | [DEFAULT] 6 | 7 | src_lang=zh 8 | trg_lang=ja 9 | 10 | system=VNMT 11 | User=XXX 12 | workspace=%(system)s_%(src_lang)s-%(trg_lang)s 13 | save_log=/home/%(User)s/workspace/result/%(workspace)s/log 14 | save_model=/home/%(User)s/workspace/result/%(workspace)s/model 15 | save_vocab=/home/%(User)s/workspace/result/%(workspace)s/vocab 16 | 17 | [Data] 18 | 19 | 20 | max_seq_len=10 21 | data_path=/itigo/Uploads/nmt/data 22 | 23 | 24 | 25 | 26 | [GPU] 27 | 28 | gpu_ids=[0] 29 | use_gpu=True 30 | 31 | [Embedding] 32 | 33 | min_freq=5 34 | src_embed_dim=500 35 | trg_embed_dim=500 36 | 37 | 38 | 39 | [Train] 40 | 41 | epochs=10 42 | batch_size=128 43 | valid_batch_size=64 44 | 45 | 46 | [Optim] 47 | 48 | lr=1.0 49 | optim=Adadelta 50 | max_grad_norm=5 51 | 52 | 53 | 54 | [Network] 55 | 56 | dropout=0.3 57 | enc_num_layers=2 58 | dec_num_layers=2 59 | bidirectional=True 60 | rnn_type=GRU 61 | hidden_size=500 62 | attn_type=general 63 | latent_size=500 64 | meanpool=False 65 | 66 | [Translate] 67 | 68 | k_best=1 69 | beam_size=10 70 | output=/home/%(User)s/workspace/result/%(workspace)s/output 71 | 72 | 73 | [Loss] 74 | 75 | kld_weight=1 76 | start_increase_kld_at=8 77 | -------------------------------------------------------------------------------- /config/vrnmt.ini: -------------------------------------------------------------------------------- 1 | ################################### 2 | # VNMT configuration # 3 | ################################### 4 | 5 | [DEFAULT] 6 | 7 | src_lang=zh 8 | trg_lang=ja 9 | 10 | system=VRNMT 11 | User=XXX 12 | workspace=%(system)s_%(src_lang)s-%(trg_lang)s 13 | save_log=/home/%(User)s/workspace/result/%(workspace)s/log 14 | save_model=/home/%(User)s/workspace/result/%(workspace)s/model 15 | save_vocab=/home/%(User)s/workspace/result/%(workspace)s/vocab 16 | 17 | [Data] 18 | 19 | 20 | max_seq_len=10 21 | data_path=/itigo/Uploads/nmt/data 22 | 23 | 24 | [GPU] 25 | 26 | gpu_ids=[0] 27 | use_gpu=True 28 | 29 | [Embedding] 30 | 31 | min_freq=5 32 | src_embed_dim=500 33 | trg_embed_dim=500 34 | 35 | 36 | 37 | [Train] 38 | 39 | epochs=10 40 | batch_size=128 41 | valid_batch_size=64 42 | 43 | 44 | [Optim] 45 | 46 | lr=1.0 47 | optim=Adadelta 48 | max_grad_norm=5 49 | 50 | 51 | 52 | [Network] 53 | 54 | dropout=0.3 55 | enc_num_layers=2 56 | dec_num_layers=2 57 | bidirectional=True 58 | rnn_type=GRU 59 | hidden_size=500 60 | attn_type=general 61 | latent_size=500 62 | meanpool=False 63 | 64 | [Translate] 65 | 66 | k_best=1 67 | beam_size=10 68 | output=/home/%(User)s/workspace/result/%(workspace)s/output 69 | 70 | 71 | 72 | [Loss] 73 | 74 | kld_weight=1 75 | start_increase_kld_at=8 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import glob 6 | import random 7 | import argparse 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from Utils.utils import trace 13 | from Utils.utils import check_save_path 14 | from Utils.args import parse_args 15 | from Utils.config import read_config 16 | from Utils.config import format_config 17 | from Utils.DataLoader import DataBatchIterator 18 | from Utils.DataLoader import PAD_WORD 19 | 20 | 21 | from NMT import Trainer 22 | from NMT import Statistics 23 | from NMT import NMTLoss 24 | from NMT import Optimizer 25 | from NMT import model_factory 26 | 27 | 28 | 29 | def train_model(model, optimizer, loss_func, 30 | train_data_iter, valid_data_iter, config): 31 | 32 | trainer = Trainer(model, loss_func, optimizer, config) 33 | 34 | for epoch in range(1, config.epochs + 1): 35 | train_iter = iter(train_data_iter) 36 | valid_iter = iter(valid_data_iter) 37 | 38 | # train 39 | train_stats = trainer.train( 40 | train_iter, epoch, train_data_iter.num_batches) 41 | 42 | print('') 43 | trace('Epoch %d, Train acc: %g, ppl: %g' % 44 | (epoch, train_stats.accuracy(), train_stats.ppl())) 45 | 46 | # validate 47 | valid_stats = trainer.validate(valid_iter) 48 | trace('Epoch %d, Valid acc: %g, ppl: %g' % 49 | (epoch, valid_stats.accuracy(), valid_stats.ppl())) 50 | 51 | # # log 52 | # train_stats.log("train", config.model_name, optimizer.lr) 53 | # valid_stats.log("valid", config.model_name, optimizer.lr) 54 | 55 | # update the learning rate 56 | trainer.lr_step(valid_stats.ppl(), epoch) 57 | 58 | # dump a checkpoint if needed. 59 | trainer.dump_checkpoint(epoch, config, train_stats) 60 | 61 | 62 | def build_optimizer(model, config): 63 | optimizer = Optimizer(config.optim, config) 64 | 65 | optimizer.set_parameters(model.named_parameters()) 66 | 67 | return optimizer 68 | 69 | 70 | def main(): 71 | # Load checkpoint if we resume from a previous training. 72 | 73 | args, parser = parse_args("train") 74 | config = read_config(args, parser, args.config) 75 | trace(format_config(config)) 76 | train_data_iter = DataBatchIterator( 77 | config=config, 78 | is_train=True, 79 | dataset="train", 80 | batch_size=config.batch_size, 81 | shuffle=True) 82 | train_data_iter.load() 83 | 84 | src_vocab = train_data_iter.src_vocab 85 | trg_vocab = train_data_iter.trg_vocab 86 | 87 | check_save_path(config.save_vocab) 88 | torch.save(src_vocab, config.save_vocab + "." + config.src_lang) 89 | torch.save(trg_vocab, config.save_vocab + "." + config.trg_lang) 90 | valid_data_iter = DataBatchIterator( 91 | config=config, 92 | is_train=True, 93 | dataset="dev", 94 | batch_size=config.valid_batch_size) 95 | valid_data_iter.set_vocab(src_vocab, trg_vocab) 96 | valid_data_iter.load() 97 | 98 | # Build model. 99 | model = model_factory(config, src_vocab, trg_vocab) 100 | # if len(config.gpu_ids) > 1: 101 | # trace('Multi gpu training: ', config.gpu_ids) 102 | # model = nn.DataParallel(model, device_ids=config.gpu_ids, dim=1) 103 | 104 | trace(model) 105 | 106 | # Build optimizer. 107 | optimizer = build_optimizer(model, config) 108 | 109 | padding_idx = trg_vocab.stoi[PAD_WORD] 110 | # Build loss functions for training set and validation set. 111 | loss_func = NMTLoss(config, padding_idx) 112 | # Do training. 113 | train_model(model, optimizer, loss_func, 114 | train_data_iter, valid_data_iter, config) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | import math 5 | import codecs 6 | import torch 7 | 8 | from itertools import count 9 | 10 | from Utils.args import parse_args 11 | from Utils.utils import trace 12 | from Utils.config import read_config 13 | from Utils.DataLoader import DataBatchIterator 14 | from NMT.ModelConstructor import model_factory 15 | from NMT.Optimizer import Optimizer 16 | 17 | from NMT.Loss import NMTLoss 18 | from NMT.Trainer import Trainer 19 | from NMT.Trainer import Statistics 20 | from NMT.translate import BatchTranslator 21 | from NMT.translate import TranslationBuilder 22 | from NMT.translate import GNMTGlobalScorer 23 | from Utils.plot import plot_attn 24 | from Utils.utils import report_bleu 25 | from Utils.utils import report_rouge 26 | 27 | def report_score(name, score_total, words_total): 28 | print("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( 29 | name, score_total / words_total, 30 | name, math.exp(-score_total / words_total))) 31 | 32 | 33 | def main(): 34 | args, parser = parse_args("translate") 35 | config = read_config(args, parser, args.config) 36 | config.batch_size = 1 37 | test_data_iter = DataBatchIterator( 38 | config=config, is_train=False, dataset="test", batch_size=config.batch_size) 39 | 40 | src_vocab = torch.load(config.save_vocab + "." + config.src_lang) 41 | trg_vocab = torch.load(config.save_vocab + "." + config.trg_lang) 42 | 43 | test_data_iter.set_vocab(src_vocab, trg_vocab) 44 | test_data_iter.load() 45 | 46 | checkpoint = torch.load(config.save_model+".pt") 47 | # Load the model. 48 | model = model_factory( 49 | config, src_vocab, trg_vocab, train_mode=False, checkpoint=checkpoint) 50 | if config.verbose: 51 | trace(model) 52 | # File to write sentences to. 53 | pred_file = codecs.open(config.output+".pred.txt", 'w', 'utf-8') 54 | ref_file = codecs.open(config.output+".ref.txt", 'w', 'utf-8') 55 | src_file = codecs.open(config.output+".src.txt", 'w', 'utf-8') 56 | # Sort batch by decreasing lengths of sentence required by pytorch. 57 | # sort=False means "Use dataset's sortkey instead of iterator's". 58 | 59 | 60 | # Translator 61 | scorer = GNMTGlobalScorer(config.alpha, config.beta, config.coverage_penalty, 62 | config.length_penalty) 63 | translator = BatchTranslator(model, config, trg_vocab, global_scorer=scorer) 64 | 65 | data_iter = iter(test_data_iter) 66 | 67 | builder = TranslationBuilder(src_vocab, trg_vocab, config) 68 | 69 | # Statistics 70 | counter = count(1) 71 | pred_score_total, pred_words_total = 0, 0 72 | gold_score_total, gold_words_total = 0, 0 73 | 74 | pred_list = [] 75 | gold_list = [] 76 | for batch in data_iter: 77 | outputs = translator.translate_batch(batch) 78 | batch_trans = builder.from_batch_translator_output(outputs) 79 | 80 | for trans in batch_trans: 81 | pred_score_total += trans.pred_scores[0] 82 | pred_words_total += len(trans.pred_sents[0]) 83 | pred_list.append(trans.pred_sents[0]) 84 | 85 | gold_score_total += trans.gold_score 86 | gold_words_total += len(trans.gold_sent) + 1 87 | gold_list.append(trans.gold_sent) 88 | 89 | k_best_preds = [" ".join(pred) 90 | for pred in trans.pred_sents[:config.k_best]] 91 | #print(" ".join(trans.gold_sent) 92 | pred_file.write('\n'.join(k_best_preds)+"\n") 93 | ref_file.write(" ".join(trans.gold_sent)+'\n') 94 | src_file.write(" ".join(trans.src_sent)+'\n') 95 | if config.verbose: 96 | sent_number = next(counter) 97 | output = trans.log(sent_number) 98 | os.write(1, output.encode('utf-8')) 99 | 100 | report_score('PRED', pred_score_total, pred_words_total) 101 | report_score('GOLD', gold_score_total, gold_words_total) 102 | if config.plot_attn: 103 | plot_attn(trans.src_sent, trans.pred_sents[0], trans.attns[0].cpu()) 104 | #break 105 | #break 106 | report_bleu(gold_list, pred_list) 107 | report_rouge(gold_list, pred_list) 108 | 109 | # if config.dump_beam: 110 | # import json 111 | # json.dump(translator.beam_accum, 112 | # codecs.open(config.dump_beam, 'w', 'utf-8')) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | --------------------------------------------------------------------------------