├── nmt ├── modules │ ├── __init__.py │ ├── Embedding.py │ ├── StackedRNN.py │ ├── Encoder.py │ ├── Attention.py │ ├── Beam.py │ └── Decoder.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ └── misc_utils.py ├── __init__.py ├── Optim.py ├── Loss.py ├── Translator.py ├── Trainer.py ├── model_helper.py └── Model.py ├── .gitignore ├── hard ├── translate.sh ├── train_critic.sh ├── train.sh └── config.yml ├── soft ├── translate.sh ├── train.sh └── config.yml ├── load_stopwords.py ├── pretrain ├── train.sh ├── translate.sh ├── case.sh ├── config.yml └── final_step.py ├── template ├── train.sh ├── test.sh ├── clean.py └── config.yml ├── read.py ├── rread.py ├── mclean.py ├── readtgt.py ├── README.md ├── score.py ├── translate.py ├── data.py ├── train.py ├── template.py ├── joint_train.py └── maskGAN.py /nmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | -------------------------------------------------------------------------------- /hard/translate.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/translate.py \ 3 | -test_file ../data/douban/human_dev \ 4 | -model_type CAS \ 5 | -tgt_out ./soft_out \ 6 | -model ./checkpoint_epoch5.pkl \ 7 | -src_vocab ../data/douban/vocab_src \ 8 | -tgt_vocab ../data/douban/vocab_tgt 9 | -------------------------------------------------------------------------------- /soft/translate.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/translate.py \ 3 | -test_file ../data/douban/human_dev \ 4 | -model_type soft \ 5 | -tgt_out ./soft_out \ 6 | -model ./checkpoint_epoch5.pkl \ 7 | -src_vocab ../data/douban/vocab_src \ 8 | -tgt_vocab ../data/douban/vocab_tgt 9 | -------------------------------------------------------------------------------- /load_stopwords.py: -------------------------------------------------------------------------------- 1 | from data import Vocab 2 | import sys 3 | 4 | vocab = sys.argv[1] 5 | stopwords = sys.argv[2] 6 | 7 | 8 | v = Vocab(vocab)#, noST = True) 9 | 10 | for line in open(stopwords, encoding = "gbk").readlines(): 11 | w = line.strip() 12 | if w in v.stoi: 13 | print (v.stoi[line.strip()]) 14 | -------------------------------------------------------------------------------- /pretrain/train.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/train.py \ 3 | -model_type rg \ 4 | -config config.yml \ 5 | -nmt_dir ${NMT_DIR} \ 6 | -src_vocab ../data/golden/vocab_src \ 7 | -tgt_vocab ../data/golden/vocab_tgt \ 8 | -train_file ../data/golden/train \ 9 | -valid_file ../data/golden/dev 10 | 11 | -------------------------------------------------------------------------------- /soft/train.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/joint_train.py \ 3 | -model_type JNT \ 4 | -config config.yml \ 5 | -nmt_dir ${NMT_DIR} \ 6 | -src_vocab ../data/douban/vocab_src \ 7 | -tgt_vocab ../data/douban/vocab_tgt \ 8 | -train_file ../data/douban/train \ 9 | -valid_file ../data/douban/dev_dev 10 | -------------------------------------------------------------------------------- /pretrain/translate.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/translate.py \ 3 | -test_file ../data/golden/ske\ 4 | -model_type rg \ 5 | -tgt_out ablation_response \ 6 | -model ./this.is.the.out/out_dir/checkpoint_epoch19.pkl \ 7 | -src_vocab ../data/golden/vocab_src \ 8 | -tgt_vocab ../data/golden/vocab_tgt 9 | 10 | 11 | -------------------------------------------------------------------------------- /template/train.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/template.py \ 3 | -config config.yml \ 4 | -nmt_dir ${NMT_DIR} \ 5 | -src_vocab ../data/golden/vocab_src \ 6 | -tgt_vocab ../data/golden/vocab_tgt \ 7 | -train_file ../data/golden/train \ 8 | -valid_file ../data/golden/dev \ 9 | -mode train \ 10 | -stop_words ../data/golden/stop_words 11 | -------------------------------------------------------------------------------- /nmt/__init__.py: -------------------------------------------------------------------------------- 1 | import nmt.model_helper 2 | from nmt.Loss import NMTLossCompute 3 | from nmt.Trainer import Trainer, Statistics, Scorer 4 | from nmt.Translator import Translator 5 | from nmt.Optim import Optim 6 | from nmt.modules.Beam import Beam 7 | from nmt.utils import misc_utils, data_utils 8 | __all__ = [nmt.model_helper, NMTLossCompute, Trainer, Translator, Scorer, Optim, Statistics, Beam, misc_utils, data_utils] 9 | -------------------------------------------------------------------------------- /template/test.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/template.py \ 3 | -config config.yml \ 4 | -nmt_dir ${NMT_DIR} \ 5 | -src_vocab ../data/golden/vocab_src \ 6 | -tgt_vocab ../data/golden/vocab_tgt \ 7 | -model ./out/out_dir/checkpoint_epoch19.pkl \ 8 | -test_file ../data/golden/train \ 9 | -out_file ./TMP \ 10 | -stop_words ../data/golden/stop_words \ 11 | -mode test 12 | python clean.py TMP > final_output -------------------------------------------------------------------------------- /hard/train_critic.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/joint_train.py \ 3 | -model_type CAS \ 4 | -config config.yml \ 5 | -nmt_dir ${NMT_DIR} \ 6 | -src_vocab ../data/douban/vocab_src \ 7 | -tgt_vocab ../data/douban/vocab_tgt \ 8 | -train_file ../data/douban/train \ 9 | -valid_file ../data/douban/dev_dev \ 10 | -rg_model ../douban_pretrain/checkpoint_epoch19.pkl.clean \ 11 | -tg_model ../douban_template/checkpoint_epoch19.pkl 12 | -------------------------------------------------------------------------------- /hard/train.sh: -------------------------------------------------------------------------------- 1 | NMT_DIR=.. 2 | python3 ${NMT_DIR}/joint_train.py \ 3 | -model_type CAS \ 4 | -config config.yml \ 5 | -nmt_dir ${NMT_DIR} \ 6 | -src_vocab ../data/douban/vocab_src \ 7 | -tgt_vocab ../data/douban/vocab_tgt \ 8 | -train_file ../data/douban/train \ 9 | -valid_file ../data/douban/dev_dev \ 10 | -rg_model ../douban_pretrain/checkpoint_epoch19.pkl.clean \ 11 | -tg_model ../douban_template/checkpoint_epoch19.pkl \ 12 | -critic_model ./checkpoint_epoch_critic2.pkl 13 | -------------------------------------------------------------------------------- /nmt/modules/Embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Embedding(nn.Module): 5 | def __init__(self, input_size, embedding_size, padding_idx=1): 6 | super(Embedding, self).__init__() 7 | self.padding_idx = padding_idx 8 | self.embedding_size = embedding_size 9 | self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=padding_idx) 10 | def forward(self, input_seqs): 11 | embedded = self.embedding(input_seqs) 12 | return embedded 13 | -------------------------------------------------------------------------------- /read.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from data import Vocab 3 | if __name__ == "__main__": 4 | prefix = './data/'+sys.argv[1]+'/vocab_' 5 | vocabs = [Vocab(prefix+'src', noST = True), Vocab(prefix+'tgt')] 6 | with open(sys.argv[2]) as f: 7 | for line in f.readlines(): 8 | res = [] 9 | sents = line.split('|') 10 | for idx, s in enumerate(sents): 11 | res.append(' '.join( [vocabs[idx%2].itos[int(x)] for x in s.split()] )) 12 | print ('|'.join(res)) 13 | 14 | -------------------------------------------------------------------------------- /template/clean.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | with open(sys.argv[1]) as f: 5 | for line in f.readlines(): 6 | x = line.strip().split('|') 7 | y = x[-1] 8 | z = [ int(t) for t in y.split()] 9 | iszero = False 10 | new_z = [] 11 | for w in z: 12 | if iszero and w == 0: 13 | continue 14 | else: 15 | new_z.append(w) 16 | iszero = (w==0) 17 | x[-1] = ' '.join([str(t) for t in new_z]) 18 | print '|'.join(x) 19 | 20 | -------------------------------------------------------------------------------- /rread.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from data import Vocab 3 | if __name__ == "__main__": 4 | prefix = './data/'+sys.argv[1]+'/vocab_' 5 | vocabs = [Vocab(prefix+'src', noST = True), Vocab(prefix+'tgt')] 6 | with open(sys.argv[2]) as f: 7 | for line in f.readlines(): 8 | res = [] 9 | sents = line.split('|') 10 | for idx, s in enumerate(sents): 11 | res.append(' '.join( [str(vocabs[idx%2].stoi.get(x, vocabs[idx%2].stoi[vocabs[idx%2].UNK])) for x in s.split()] )) 12 | print ('|'.join(res)) 13 | 14 | -------------------------------------------------------------------------------- /mclean.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from data import Vocab 3 | 4 | vocab_tgt = Vocab('../data/golden/vocab_tgt') 5 | 6 | with open(sys.argv[1]) as f: 7 | for line in f.readlines(): 8 | x = line.strip().split('|') 9 | y = x[-1] 10 | z = [ int(t) for t in y.split()] 11 | iszero = False 12 | new_z = [] 13 | for w in z: 14 | if iszero and w == 0: 15 | continue 16 | else: 17 | new_z.append(w) 18 | iszero = (w==0) 19 | print (' '.join([vocab_tgt.i2s(w) for w in new_z])) 20 | 21 | -------------------------------------------------------------------------------- /pretrain/case.sh: -------------------------------------------------------------------------------- 1 | python final_step.py 2 | NMT_DIR=.. 3 | python3 ${NMT_DIR}/template.py \ 4 | -config ../template/config.yml \ 5 | -nmt_dir ${NMT_DIR} \ 6 | -src_vocab ../data/douban/vocab_src \ 7 | -tgt_vocab ../data/douban/vocab_tgt \ 8 | -model ../template/checkpoint_epoch8.pkl \ 9 | -test_file ./in \ 10 | -out_file ./in_tem \ 11 | -stop_words ../data/douban/stop_words \ 12 | -mode test 13 | 14 | python3 ../read.py douban in_tem > output_skeleton 15 | 16 | python3 ${NMT_DIR}/translate.py \ 17 | -test_file ./in_tem \ 18 | -model_type rg \ 19 | -tgt_out ./output_case \ 20 | -model ./hard.pkl \ 21 | -src_vocab ../data/douban/vocab_src \ 22 | -tgt_vocab ../data/douban/vocab_tgt 23 | -------------------------------------------------------------------------------- /readtgt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from data import Vocab 3 | if __name__ == "__main__": 4 | prefix = '/home/jcykcai/JunNMT/data/'+sys.argv[1]+'/vocab_' 5 | vocabs = [Vocab(prefix+'src', noST = True), Vocab(prefix+'tgt')] 6 | with open(sys.argv[2]) as f, open(sys.argv[3]) as f1: 7 | line1s = f1.readlines() 8 | cur = 0 9 | for line in f.readlines(): 10 | ooo = line1s[cur].strip().split('|') 11 | cur += 1 12 | sent = line.strip().split('\t') 13 | #assert len(sent) == 2 14 | sent = sent[0] 15 | ooo[1] = ' '.join( [ str(vocabs[1].stoi.get(x, vocabs[1].stoi[vocabs[1].UNK])) for x in sent.split()[:-1]]) 16 | res = ooo[:2] 17 | print ('|'.join(res)) 18 | 19 | -------------------------------------------------------------------------------- /template/config.yml: -------------------------------------------------------------------------------- 1 | # Network 2 | encoder_type: 3 | decoder_type: AttnDecoderRNN #InputFeedDecoder, AttnDecoderRNN, ScheduledDecoder 4 | rnn_type: LSTM #RNN, LSTM, GRU 5 | bidirectional: true 6 | embedding_size: 500 7 | hidden_size: 500 8 | num_layers: 2 9 | dropout: 0.3 10 | dropout_ev: 0.1 11 | atten_model: dot #general, dot, none !!!!! base use dot 12 | 13 | dhidden_size: 128 14 | aux_size: 256 15 | aux_nums: 4 16 | 17 | mem_gate: true 18 | gate_vector: true 19 | src_attention: true 20 | 21 | # Misc 22 | use_cuda: true 23 | random_seed: 19940117 24 | 25 | # Train 26 | optim_method: adam #adadelta, adam, sgd 27 | max_grad_norm: 5 28 | learning_rate: 0.001 29 | learning_rate_decay: 0.9 30 | start_decay_at: 8 31 | weight_decay: 0.000001 # weight decay(L2 penalty) 32 | num_train_epochs: 20 33 | steps_per_stats: 100 34 | steps_per_eval: 1000 35 | train_batch_size: 128 36 | train_shard_size: 32 37 | start_epoch_at: 38 | 39 | out_dir: ./out/out_dir # path to save model 40 | -------------------------------------------------------------------------------- /pretrain/config.yml: -------------------------------------------------------------------------------- 1 | # Network 2 | encoder_type: 3 | decoder_type: AttnDecoderRNN #InputFeedDecoder, AttnDecoderRNN, ScheduledDecoder 4 | rnn_type: LSTM #RNN, LSTM, GRU 5 | bidirectional: true 6 | embedding_size: 300 7 | hidden_size: 300 8 | num_layers: 2 9 | dropout: 0.33 10 | dropout_ev: 0.33 11 | atten_model: dot #general, dot, none !!!!! base use dot 12 | 13 | dhidden_size: 128 14 | aux_size: 256 15 | aux_nums: 4 16 | 17 | mem_gate: true 18 | gate_vector: true 19 | src_attention: true 20 | 21 | # Misc 22 | use_cuda: true 23 | random_seed: 19940117 24 | 25 | # Train 26 | optim_method: adam #adadelta, adam, sgd 27 | max_grad_norm: 5 28 | learning_rate: 0.001 29 | learning_rate_decay: 0.9 30 | start_decay_at: 8 31 | weight_decay: 0.000001 # weight decay(L2 penalty) 32 | num_train_epochs: 20 33 | steps_per_stats: 100 34 | steps_per_eval: 1000 35 | train_batch_size: 128 36 | train_shard_size: 32 37 | start_epoch_at: 38 | valid_batch_size: 32 39 | 40 | only_train_mem: false 41 | use_ev: false 42 | out_dir: ./out/out_dir # path to save model 43 | -------------------------------------------------------------------------------- /pretrain/final_step.py: -------------------------------------------------------------------------------- 1 | with open('../data/douban/vocab_src') as f: 2 | words = [line.strip() for line in f.readlines()] 3 | vocab_src = dict(zip(words, range(0, len(words)))) 4 | 5 | with open('../data/douban/vocab_tgt') as f: 6 | words = [line.strip() for line in f.readlines()] 7 | vocab_tgt = dict(zip(words, range(0, len(words)))) 8 | 9 | def write(res, fo, SrcnoST = False): 10 | vocabs = [vocab_src, vocab_tgt] 11 | line = [] 12 | if SrcnoST: 13 | tmp = [1, 3] 14 | else: 15 | tmp = [3, 3] 16 | for vid, s in enumerate(res): 17 | words = s.strip().split() 18 | line.append(' '.join([str(tmp[vid%2] + vocabs[vid%2].get(w, 0)) for w in words])) 19 | line = '|'.join(line) 20 | fo.write(line + '\n') 21 | 22 | with open('input_case') as fsrc, \ 23 | open('in', 'w') as fox: 24 | last_src, last_tgt = None, None 25 | result = [] 26 | for line in fsrc.readlines(): 27 | result = [ t.strip() for t in line.strip().split('|')] 28 | assert len(result) ==4 29 | write(result, fox, True) 30 | -------------------------------------------------------------------------------- /hard/config.yml: -------------------------------------------------------------------------------- 1 | # Network 2 | encoder_type: 3 | decoder_type: AttnDecoderRNN #InputFeedDecoder, AttnDecoderRNN, ScheduledDecoder 4 | rnn_type: LSTM #RNN, LSTM, GRU 5 | bidirectional: true 6 | embedding_size: 500 7 | hidden_size: 500 8 | num_layers: 2 9 | dropout: 0.3 10 | dropout_ev: 0.3 11 | atten_model: dot #general, dot, none !!!!! base use dot 12 | 13 | dhidden_size: 128 14 | aux_size: 256 15 | aux_nums: 4 16 | 17 | mem_gate: true 18 | gate_vector: true 19 | src_attention: true 20 | 21 | # Misc 22 | use_cuda: true 23 | random_seed: 19940117 24 | 25 | # Train 26 | optim_method: adam #adadelta, adam, sgd 27 | max_grad_norm: 5 28 | learning_rate_R: 0.001 29 | learning_rate_T: 0.001 30 | learning_rate_C: 0.001 31 | learning_rate_decay: 0.9 32 | start_decay_at: 8 33 | weight_decay: 0.000001 # weight decay(L2 penalty) 34 | num_train_epochs: 30 35 | steps_per_stats: 100 36 | steps_per_eval: 1000 37 | train_batch_size: 128 38 | train_shard_size: 32 39 | start_epoch_at: 40 | valid_batch_size: 32 41 | 42 | 43 | use_ev: false 44 | use_critic: true 45 | out_dir: ./out/out_dir # path to save model 46 | -------------------------------------------------------------------------------- /soft/config.yml: -------------------------------------------------------------------------------- 1 | # Network 2 | encoder_type: 3 | decoder_type: AttnDecoderRNN #InputFeedDecoder, AttnDecoderRNN, ScheduledDecoder 4 | rnn_type: LSTM #RNN, LSTM, GRU 5 | bidirectional: true 6 | embedding_size: 500 7 | hidden_size: 500 8 | num_layers: 2 9 | dropout: 0.3 10 | dropout_ev: 0.3 11 | atten_model: dot #general, dot, none !!!!! base use dot 12 | 13 | dhidden_size: 128 14 | aux_size: 256 15 | aux_nums: 4 16 | 17 | mem_gate: true 18 | gate_vector: true 19 | src_attention: true 20 | 21 | # Misc 22 | use_cuda: true 23 | random_seed: 19940117 24 | 25 | # Train 26 | optim_method: adam #adadelta, adam, sgd 27 | max_grad_norm: 5 28 | learning_rate_R: 0.001 29 | learning_rate_T: 0.001 30 | learning_rate_C: 0.001 31 | learning_rate_decay: 0.9 32 | start_decay_at: 8 33 | weight_decay: 0.000001 # weight decay(L2 penalty) 34 | num_train_epochs: 30 35 | steps_per_stats: 100 36 | steps_per_eval: 1000 37 | train_batch_size: 128 38 | train_shard_size: 32 39 | start_epoch_at: 40 | valid_batch_size: 32 41 | 42 | 43 | use_ev: false 44 | use_critic: false 45 | out_dir: ./out/out_dir # path to save model 46 | -------------------------------------------------------------------------------- /nmt/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import nmt 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | 6 | def sequence_mask(lengths, max_len=None): 7 | """ 8 | Creates a boolean mask from sequence lengths. 9 | """ 10 | lengths = torch.LongTensor(lengths).cuda() 11 | batch_size = lengths.numel() 12 | max_len = max_len or lengths.max() 13 | return (torch.arange(0, max_len) 14 | .type_as(lengths) 15 | .repeat(batch_size, 1) 16 | .lt(lengths.unsqueeze(1))) 17 | 18 | # Pad a with the PAD symbol 19 | 20 | def pad_seq(seq, max_length, padding_idx): 21 | seq += [padding_idx for i in range(max_length - len(seq))] 22 | return seq 23 | 24 | 25 | def seq2indices(seq, word2index, max_len=None): 26 | seq_idx = [] 27 | words_in = seq.split(' ') 28 | if max_len is not None: 29 | words_in = words_in[:max_len] 30 | for w in words_in: 31 | seq_idx.append(word2index[w]) 32 | 33 | return seq_idx 34 | 35 | 36 | def indices2words(idxs, index2word): 37 | words_list = [index2word[idx] for idx in idxs] 38 | return words_list 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skeleton-to-Response 2 | 3 | ### Requirement 4 | 5 | pytorch==0.3.1 6 | 7 | ### Usage 8 | 9 | - Cascaded Model 10 | 1. pretrain skeleton generator: go to the `template` folder, use `train.sh` 11 | 2. pretrain response generator: go to the `pretrain` folder, use `train.sh`. 12 | 3. pretrain critic: go to the `hard` folder, use `train_critic.sh` 13 | 4. train both skeleton generator and response generator with RL: go to the `hard` folder, use `train.sh` 14 | 5. Test: go for `hard/translate.sh` 15 | - Joint Model 16 | - Use the `train.sh` and `translate.sh` in the `soft` folder 17 | 18 | ### Data 19 | 20 | The data we used in our paper are from [Wu et al, 2019](https://github.com/MarkWuNLP/ResponseEdit) 21 | 22 | some sample data are in the `data` folder. The format is 23 | 24 | `query | response | retrieved query | retrieved response` 25 | 26 | (sentences in each line are split by the symbol `|`) 27 | 28 | ### Citation 29 | 30 | ``` 31 | @inproceedings{cai-etal-2019-skeleton, 32 | title = "Skeleton-to-Response: Dialogue Generation Guided by Retrieval Memory", 33 | author = "Cai, Deng and Wang, Yan and Bi, Wei and Tu, Zhaopeng and Liu, Xiaojiang and Lam, Wai and Shi, Shuming", 34 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", 35 | month = jun, 36 | year = "2019", 37 | address = "Minneapolis, Minnesota", 38 | publisher = "Association for Computational Linguistics", 39 | url = "https://www.aclweb.org/anthology/N19-1124", 40 | pages = "1219--1228" 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /nmt/Optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm 3 | 4 | class Optim(object): 5 | 6 | def __init__(self, method, lr, max_grad_norm, 7 | lr_decay=1, weight_decay=0, 8 | start_decay_at=None, 9 | beta1=0.9, beta2=0.98): 10 | self.last_ppl = None 11 | self.lr = lr 12 | self.max_grad_norm = max_grad_norm 13 | self.method = method 14 | self.lr_decay = lr_decay 15 | self.weight_decay = weight_decay 16 | self.start_decay_at = start_decay_at 17 | self.start_decay = False 18 | self._step = 0 19 | self.betas = [beta1, beta2] 20 | 21 | def _setRate(self, lr): 22 | self.lr = lr 23 | self.optimizer.param_groups[0]['lr'] = self.lr 24 | 25 | def set_parameters(self, params): 26 | self.params = [p for p in params if p.requires_grad] 27 | if self.method == 'sgd': 28 | self.optimizer = optim.SGD(self.params, lr=self.lr, weight_decay=self.weight_decay) 29 | elif self.method == 'adagrad': 30 | self.optimizer = optim.Adagrad(self.params, lr=self.lr, weight_decay=self.weight_decay) 31 | elif self.method == 'adadelta': 32 | self.optimizer = optim.Adadelta(self.params, lr=self.lr, weight_decay=self.weight_decay) 33 | elif self.method == 'adam': 34 | self.optimizer = optim.Adam(self.params, lr=self.lr, 35 | betas=self.betas, eps=1e-9, 36 | weight_decay=self.weight_decay) 37 | else: 38 | raise RuntimeError("Invalid optim method: " + self.method) 39 | 40 | def step(self): 41 | "Compute gradients norm." 42 | self._step += 1 43 | 44 | if self.max_grad_norm: 45 | clip_grad_norm(self.params, self.max_grad_norm) 46 | self.lr = self.optimizer.param_groups[0]['lr'] 47 | self.optimizer.step() 48 | -------------------------------------------------------------------------------- /nmt/modules/StackedRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class StackedLSTM(nn.Module): 5 | """ 6 | Our own implementation of stacked LSTM. 7 | Needed for the decoder, because we do input feeding. 8 | """ 9 | def __init__(self, num_layers, input_size, rnn_size, dropout): 10 | super(StackedLSTM, self).__init__() 11 | self.dropout = nn.Dropout(dropout) 12 | self.num_layers = num_layers 13 | self.layers = nn.ModuleList() 14 | 15 | for i in range(num_layers): 16 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 17 | input_size = rnn_size 18 | 19 | def forward(self, input, hidden): 20 | h_0, c_0 = hidden 21 | h_1, c_1 = [], [] 22 | for i, layer in enumerate(self.layers): 23 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 24 | input = h_1_i 25 | if i + 1 != self.num_layers: 26 | input = self.dropout(input) 27 | h_1 += [h_1_i] 28 | c_1 += [c_1_i] 29 | 30 | h_1 = torch.stack(h_1) 31 | c_1 = torch.stack(c_1) 32 | 33 | return input, (h_1, c_1) 34 | 35 | 36 | class StackedGRU(nn.Module): 37 | 38 | def __init__(self, num_layers, input_size, rnn_size, dropout): 39 | super(StackedGRU, self).__init__() 40 | self.dropout = nn.Dropout(dropout) 41 | self.num_layers = num_layers 42 | self.layers = nn.ModuleList() 43 | 44 | for i in range(num_layers): 45 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 46 | input_size = rnn_size 47 | 48 | def forward(self, input, hidden): 49 | h_1 = [] 50 | for i, layer in enumerate(self.layers): 51 | h_1_i = layer(input, hidden[0][i]) 52 | input = h_1_i 53 | if i + 1 != self.num_layers: 54 | input = self.dropout(input) 55 | h_1 += [h_1_i] 56 | 57 | h_1 = torch.stack(h_1) 58 | return input, (h_1,) -------------------------------------------------------------------------------- /nmt/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch.utils.data as data 3 | import codecs 4 | import math 5 | import os 6 | import sys, time 7 | import nmt 8 | class HParams(object): 9 | def __init__(self, **entries): 10 | self.__dict__.update(entries) 11 | 12 | 13 | def load_hparams(config_file): 14 | with codecs.open(config_file, 'r', encoding='utf8') as f: 15 | configs = yaml.load(f) 16 | hparams = HParams(**configs) 17 | return hparams 18 | 19 | 20 | def safe_exp(value): 21 | """Exponentiation with catching of overflow error.""" 22 | try: 23 | ans = math.exp(value) 24 | except OverflowError: 25 | ans = float("inf") 26 | return ans 27 | 28 | 29 | def latest_checkpoint(model_dir): 30 | cnpt_file = os.path.join(model_dir,'checkpoint') 31 | try: 32 | cnpt = open(cnpt_file,'r').readline().strip().split(':')[-1] 33 | except: 34 | return None 35 | cnpt = os.path.join(model_dir,cnpt) 36 | return cnpt 37 | 38 | 39 | 40 | class ShowProcess(): 41 | """ 42 | 显示处理进度的类 43 | 调用该类相关函数即可实现处理进度的显示 44 | """ 45 | i = 1 # 当前的处理进度 46 | max_steps = 0 # 总共需要处理的次数 47 | max_arrow = 50 #进度条的长度 48 | 49 | # 初始化函数,需要知道总共的处理次数 50 | def __init__(self, max_steps): 51 | self.max_steps = max_steps 52 | self.i = 1 53 | 54 | # 显示函数,根据当前的处理进度i显示进度 55 | # 效果为[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100.00% 56 | def show_process(self, i=None): 57 | if i is not None: 58 | self.i = i 59 | num_arrow = int(self.i * self.max_arrow / self.max_steps) #计算显示多少个'>' 60 | num_line = self.max_arrow - num_arrow #计算显示多少个'-' 61 | percent = self.i * 100.0 / self.max_steps #计算完成进度,格式为xx.xx% 62 | process_bar = '[' + '>' * num_arrow + '-' * num_line + ']'\ 63 | + '%.2f' % percent + '%' + '\r' #带输出的字符串,'\r'表示不换行回到最左边 64 | sys.stdout.write(process_bar) #这两句打印字符到终端 65 | sys.stdout.flush() 66 | self.i += 1 67 | 68 | def close(self, words='done'): 69 | print('') 70 | print(words) 71 | self.i = 1 72 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import codecs 4 | import nmt 5 | import json 6 | from torch import cuda 7 | import time 8 | from train import vocab_wrapper 9 | from data import Data_Loader, Vocab 10 | 11 | def score_file(scorer, test_iter, tgt_fout, fields, use_cuda): 12 | print('start translating ...') 13 | process_bar = nmt.misc_utils.ShowProcess(len(test_iter)) 14 | with codecs.open(tgt_fout, 'w', 'utf8') as tgt_file: 15 | for batch in test_iter: 16 | process_bar.show_process() 17 | src, src_lengths = batch.src 18 | tgt, tgt_lengths = batch.tgt 19 | ref_src, ref_src_lengths = batch.ref_src 20 | ref_tgt, ref_tgt_lengths = batch.ref_tgt 21 | 22 | _, ret = scorer.score_batch(src, tgt, ref_src, ref_tgt, src_lengths, tgt_lengths, ref_src_lengths, ref_tgt_lengths, normalization = True) 23 | for s in ret: 24 | tgt_file.write(str(s)+'\n') 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-model_type", type=str) 29 | parser.add_argument("-test_file", type=str) 30 | parser.add_argument("-tgt_out", type=str) 31 | parser.add_argument("-model", type=str) 32 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 33 | parser.add_argument('-src_vocab', type=str) 34 | parser.add_argument('-tgt_vocab', type=str) 35 | parser.add_argument('-gan', type = bool, default = False) 36 | args = parser.parse_args() 37 | opt = torch.load(args.model)['opt'] 38 | 39 | 40 | fields = dict() 41 | vocab_src = Vocab(args.src_vocab, noST = True) 42 | vocab_tgt = Vocab(args.tgt_vocab) 43 | fields['src'] = vocab_wrapper(vocab_src) 44 | fields['tgt'] = vocab_wrapper(vocab_tgt) 45 | 46 | use_cuda = False 47 | if args.gpuid: 48 | cuda.set_device(args.gpuid[0]) 49 | use_cuda = True 50 | 51 | if args.model_type == "base": 52 | model = nmt.model_helper.create_base_model(opt, fields) 53 | if args.model_type == "bibase": 54 | model = nmt.model_helper.create_bibase_model(opt, fields) 55 | if args.model_type == "ref": 56 | model = nmt.model_helper.create_ref_model(opt, fields) 57 | if args.model_type == "ev": 58 | model = nmt.model_helper.create_ev_model(opt, fields) 59 | if args.model_type == "rg": 60 | model = nmt.model_helper.create_response_generator(opt, fields) 61 | if args.model_type == "joint": 62 | model = nmt.model_helper.create_joint_model(opt, fields) 63 | print('Loading parameters ...') 64 | 65 | if args.gan: 66 | ckpt = torch.load(args.model) 67 | model.load_state_dict(ckpt['generator_dict']) 68 | else: 69 | model.load_checkpoint(args.model) 70 | 71 | if use_cuda: 72 | model = model.cuda() 73 | 74 | scorer = nmt.Scorer(model, fields['tgt'].vocab, None, None, opt) 75 | mask_end = (args.model_type == 'ev') or (args.model_type == 'joint') 76 | test_iter = Data_Loader(args.test_file, opt.train_batch_size, train = False, mask_end = mask_end) 77 | score_file(scorer, test_iter, args.tgt_out, fields, use_cuda) 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /nmt/modules/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | import numpy as np 9 | 10 | class EncoderBase(nn.Module): 11 | """ 12 | EncoderBase class for sharing code among various encoder. 13 | """ 14 | 15 | def forward(self, input, lengths=None, hidden=None): 16 | """ 17 | Args: 18 | input (LongTensor): len x batch x nfeat. 19 | lengths (LongTensor): batch 20 | hidden: Initial hidden state. 21 | Returns: 22 | hidden_t (Variable): Pair of layers x batch x rnn_size - final 23 | encoder state 24 | outputs (FloatTensor): len x batch x rnn_size - Memory bank 25 | """ 26 | raise NotImplementedError 27 | 28 | 29 | class EncoderRNN(EncoderBase): 30 | """ The standard RNN encoder. """ 31 | def __init__(self, rnn_type, input_size, 32 | hidden_size, num_layers=1, 33 | dropout=0.1, bidirectional=False): 34 | super(EncoderRNN, self).__init__() 35 | 36 | num_directions = 2 if bidirectional else 1 37 | assert hidden_size % num_directions == 0 38 | hidden_size = hidden_size // num_directions 39 | self.rnn_type = rnn_type 40 | self.hidden_size = hidden_size 41 | self.num_layers = num_layers 42 | self.bidirectional = bidirectional 43 | 44 | self.rnn = getattr(nn, rnn_type)( 45 | input_size=input_size, 46 | hidden_size=hidden_size, 47 | num_layers=num_layers, 48 | dropout=dropout, 49 | bidirectional=bidirectional) 50 | 51 | def forward(self, input, lengths, hidden=None): 52 | """ See EncoderBase.forward() for description of args and returns.""" 53 | emb = input 54 | is_sorted = lambda a: np.all(np.array(a[:-1]) >= np.array(a[1:])) 55 | packed_emb = emb 56 | changedorder = False 57 | if lengths is not None: 58 | # Lengths data is wrapped inside a Variable. 59 | if not is_sorted(lengths): 60 | inds = (np.argsort(lengths)[::-1]).copy() 61 | inds_tensor = Variable(torch.LongTensor(inds).cuda()) 62 | emb = emb.index_select(1, inds_tensor) 63 | len_sub = list(np.array(lengths)[inds]) 64 | packed_emb = pack(emb, len_sub) 65 | changedorder = True 66 | else: 67 | packed_emb = pack(emb, lengths) 68 | 69 | outputs, hidden_t = self.rnn(packed_emb, hidden) 70 | 71 | if lengths is not None: 72 | outputs = unpack(outputs)[0] 73 | 74 | if self.bidirectional: 75 | # The encoder hidden is (layers*directions) x batch x dim. 76 | # We need to convert it to layers x batch x (directions*dim). 77 | if self.rnn_type != 'LSTM': 78 | hidden_t = torch.cat([hidden_t[0:hidden_t.size(0):2], hidden_t[1:hidden_t.size(0):2]], 2) 79 | else: 80 | h_n, c_n = hidden_t 81 | h_n = torch.cat([h_n[0:h_n.size(0):2], h_n[1:h_n.size(0):2]], 2) 82 | c_n = torch.cat([c_n[0:c_n.size(0):2], c_n[1:c_n.size(0):2]], 2) 83 | hidden_t = (h_n, c_n) 84 | if changedorder: 85 | rinds = np.argsort(inds) 86 | rinds_tensor = Variable(torch.LongTensor(rinds).cuda()) 87 | outputs = outputs.index_select(1, rinds_tensor) 88 | if self.rnn_type == 'LSTM': 89 | h_n, c_n = hidden_t 90 | hidden_t = (h_n.index_select(1, rinds_tensor), c_n.index_select(1, rinds_tensor)) 91 | else: 92 | hidden_t = hidden_t.index_select(1, rinds_tensor) 93 | return outputs, hidden_t 94 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import codecs 4 | import nmt 5 | import json 6 | from torch import cuda 7 | import time 8 | from train import vocab_wrapper 9 | from data import Data_Loader, Vocab 10 | 11 | 12 | def get_sentence(idx, vocab): 13 | return ' '.join([vocab.itos[x] for x in idx]) 14 | 15 | def translate_file(translator, test_iter, tgt_fout, fields, use_cuda): 16 | print('start translating ...') 17 | process_bar = nmt.misc_utils.ShowProcess(len(test_iter)) 18 | with codecs.open(tgt_fout, 'w', 'utf8') as tgt_file: 19 | for batch in test_iter: 20 | process_bar.show_process() 21 | src, src_lengths = batch.src 22 | ref_src, ref_src_lengths = batch.ref_src 23 | ref_tgt, ref_tgt_lengths = batch.ref_tgt 24 | ret = translator.translate_batch(src, ref_src, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths, batch) 25 | for raw, s in zip(ret['predictions'], ret['scores']): 26 | for idx, hyp_inds in enumerate(raw): 27 | sentence_out = get_sentence(hyp_inds, fields['tgt'].vocab) 28 | tgt_file.write(sentence_out+'\n') 29 | #ms = s[idx]/ len(hyp_inds) 30 | #tgt_file.write(sentence_out+'\t'+str(ms)+'\n') 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("-model_type", type=str) 34 | parser.add_argument("-test_file", type=str) 35 | parser.add_argument("-tgt_out", type=str) 36 | parser.add_argument("-model", type=str) 37 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 38 | parser.add_argument('-beam_size', default = 5, type=int) 39 | parser.add_argument('-decode_max_length', default = 20, type = int) 40 | parser.add_argument('-n_best', default= 1, type = int) 41 | parser.add_argument('-src_vocab', type=str) 42 | parser.add_argument('-tgt_vocab', type=str) 43 | parser.add_argument('-gan', type = bool, default = False) 44 | args = parser.parse_args() 45 | opt = torch.load(args.model)['opt'] 46 | 47 | 48 | fields = dict() 49 | vocab_src = Vocab(args.src_vocab, noST = True) 50 | vocab_tgt = Vocab(args.tgt_vocab) 51 | fields['src'] = vocab_wrapper(vocab_src) 52 | fields['tgt'] = vocab_wrapper(vocab_tgt) 53 | 54 | use_cuda = False 55 | if args.gpuid: 56 | cuda.set_device(args.gpuid[0]) 57 | use_cuda = True 58 | 59 | if args.model_type == "base": 60 | model = nmt.model_helper.create_base_model(opt, fields) 61 | if args.model_type == "bibase": 62 | model = nmt.model_helper.create_bibase_model(opt, fields) 63 | if args.model_type == "ref": 64 | model = nmt.model_helper.create_ref_model(opt, fields) 65 | if args.model_type == "ev": 66 | model = nmt.model_helper.create_ev_model(opt, fields) 67 | if args.model_type == "rg": 68 | model = nmt.model_helper.create_response_generator(opt, fields) 69 | if args.model_type == "CAS": 70 | model = nmt.model_helper.create_joint_model(opt, fields) 71 | if args.model_type == "JNT": 72 | model = nmt.model_helper.create_joint_template_response_model(opt, fields) 73 | 74 | print('Loading parameters ...') 75 | 76 | if args.gan: 77 | ckpt = torch.load(args.model) 78 | model.load_state_dict(ckpt['generator_dict']) 79 | else: 80 | model.load_checkpoint(args.model) 81 | 82 | if use_cuda: 83 | model = model.cuda() 84 | 85 | translator = nmt.Translator(model=model, 86 | fields=fields, 87 | beam_size = args.beam_size, 88 | n_best= args.n_best, 89 | max_length= args.decode_max_length, 90 | global_scorer=None, 91 | cuda=use_cuda) 92 | mask_end = (args.model_type == 'ev') or (args.model_type == 'CAS') or (args.model_type == 'JNT') 93 | test_iter = Data_Loader(args.test_file, 1, train = False, mask_end = mask_end) 94 | translate_file(translator, test_iter, args.tgt_out, fields, use_cuda) 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /nmt/modules/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GlobalAttention(nn.Module): 7 | def __init__(self, dim, attn_type="dot", context_gate = False, gate_vector = False): 8 | super(GlobalAttention, self).__init__() 9 | self.dim = dim 10 | self.attn_type = attn_type 11 | assert (self.attn_type in ["dot", "general", "mlp"]), ( 12 | "Please select a valid attention type.") 13 | if self.attn_type == "general": 14 | self.linear_in = nn.Linear(dim, dim, bias=False) 15 | if self.attn_type == "mlp": 16 | self.s2s = nn.Linear(dim, dim, bias = False) 17 | self.t2s = nn.Linear(dim, dim) 18 | self.v = nn.Linear(dim, 1, bias = False) 19 | 20 | self.context_gate = context_gate 21 | self.gate_vector = gate_vector 22 | if context_gate: 23 | self.gate_linear = nn.Linear(dim*2, (dim if gate_vector else 1)) 24 | self.linear_out = nn.Linear(dim*2, dim, bias = False) 25 | self.sm = nn.Softmax(dim=-1) 26 | self.tanh = nn.Tanh() 27 | 28 | def score(self, h_t, h_s): 29 | # Check input sizes 30 | src_batch, src_len, src_dim = h_s.size() 31 | tgt_batch, tgt_len, tgt_dim = h_t.size() 32 | 33 | if self.attn_type in ["general", "dot"]: 34 | if self.attn_type == "general": 35 | h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim) 36 | h_t_ = self.linear_in(h_t_) 37 | h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) 38 | h_s_ = h_s.transpose(1, 2) 39 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 40 | return torch.bmm(h_t, h_s_) 41 | elif self.attn_type == "mlp": 42 | _t = self.t2s(h_t.view(tgt_batch*tgt_len, tgt_dim)) 43 | _t = _t.view(tgt_batch, tgt_len, 1, tgt_dim) 44 | _t = _t.expand(tgt_batch, tgt_len, src_len, tgt_dim) 45 | 46 | _s = self.s2s(h_s.view(src_batch*src_len, src_dim)) 47 | _s = _s.view(src_batch, 1, src_len, src_dim) 48 | _s = _s.expand(src_batch, tgt_len, src_len, src_dim) 49 | 50 | #print (_t.size(), _s.size()) 51 | return self.v(self.tanh(_t+_s).view(-1, tgt_dim)).view(tgt_batch, tgt_len, src_len) 52 | 53 | 54 | def forward(self, input, context, context_values = None, mask = None): 55 | 56 | # print (input.size(), context.size(), mask.size()) 57 | # one step input 58 | if input.dim() == 2: 59 | one_step = True 60 | input = input.unsqueeze(1) 61 | else: 62 | one_step = False 63 | 64 | 65 | batch, sourceL, dim = context.size() 66 | batch_, targetL, dim_ = input.size() 67 | #print (sourceL, targetL) 68 | # compute attention scores, as in Luong et al. 69 | align = self.score(input, context) 70 | if mask is not None: 71 | mask = mask.unsqueeze(1) 72 | align.data.masked_fill_(1 - mask, -float('inf')) 73 | # Softmax to normalize attention weights 74 | align_vectors = self.sm(align.view(batch*targetL, sourceL)) 75 | align_vectors = align_vectors.view(batch, targetL, sourceL) 76 | #print (align_vectors) 77 | # each context vector c_t is the weighted average 78 | # over all the source hidden states 79 | if context_values is not None: 80 | c = torch.bmm(align_vectors, context_values) 81 | else: 82 | c = torch.bmm(align_vectors, context) 83 | 84 | # concatenate 85 | concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) 86 | if self.context_gate: 87 | gates = F.sigmoid(self.gate_linear(concat_c).view(batch, targetL, (dim if self.gate_vector else 1))) 88 | attn_h = (1-gates) * c + gates * input 89 | attn_h = attn_h.view(batch, targetL, dim) 90 | else: 91 | attn_h = self.linear_out(concat_c).view(batch, targetL, dim) 92 | attn_h = self.tanh(attn_h) 93 | 94 | if one_step: 95 | attn_h = attn_h.squeeze(1) 96 | align_vectors = align_vectors.squeeze(1) 97 | else: 98 | attn_h = attn_h.transpose(0, 1).contiguous() 99 | align_vectors = align_vectors.transpose(0, 1).contiguous() 100 | return attn_h, align_vectors 101 | -------------------------------------------------------------------------------- /nmt/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from torch.autograd import Variable 5 | from nmt.Trainer import Statistics 6 | 7 | 8 | class NMTLossCompute(nn.Module): 9 | """ 10 | Standard NMT Loss Computation. 11 | """ 12 | def __init__(self, generator, tgt_vocab): 13 | super(NMTLossCompute, self).__init__() 14 | self.generator = generator 15 | self.tgt_vocab = tgt_vocab 16 | self.padding_idx = tgt_vocab.stoi[tgt_vocab.PAD] 17 | weight = torch.ones(len(tgt_vocab)) 18 | weight[self.padding_idx] = 0 19 | self.criterion = nn.NLLLoss(weight, size_average=False) 20 | 21 | def make_shard_state(self, batch, output): 22 | """ See base class for args description. """ 23 | return { 24 | "output": output, 25 | "target": batch.tgt[0][1:], 26 | } 27 | 28 | def compute_loss(self, batch, weight, output, target): 29 | scores = self.generator(self.bottle(output)) 30 | if weight is not None: 31 | scores = self.unbottle(scores, batch.batch_size) 32 | scores = self.bottle((scores * weight.view(1, -1, 1))) 33 | target = target.view(-1) 34 | loss = self.criterion(scores,target) 35 | 36 | loss_data = loss.data.clone() 37 | stats = self.stats(loss_data, scores.data, target.data) 38 | return loss, stats 39 | 40 | def sharded_compute_loss(self, batch, output, shard_size, weight = None): 41 | """ 42 | Compute the loss in shards for efficiency. 43 | """ 44 | batch_stats = Statistics() 45 | shard_state = self.make_shard_state(batch, output) 46 | 47 | for shard in shards(shard_state, shard_size): 48 | loss, stats = self.compute_loss(batch, weight, **shard) 49 | loss.div(batch.batch_size).backward() 50 | batch_stats.update(stats) 51 | 52 | return batch_stats 53 | 54 | def monolithic_compute_loss(self, batch, output): 55 | """ 56 | Compute the loss monolithically, not dividing into shards. 57 | """ 58 | 59 | shard_state = self.make_shard_state(batch, output) 60 | _, batch_stats = self.compute_loss(batch, None, **shard_state) 61 | 62 | return batch_stats 63 | 64 | def stats(self, loss, scores, target): 65 | """ 66 | Compute and return a Statistics object. 67 | Args: 68 | loss(Tensor): the loss computed by the loss criterion. 69 | scores(Tensor): a sequence of predict output with scores. 70 | """ 71 | pred = scores.max(1)[1] 72 | non_padding = target.ne(self.padding_idx) 73 | num_correct = pred.eq(target) \ 74 | .masked_select(non_padding) \ 75 | .sum() 76 | return Statistics(loss.item(), non_padding.sum().item(), num_correct) 77 | 78 | def bottle(self, v): 79 | return v.view(-1, v.size(2)) 80 | 81 | def unbottle(self, v, batch_size): 82 | return v.view(-1, batch_size, v.size(1)) 83 | 84 | 85 | def filter_shard_state(state): 86 | for k, v in state.items(): 87 | if v is not None: 88 | if isinstance(v, Variable) and v.requires_grad: 89 | v = Variable(v.data, requires_grad=True, volatile=False) 90 | yield k, v 91 | 92 | 93 | def shards(state, shard_size, eval=False): 94 | """ 95 | Args: 96 | state: A dictionary which corresponds to the output of 97 | *LossCompute.make_shard_state(). The values for 98 | those keys are Tensor-like or None. 99 | shard_size: The maximum size of the shards yielded by the model. 100 | eval: If True, only yield the state, nothing else. 101 | Otherwise, yield shards. 102 | Yields: 103 | Each yielded shard is a dict. 104 | Side effect: 105 | After the last shard, this function does back-propagation. 106 | """ 107 | if eval: 108 | yield state 109 | else: 110 | # non_none: the subdict of the state dictionary where the values 111 | # are not None. 112 | non_none = dict(filter_shard_state(state)) 113 | 114 | # Now, the iteration: 115 | # state is a dictionary of sequences of tensor-like but we 116 | # want a sequence of dictionaries of tensors. 117 | # First, unzip the dictionary into a sequence of keys and a 118 | # sequence of tensor-like sequences. 119 | keys, values = zip(*((k, torch.split(v, shard_size)) 120 | for k, v in non_none.items())) 121 | 122 | # Now, yield a dictionary for each shard. The keys are always 123 | # the same. values is a sequence of length #keys where each 124 | # element is a sequence of length #shards. We want to iterate 125 | # over the shards, not over the keys: therefore, the values need 126 | # to be re-zipped by shard and then each shard can be paired 127 | # with the keys. 128 | for shard_tensors in zip(*values): 129 | yield dict(zip(keys, shard_tensors)) 130 | 131 | # Assumed backprop'd 132 | variables = ((state[k], v.grad.data) for k, v in non_none.items() 133 | if isinstance(v, Variable) and v.grad is not None) 134 | inputs, grads = zip(*variables) 135 | torch.autograd.backward(inputs, grads, retain_graph = True) 136 | -------------------------------------------------------------------------------- /nmt/modules/Beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class Beam(object): 6 | """ 7 | Class for managing the internals of the beam search process. 8 | Takes care of beams, back pointers, and scores. 9 | Args: 10 | size (int): beam size 11 | pad, bos, eos (int): indices of padding, beginning, and ending. 12 | n_best (int): nbest size to use 13 | cuda (bool): use gpu 14 | global_scorer (:obj:`GlobalScorer`) 15 | """ 16 | def __init__(self, size, pad, bos, eos, 17 | n_best=1, cuda=False, 18 | global_scorer=None, 19 | min_length=0): 20 | 21 | self.size = size 22 | self.tt = torch.cuda if cuda else torch 23 | 24 | # The score for each translation on the beam. 25 | self.scores = self.tt.FloatTensor(size).zero_() 26 | self.all_scores = [] 27 | 28 | # The backpointers at each time-step. 29 | self.prev_ks = [] 30 | 31 | # The outputs at each time-step. 32 | self.next_ys = [self.tt.LongTensor(size) 33 | .fill_(pad)] 34 | self.next_ys[0][0] = bos 35 | 36 | # Has EOS topped the beam yet. 37 | self._eos = eos 38 | self.eos_top = False 39 | 40 | # The attentions (matrix) for each time. 41 | self.attn = [] 42 | 43 | # Time and k pair for finished. 44 | self.finished = [] 45 | self.n_best = n_best 46 | 47 | # Information for global scoring. 48 | self.global_scorer = global_scorer 49 | self.global_state = {} 50 | 51 | # Minimum prediction length 52 | self.min_length = min_length 53 | 54 | def get_current_state(self): 55 | "Get the outputs for the current timestep." 56 | return self.next_ys[-1] 57 | 58 | def get_current_origin(self): 59 | "Get the backpointers for the current timestep." 60 | return self.prev_ks[-1] 61 | 62 | # def advance(self, word_probs, attn_out): 63 | def advance(self, word_probs): 64 | """ 65 | Given prob over words for every last beam `wordLk` and attention 66 | `attn_out`: Compute and update the beam search. 67 | Parameters: 68 | * `word_probs`- probs of advancing from the last step (K x words) 69 | * `attn_out`- attention at the last step 70 | Returns: True if beam search is complete. 71 | """ 72 | num_words = word_probs.size(1) 73 | 74 | # force the output to be longer than self.min_length 75 | cur_len = len(self.next_ys) 76 | if cur_len < self.min_length: 77 | for k in range(len(word_probs)): 78 | word_probs[k][self._eos] = -1e20 79 | 80 | for k in range(len(word_probs)): 81 | word_probs[k][3] = -1e20 82 | 83 | # Sum the previous scores. 84 | if len(self.prev_ks) > 0: 85 | beam_scores = word_probs + \ 86 | self.scores.unsqueeze(1).expand_as(word_probs) 87 | 88 | # Don't let EOS have children. 89 | for i in range(self.next_ys[-1].size(0)): 90 | if self.next_ys[-1][i] == self._eos: 91 | beam_scores[i] = -1e20 92 | else: 93 | beam_scores = word_probs[0] 94 | flat_beam_scores = beam_scores.view(-1) 95 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 96 | True, True) 97 | 98 | 99 | self.scores = best_scores 100 | self.all_scores.append(self.scores) 101 | # best_scores_id is flattened beam x word array, so calculate which 102 | # word and beam each score came from 103 | prev_k = best_scores_id / num_words 104 | self.prev_ks.append(prev_k) 105 | self.next_ys.append((best_scores_id - prev_k * num_words)) 106 | # self.attn.append(attn_out.index_select(0, prev_k)) 107 | 108 | if self.global_scorer is not None: 109 | self.global_scorer.update_global_state(self) 110 | 111 | for i in range(self.next_ys[-1].size(0)): 112 | if self.next_ys[-1][i] == self._eos: 113 | s = self.scores[i] 114 | if self.global_scorer is not None: 115 | global_scores = self.global_scorer.score(self, self.scores) 116 | s = global_scores[i] 117 | self.finished.append((s, len(self.next_ys) - 1, i)) 118 | 119 | # End condition is when top-of-beam is EOS and no global score. 120 | if self.next_ys[-1][0] == self._eos: 121 | # self.all_scores.append(self.scores) 122 | self.eos_top = True 123 | 124 | def done(self): 125 | return self.eos_top and len(self.finished) >= self.n_best 126 | 127 | def sort_finished(self, minimum=None): 128 | if minimum is not None: 129 | i = 0 130 | # Add from beam until we have minimum outputs. 131 | while len(self.finished) < minimum: 132 | s = self.scores[i] 133 | if self.global_scorer is not None: 134 | global_scores = self.global_scorer.score(self, self.scores) 135 | s = global_scores[i] 136 | self.finished.append((s, len(self.next_ys) - 1, i)) 137 | 138 | self.finished.sort(key=lambda a: -a[0]) 139 | scores = [sc for sc, _, _ in self.finished] 140 | ks = [(t, k) for _, t, k in self.finished] 141 | return scores, ks 142 | 143 | def get_hyp(self, timestep, k): 144 | """ 145 | Walk back to construct the full hypothesis. 146 | """ 147 | hyp, attn = [], [] 148 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 149 | hyp.append(self.next_ys[j+1][k]) 150 | # attn.append(self.attn[j][k]) 151 | k = self.prev_ks[j][k] 152 | return hyp[::-1] #, torch.stack(attn[::-1]) 153 | 154 | 155 | class GNMTGlobalScorer(object): 156 | """ 157 | NMT re-ranking score from 158 | "Google's Neural Machine Translation System" :cite:`wu2016google` 159 | Args: 160 | alpha (float): length parameter 161 | beta (float): coverage parameter 162 | """ 163 | def __init__(self, alpha, beta): 164 | self.alpha = alpha 165 | self.beta = beta 166 | 167 | def score(self, beam, logprobs): 168 | "Additional term add to log probability" 169 | cov = beam.global_state["coverage"] 170 | pen = self.beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 171 | l_term = (((5 + len(beam.next_ys)) ** self.alpha) / 172 | ((5 + 1) ** self.alpha)) 173 | return (logprobs / l_term) + pen 174 | 175 | def update_global_state(self, beam): 176 | "Keeps the coverage vector as sum of attens" 177 | if len(beam.prev_ks) == 1: 178 | beam.global_state["coverage"] = beam.attn[-1] 179 | else: 180 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 181 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 182 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import math 3 | import torch 4 | from torch.autograd import Variable 5 | import random 6 | PAD, EOS, EOT, UNK = '', '', '', '' 7 | PAD_idx, EOS_idx, EOT_idx, UNK_idx = 0, 1, 2, 3 8 | 9 | class Vocab(object): 10 | def __init__(self, vocab_file, noST = False): 11 | with open(vocab_file, encoding="utf-8") as f: 12 | if noST: 13 | self.itos = [PAD] + [ token.strip() for token in f.readlines() ] 14 | else: 15 | self.itos = [PAD, EOS, EOT] + [ token.strip() for token in f.readlines() ] 16 | self.stoi = dict(zip(self.itos, range(len(self.itos)))) 17 | self.PAD = PAD 18 | #assert (self.itos[3] == UNK), (self.itos[3], "first word must be ") 19 | self.UNK = UNK 20 | self.EOS = EOS 21 | self.EOT = EOT 22 | 23 | def i2s(self, idx): 24 | return self.itos[idx] 25 | 26 | def __len__(self): 27 | return len(self.itos) 28 | 29 | def ListsToTensor(xs, tgt = False): 30 | batch_size = len(xs) 31 | lens = [ len(x)+(2 if tgt else 0) for x in xs] 32 | mx_len = max( max(lens),1) 33 | ys = [] 34 | for i, x in enumerate(xs): 35 | y = ([EOS_idx] if tgt else [] )+ x + ([EOT_idx] if tgt else []) + ([PAD_idx]*(mx_len - lens[i])) 36 | ys.append(y) 37 | 38 | lens = [ max(1, x) for x in lens] 39 | data = Variable(torch.LongTensor(ys).t_()) 40 | 41 | data = data.cuda() 42 | 43 | return (data, lens) 44 | 45 | def LCS_mask(src, tgt, stop_words): 46 | m = len(src) 47 | n = len(tgt) 48 | if stop_words is None: 49 | stop_words = set() 50 | mat = [[0] * (n+1) for row in range(m+1)] 51 | for row in range(1, m+1): 52 | for col in range(1, n+1): 53 | if src[row - 1] == tgt[col - 1] and (src[row-1] not in stop_words): 54 | mat[row][col] = mat[row - 1][col - 1] + 1 55 | else: 56 | mat[row][col] = max(mat[row][col - 1], mat[row - 1][col]) 57 | x,y = m,n 58 | mask = [] 59 | while y >0 and x >0: 60 | if mat[x][y] == mat[x-1][y-1] + 1: 61 | x -=1 62 | y -=1 63 | mask.append(1) 64 | elif mat[x][y] == mat[x][y-1]: 65 | y -= 1 66 | mask.append(0) 67 | else: 68 | x -= 1 69 | while y>0: 70 | y -= 1 71 | mask.append(0) 72 | return mask[::-1] 73 | 74 | class Batch(object): 75 | def __init__(self, qs, rs, rqs, rrs, ss, mask_end = False, stop_words = None): 76 | if mask_end: 77 | I, D, ref_tgt, mask,ref_src = [], [], [], [], [] 78 | for q ,r,rq,rr in zip(qs, rs, rqs, rrs): 79 | rq = rq[0] 80 | rr = rr[0] 81 | q_set = set(q) 82 | rq_set = set(rq) 83 | ins = q_set&rq_set 84 | I.append( list(q_set - ins)) 85 | D.append( list(rq_set - ins)) 86 | ref_tgt.append(rr) 87 | ref_src.append(rq) 88 | mask.append(LCS_mask(r, rr, stop_words)) 89 | self.I = ListsToTensor(I) 90 | self.D = ListsToTensor(D) 91 | self.ref_tgt = ListsToTensor(ref_tgt) 92 | self.mask = ListsToTensor(mask) 93 | self.src = ListsToTensor(qs) 94 | self.tgt = ListsToTensor(rs, tgt=True) 95 | self.ref_src = ListsToTensor(ref_src) 96 | self.batch_size = len(qs) 97 | self.score = ss 98 | return 99 | self.src = ListsToTensor(qs) 100 | self.tgt = ListsToTensor(rs, tgt=True) 101 | num_rq = max( 1, max(len(rq) for rq in rqs)) 102 | num_rr = max( 1, max(len(rr) for rr in rrs)) 103 | ref_src =[ ListsToTensor([ 104 | rq[i] if i < len(rq) else [PAD_idx] 105 | for rq in rqs]) 106 | for i in range(num_rq)] 107 | ref_tgt =[ ListsToTensor([ 108 | rr[i] if i < len(rr) else [PAD_idx] 109 | for rr in rrs], tgt = True) 110 | for i in range(num_rr)] 111 | self.ref_src = ([x[0] for x in ref_src[:5]], [x[1] for x in ref_src[:5]]) 112 | self.ref_tgt = ([x[0] for x in ref_tgt[:5]], [x[1] for x in ref_tgt[:5]]) 113 | self.score = ss 114 | self.batch_size = len(qs) 115 | 116 | class Data_Loader(object): 117 | def __init__(self, fname, batch_size = 32, train = True, score = None, mask_end = False, stop_words = None): 118 | all_q = [] 119 | all_r = [] 120 | all_rq = [] 121 | all_rr = [] 122 | with open(fname) as f: 123 | for line in f.readlines(): 124 | x = line.split('|') 125 | q, r = x[:2] 126 | q = [int(x) for x in q.split()] 127 | r = [int(x) for x in r.split()] 128 | if train and (len(q) <= 2 or len(r)<=2): 129 | continue 130 | q = q[:30] 131 | r = r[:30] 132 | all_q.append(q) 133 | all_r.append(r) 134 | rq, rr = [], [] 135 | for q,r in zip(x[2::2], x[3::2]): 136 | q = [int(x) for x in q.split()] 137 | r = [int(x) for x in r.split()] 138 | q = q[:30] 139 | r = r[:30] 140 | rq.append(q) 141 | rr.append(r) 142 | all_rq.append(rq) 143 | all_rr.append(rr) 144 | 145 | self.all_s = None 146 | if score: 147 | self.all_s = [float(x.strip()) for x in open(score).readlines()] 148 | self.batch_size = batch_size 149 | self.all_q = all_q 150 | self.all_r = all_r 151 | self.all_rq = all_rq 152 | self.all_rr= all_rr 153 | self.train = train 154 | self.mask_end = mask_end 155 | 156 | if stop_words is not None: 157 | stop_words = set( [ int(x.strip()) for x in open(stop_words).readlines()]) 158 | self.stop_words = stop_words 159 | else: 160 | self.stop_words = None 161 | def __len__(self): 162 | return math.ceil(len(self.all_q)/self.batch_size) 163 | 164 | def __iter__(self): 165 | idx = list(range(len(self.all_q))) 166 | if self.train: 167 | random.shuffle(idx) 168 | cur = 0 169 | while cur < len(idx): 170 | batch = idx[cur:cur + self.batch_size] 171 | cur_q = [self.all_q[x] for x in batch] 172 | cur_r = [self.all_r[x] for x in batch] 173 | cur_rq = [self.all_rq[x] for x in batch] 174 | cur_rr = [self.all_rr[x] for x in batch] 175 | if self.all_s is not None: 176 | cur_s = [self.all_s[x] for x in batch] 177 | else: 178 | cur_s = None 179 | yield Batch(cur_q, cur_r, cur_rq, cur_rr, cur_s, self.mask_end, self.stop_words) 180 | cur += self.batch_size 181 | raise StopIteration 182 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import nmt.utils.misc_utils as utils 2 | import argparse 3 | import codecs 4 | import os 5 | import shutil 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | from torch import cuda 10 | import nmt 11 | import random 12 | from data import Vocab, Data_Loader 13 | import numpy 14 | 15 | use_cuda = True 16 | 17 | 18 | def report_func(opt, global_step, epoch, batch, num_batches, 19 | start_time, lr, report_stats): 20 | """ 21 | This is the user-defined batch-level traing progress 22 | report function. 23 | Args: 24 | epoch(int): current epoch count. 25 | batch(int): current batch count. 26 | num_batches(int): total number of batches. 27 | start_time(float): last report time. 28 | lr(float): current learning rate. 29 | report_stats(Statistics): old Statistics instance. 30 | Returns: 31 | report_stats(Statistics): updated Statistics instance. 32 | """ 33 | if batch % opt.steps_per_stats == -1 % opt.steps_per_stats: 34 | report_stats.print_out(epoch, batch+1, num_batches, start_time) 35 | report_stats = nmt.Statistics() 36 | 37 | return report_stats 38 | 39 | def build_or_load_model(args, model_opt, fields): 40 | if args.model_type == "base": 41 | model = nmt.model_helper.create_base_model(model_opt, fields) 42 | if args.model_type == "bibase": 43 | model = nmt.model_helper.create_bibase_model(model_opt, fields) 44 | if args.model_type == "ref": 45 | model = nmt.model_helper.create_ref_model(model_opt, fields) 46 | if args.model_type == "ev": 47 | model = nmt.model_helper.create_ev_model(model_opt, fields) 48 | if args.model_type == "rg": 49 | model = nmt.model_helper.create_response_generator(model_opt, fields) 50 | if args.model_type == "joint": 51 | model = nmt.model_helper.create_joint_model(model_opt, fields) 52 | model.response_generator.load_checkpoint(args.rg_model) 53 | model.template_generator.load_checkpoint(args.tg_model) 54 | latest_ckpt = nmt.misc_utils.latest_checkpoint(model_opt.out_dir) 55 | start_epoch_at = 0 56 | if model_opt.start_epoch_at is not None: 57 | ckpt = 'checkpoint_epoch%d.pkl'%(model_opt.start_epoch_at) 58 | ckpt = os.path.join(model_opt.out_dir,ckpt) 59 | else: 60 | ckpt = latest_ckpt 61 | 62 | if ckpt: 63 | print('Loding model from %s...'%(ckpt)) 64 | start_epoch_at = model.load_checkpoint(ckpt) 65 | else: 66 | print('Building model...') 67 | print(model) 68 | #model.save_checkpoint(0, model_opt, os.path.join(model_opt.out_dir,"checkpoint_epoch0.pkl")) 69 | return model, start_epoch_at 70 | 71 | 72 | def build_optim(model, optim_opt): 73 | optim = nmt.Optim(optim_opt.optim_method, 74 | optim_opt.learning_rate, 75 | optim_opt.max_grad_norm, 76 | optim_opt.learning_rate_decay, 77 | optim_opt.weight_decay, 78 | optim_opt.start_decay_at) 79 | 80 | optim.set_parameters(model.parameters()) 81 | return optim 82 | 83 | def build_lr_scheduler(optimizer, opt): 84 | 85 | lr_lambda = lambda epoch: opt.learning_rate_decay ** epoch 86 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, 87 | lr_lambda=[lr_lambda]) 88 | return scheduler 89 | 90 | def check_save_model_path(args, opt): 91 | if not os.path.exists(opt.out_dir): 92 | os.makedirs(opt.out_dir) 93 | print('saving config file to %s ...'%(opt.out_dir)) 94 | shutil.copy(args.config, os.path.join(opt.out_dir,'config.yml')) 95 | 96 | 97 | def train_model(opt, model, train_iter, valid_iter, fields, optim, lr_scheduler, start_epoch_at): 98 | train_loss = nmt.NMTLossCompute(model.generator,fields['tgt'].vocab) 99 | valid_loss = nmt.NMTLossCompute(model.generator,fields['tgt'].vocab) 100 | 101 | if use_cuda: 102 | train_loss = train_loss.cuda() 103 | valid_loss = valid_loss.cuda() 104 | 105 | shard_size = opt.train_shard_size 106 | trainer = nmt.Trainer(opt, model, 107 | train_iter, 108 | valid_iter, 109 | train_loss, 110 | valid_loss, 111 | optim, 112 | shard_size) 113 | 114 | num_train_epochs = opt.num_train_epochs 115 | print('start training...') 116 | for step_epoch in range(start_epoch_at+1, num_train_epochs): 117 | 118 | if step_epoch >= opt.start_decay_at: 119 | lr_scheduler.step() 120 | # 1. Train for one epoch on the training set. 121 | train_stats = trainer.train(step_epoch, report_func) 122 | print('Train perplexity: %g' % train_stats.ppl()) 123 | 124 | #2. Validate on the validation set. 125 | valid_stats = trainer.validate() 126 | print('Validation perplexity: %g' % valid_stats.ppl()) 127 | 128 | trainer.epoch_step(step_epoch, out_dir=opt.out_dir) 129 | 130 | model.train() 131 | 132 | class vocab_wrapper(object): 133 | def __init__(self, vocab): 134 | self.vocab = vocab 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("-config", type=str) 139 | parser.add_argument("-nmt_dir", type=str) 140 | parser.add_argument("-model_type", type=str) 141 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 142 | parser.add_argument("-valid_file", type=str) 143 | parser.add_argument("-train_file", type=str) 144 | parser.add_argument("-train_score", type=str, default= None) 145 | parser.add_argument("-src_vocab", type = str) 146 | parser.add_argument("-tgt_vocab", type = str) 147 | parser.add_argument("-rg_model", type = str, default = None) 148 | parser.add_argument("-tg_model", type = str, default = None) 149 | args = parser.parse_args() 150 | opt = utils.load_hparams(args.config) 151 | cuda.set_device(args.gpuid[0]) 152 | if opt.random_seed > 0: 153 | random.seed(opt.random_seed) 154 | torch.manual_seed(opt.random_seed) 155 | numpy.random.seed(opt.random_seed) 156 | 157 | fields = dict() 158 | vocab_src = Vocab(args.src_vocab, noST = True) 159 | vocab_tgt = Vocab(args.tgt_vocab) 160 | 161 | fields['src'] = vocab_wrapper(vocab_src) 162 | fields['tgt'] = vocab_wrapper(vocab_tgt) 163 | 164 | mask_end = (args.model_type == "ev") or (args.model_type == "joint") 165 | train = Data_Loader(args.train_file, opt.train_batch_size, score = args.train_score, mask_end = mask_end) 166 | valid = Data_Loader(args.valid_file, opt.train_batch_size, mask_end = mask_end) 167 | 168 | # Build model. 169 | 170 | model, start_epoch_at = build_or_load_model(args, opt, fields) 171 | check_save_model_path(args, opt) 172 | 173 | # Build optimizer. 174 | optim = build_optim(model, opt) 175 | lr_scheduler = build_lr_scheduler(optim.optimizer, opt) 176 | 177 | if use_cuda: 178 | model = model.cuda() 179 | 180 | # Do training. 181 | 182 | train_model(opt, model, train, valid, fields, optim, lr_scheduler, start_epoch_at) 183 | x = 1 184 | while True: 185 | x = (x+1)%5 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /nmt/modules/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from nmt.modules.Attention import GlobalAttention 4 | # from nmt.modules.SRU import SRU 5 | from nmt.modules.StackedRNN import StackedGRU, StackedLSTM 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import math 9 | 10 | class DecoderBase(nn.Module): 11 | def forward(self, input, context, state): 12 | """ 13 | Forward through the decoder. 14 | Args: 15 | input (LongTensor): a sequence of input tokens tensors 16 | of size (len x batch x nfeats). 17 | context (FloatTensor): output(tensor sequence) from the encoder 18 | RNN of size (src_len x batch x hidden_size). 19 | state (FloatTensor): hidden state from the encoder RNN for 20 | initializing the decoder. 21 | Returns: 22 | outputs (FloatTensor): a Tensor sequence of output from the decoder 23 | of shape (len x batch x hidden_size). 24 | state (FloatTensor): final hidden state from the decoder. 25 | attns (dict of (str, FloatTensor)): a dictionary of different 26 | type of attention Tensor from the decoder 27 | of shape (src_len x batch). 28 | """ 29 | raise NotImplementedError 30 | 31 | class KVAttnDecoderRNN(DecoderBase): 32 | def __init__(self, rnn_type, attn_type, input_size, 33 | hidden_size, num_layers=1, dropout=0.1, src_attention = False, mem_gate = False, gate_vector = False, return_original = False): 34 | super(KVAttnDecoderRNN, self).__init__() 35 | # Basic attributes. 36 | self.rnn_type = rnn_type 37 | self.attn_type = attn_type 38 | self.num_layers = num_layers 39 | self.hidden_size = hidden_size 40 | self.dropout = nn.Dropout(dropout) 41 | 42 | self.rnn = getattr(nn, rnn_type)( 43 | input_size=input_size, 44 | hidden_size=hidden_size, 45 | num_layers=num_layers, 46 | dropout=dropout) 47 | 48 | self.src_attention = src_attention 49 | if src_attention: 50 | self.src_attn = GlobalAttention(hidden_size, attn_type) 51 | self.mem_attn = GlobalAttention(hidden_size, "general", context_gate= mem_gate, gate_vector = gate_vector) 52 | self.return_original = return_original 53 | 54 | def forward(self, input, context_keys, context_values, state, mem_mask =None, src_context = None, src_mask = None): 55 | emb = input 56 | rnn_outputs, hidden = self.rnn(emb, state) 57 | 58 | if self.src_attention: 59 | originals, src_score = self.src_attn( 60 | rnn_outputs.transpose(0, 1).contiguous(), 61 | src_context.transpose(0, 1), 62 | mask = src_mask 63 | ) 64 | else: 65 | orginals = rnn_outputs 66 | original = self.dropout(originals) 67 | 68 | if not self.return_original: 69 | rnn_outputs = originals 70 | # Calculate the attention. 71 | attn_outputs, attn_scores = self.mem_attn( 72 | rnn_outputs.transpose(0, 1).contiguous(), # (output_len, batch, d) 73 | context_keys.transpose(0, 1), 74 | context_values.transpose(0, 1), # (contxt_len, batch, d) 75 | mask = mem_mask 76 | ) 77 | outputs = self.dropout(attn_outputs) # (input_len, batch, d) 78 | if self.return_original: 79 | return outputs, hidden, attn_scores, originals 80 | return outputs , hidden, attn_scores 81 | 82 | 83 | class AttnDecoderRNN(DecoderBase): 84 | """ The GlobalAttention-based RNN decoder. """ 85 | def __init__(self, rnn_type, attn_type, input_size, 86 | hidden_size, num_layers=1, dropout=0.1): 87 | super(AttnDecoderRNN, self).__init__() 88 | # Basic attributes. 89 | self.rnn_type = rnn_type 90 | self.attn_type = attn_type 91 | self.num_layers = num_layers 92 | self.hidden_size = hidden_size 93 | self.dropout = nn.Dropout(dropout) 94 | 95 | self.rnn = getattr(nn, rnn_type)( 96 | input_size=input_size, 97 | hidden_size=hidden_size, 98 | num_layers=num_layers, 99 | dropout=dropout) 100 | 101 | if self.attn_type != 'none': 102 | self.attn = GlobalAttention(hidden_size, attn_type) 103 | 104 | def forward(self, input, context, state, attn_mask = None): 105 | emb = input 106 | rnn_outputs, hidden = self.rnn(emb, state) 107 | 108 | if self.attn_type != 'none': 109 | # Calculate the attention. 110 | attn_outputs, attn_scores = self.attn( 111 | rnn_outputs.transpose(0, 1).contiguous(), # (output_len, batch, d) 112 | context.transpose(0, 1), # (contxt_len, batch, d) 113 | mask = attn_mask 114 | ) 115 | 116 | outputs = self.dropout(attn_outputs) # (input_len, batch, d) 117 | attn = attn_scores 118 | else: 119 | outputs = self.dropout(rnn_outputs) 120 | attn = None 121 | 122 | return outputs , hidden, attn 123 | 124 | class AuxDecoderRNN(DecoderBase): 125 | """ The GlobalAttention-based RNN decoder. """ 126 | def __init__(self, rnn_type, attn_type, input_size, 127 | hidden_size, num_layers=1, dropout=0.1): 128 | super(AuxDecoderRNN, self).__init__() 129 | # Basic attributes. 130 | self.rnn_type = rnn_type 131 | self.attn_type = attn_type 132 | self.num_layers = num_layers 133 | self.hidden_size = hidden_size 134 | self.dropout = nn.Dropout(dropout) 135 | 136 | self.rnn = getattr(nn, rnn_type)( 137 | input_size=input_size, 138 | hidden_size=hidden_size, 139 | num_layers=num_layers, 140 | dropout=dropout) 141 | 142 | if self.attn_type != 'none': 143 | self.attn = GlobalAttention(hidden_size, attn_type) 144 | 145 | def forward(self, input, context, state, aux, attn_mask = None): 146 | emb = input 147 | aux_input = aux.unsqueeze(0).repeat(emb.size()[0], 1, 1) 148 | 149 | emb = torch.cat([emb, aux_input], 2) 150 | rnn_outputs, hidden = self.rnn(emb, state) 151 | 152 | if self.attn_type != 'none': 153 | # Calculate the attention. 154 | attn_outputs, attn_scores = self.attn( 155 | rnn_outputs.transpose(0, 1).contiguous(), # (output_len, batch, d) 156 | context.transpose(0, 1), # (contxt_len, batch, d) 157 | mask = attn_mask 158 | ) 159 | 160 | outputs = self.dropout(attn_outputs) # (input_len, batch, d) 161 | attn = attn_outputs 162 | else: 163 | outputs = self.dropout(rnn_outputs) 164 | attn = None 165 | return outputs , hidden, attn 166 | 167 | class AuxMemDecoderRNN(DecoderBase): 168 | """ The GlobalAttention-based RNN decoder. """ 169 | def __init__(self, rnn_type, attn_type, input_size, 170 | hidden_size, num_layers=1, dropout=0.1, src_attention = False, mem_gate = False, gate_vector = False): 171 | super(AuxMemDecoderRNN, self).__init__() 172 | # Basic attributes. 173 | self.rnn_type = rnn_type 174 | self.attn_type = attn_type 175 | self.num_layers = num_layers 176 | self.hidden_size = hidden_size 177 | self.dropout = nn.Dropout(dropout) 178 | 179 | self.rnn = getattr(nn, rnn_type)( 180 | input_size=input_size, 181 | hidden_size=hidden_size, 182 | num_layers=num_layers, 183 | dropout=dropout) 184 | 185 | self.src_attention = src_attention 186 | if src_attention: 187 | self.src_attn = GlobalAttention(hidden_size, attn_type) 188 | self.mem_attn = GlobalAttention(hidden_size, "general", context_gate= mem_gate, gate_vector = gate_vector) 189 | 190 | def forward(self, input, mem_context, state, aux, mem_mask =None, src_context = None, src_mask = None): 191 | emb = input 192 | aux_input = aux.unsqueeze(0).repeat(emb.size()[0], 1, 1) 193 | 194 | emb = torch.cat([emb, aux_input], 2) 195 | 196 | rnn_outputs, hidden = self.rnn(emb, state) 197 | 198 | if self.src_attention: 199 | originals, src_score = self.src_attn( 200 | rnn_outputs.transpose(0, 1).contiguous(), 201 | src_context.transpose(0, 1), 202 | mask = src_mask 203 | ) 204 | else: 205 | orginals = rnn_outputs 206 | rnn_outputs = self.dropout(originals) 207 | 208 | # Calculate the attention. 209 | attn_outputs, attn_scores = self.mem_attn( 210 | rnn_outputs.transpose(0, 1).contiguous(), # (output_len, batch, d) 211 | mem_context.transpose(0, 1), 212 | mask = mem_mask 213 | ) 214 | outputs = self.dropout(attn_outputs) # (input_len, batch, d) 215 | return outputs , hidden, attn_scores 216 | -------------------------------------------------------------------------------- /template.py: -------------------------------------------------------------------------------- 1 | import nmt.utils.misc_utils as utils 2 | import argparse 3 | import codecs 4 | import os, sys 5 | import shutil 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | from torch import cuda 10 | import nmt 11 | import random 12 | from data import Vocab, Data_Loader 13 | from nmt.utils.data_utils import sequence_mask 14 | import torch.nn.functional as F 15 | use_cuda = True 16 | 17 | def build_or_load_model(args, model_opt, fields): 18 | model = nmt.model_helper.create_template_generator(model_opt, fields) 19 | 20 | latest_ckpt = nmt.misc_utils.latest_checkpoint(model_opt.out_dir) 21 | start_epoch_at = 0 22 | if model_opt.start_epoch_at is not None: 23 | ckpt = 'checkpoint_epoch%d.pkl'%(model_opt.start_epoch_at) 24 | ckpt = os.path.join(model_opt.out_dir,ckpt) 25 | else: 26 | ckpt = latest_ckpt 27 | 28 | if ckpt: 29 | print('Loding model from %s...'%(ckpt)) 30 | start_epoch_at = model.load_checkpoint(ckpt) 31 | else: 32 | print('Building model...') 33 | print(model) 34 | 35 | return model, start_epoch_at 36 | 37 | 38 | def build_optim(model, optim_opt): 39 | optim = nmt.Optim(optim_opt.optim_method, 40 | optim_opt.learning_rate, 41 | optim_opt.max_grad_norm, 42 | optim_opt.learning_rate_decay, 43 | optim_opt.weight_decay, 44 | optim_opt.start_decay_at) 45 | optim.set_parameters(model.parameters()) 46 | return optim 47 | 48 | def build_lr_scheduler(optimizer, opt): 49 | 50 | lr_lambda = lambda epoch: opt.learning_rate_decay ** epoch 51 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, 52 | lr_lambda=[lr_lambda]) 53 | return scheduler 54 | 55 | def check_save_model_path(args, opt): 56 | if not os.path.exists(opt.out_dir): 57 | os.makedirs(opt.out_dir) 58 | print('saving config file to %s ...'%(opt.out_dir)) 59 | shutil.copy(args.config, os.path.join(opt.out_dir,'config.yml')) 60 | 61 | def save_per_epoch(model, epoch, opt): 62 | f = open(os.path.join(opt.out_dir,'checkpoint'),'w') 63 | f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch)) 64 | f.close() 65 | model.save_checkpoint(epoch, opt, os.path.join(opt.out_dir,"checkpoint_epoch%d.pkl"%(epoch))) 66 | 67 | def train_model(opt, model, train_iter, valid_iter, fields, optim, lr_scheduler, start_epoch_at): 68 | sys.stdout.flush() 69 | for step_epoch in range(start_epoch_at+1, opt.num_train_epochs): 70 | for batch in train_iter: 71 | model.zero_grad() 72 | I_word, I_word_length = batch.I 73 | D_word, D_word_length = batch.D 74 | target, _ = batch.mask 75 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 76 | preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 77 | preds = preds.squeeze(2) 78 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) 79 | tot = mask.float().sum() 80 | 81 | reserved = target.float().sum() 82 | w1 = (0.5 * tot / reserved).data[0] 83 | w2 = (0.5 * tot / (tot - reserved)).data[0] 84 | #w1, w2 = 1., 1. 85 | weight = torch.FloatTensor(mask.size()).zero_().cuda() 86 | weight.masked_fill_(mask, w2) 87 | weight.masked_fill_(torch.eq(target, 1).data, w1) 88 | 89 | loss = F.binary_cross_entropy(preds, target.float(), weight) 90 | loss.backward() 91 | optim.step() 92 | 93 | loss = 0. 94 | acc = 0. 95 | ntokens = 0. 96 | reserved, targeted, received = 0., 0., 0. 97 | model.eval() 98 | for batch in valid_iter: 99 | I_word, I_word_length = batch.I 100 | D_word, D_word_length = batch.D 101 | target, _ = batch.mask 102 | target= target.float() 103 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 104 | preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 105 | preds = preds.squeeze(2) 106 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() 107 | loss += F.binary_cross_entropy(preds, target, mask, size_average = False).data[0] 108 | ans = torch.ge(preds, 0.5).float() 109 | acc += (torch.eq(ans, target).float().data * mask ).sum() 110 | received += (ans.data * target.data * mask).sum() 111 | reserved += (ans.data * mask).sum() 112 | targeted += (target.data * mask).sum() 113 | ntokens += mask.sum() 114 | print ("epoch: ", step_epoch, "valid_loss: ", loss/ ntokens, "valid_acc: ", acc/ntokens, "precision: ", received/reserved, "recall: ", received/targeted) 115 | 116 | if step_epoch >= opt.start_decay_at: 117 | lr_scheduler.step() 118 | model.train() 119 | save_per_epoch(model, step_epoch, opt) 120 | sys.stdout.flush() 121 | 122 | class vocab_wrapper(object): 123 | def __init__(self, vocab): 124 | self.vocab = vocab 125 | 126 | def get_sentence(idx, vocab): 127 | return ' '.join([vocab.itos[x] for x in idx]) 128 | 129 | def Tensor2List(x, xlen, tgt = False): 130 | y = x.transpose(0, 1).data.tolist() 131 | return [ z[1:l-1] if tgt else z[:l] for z,l in zip(y, xlen) ] 132 | 133 | def output_results(ans, batch, fo, vocab_tgt, for_train = True): 134 | src = Tensor2List(batch.src[0], batch.src[1]) 135 | tgt = Tensor2List(batch.tgt[0], batch.tgt[1], tgt = True) 136 | ref_src = Tensor2List(batch.ref_src[0], batch.ref_src[1]) 137 | #batch.ref_tgt[0].data.masked_fill_( torch.lt((batch.mask[0]).float(), 1.).data , 0) 138 | batch.ref_tgt[0].data.masked_fill_( torch.lt(ans, 1.).data , 0) 139 | ref_tgt = Tensor2List(batch.ref_tgt[0], batch.ref_tgt[1]) 140 | 141 | if not for_train: 142 | for x in ref_tgt: 143 | fo.write(get_sentence(x, vocab_tgt)+'\n') 144 | return 145 | 146 | for x, y, z , w in zip(src, tgt, ref_src, ref_tgt): 147 | a = ' '.join( [str(t) for t in x ]) 148 | b = ' '.join( [str(t) for t in y ]) 149 | c = ' '.join( [str(t) for t in z ]) 150 | d = ' '.join( [str(t) for t in w ]) 151 | fo.write('|'.join([a,b,c,d])+'\n') 152 | 153 | def main(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("-config", type=str) 156 | parser.add_argument("-nmt_dir", type=str) 157 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 158 | parser.add_argument("-valid_file", type=str) 159 | parser.add_argument("-train_file", type=str) 160 | parser.add_argument("-test_file", type = str) 161 | parser.add_argument("-model", type=str) 162 | parser.add_argument("-src_vocab", type = str) 163 | parser.add_argument("-tgt_vocab", type = str) 164 | parser.add_argument("-mode", type = str) 165 | parser.add_argument("-out_file", type = str) 166 | parser.add_argument("-stop_words", type = str, default = None) 167 | parser.add_argument("-for_train", type = bool, default = True) 168 | args = parser.parse_args() 169 | opt = utils.load_hparams(args.config) 170 | 171 | if opt.random_seed > 0: 172 | random.seed(opt.random_seed) 173 | torch.manual_seed(opt.random_seed) 174 | 175 | fields = dict() 176 | vocab_src = Vocab(args.src_vocab, noST = True) 177 | vocab_tgt = Vocab(args.tgt_vocab) 178 | fields['src'] = vocab_wrapper(vocab_src) 179 | fields['tgt'] = vocab_wrapper(vocab_tgt) 180 | 181 | if args.mode == "test": 182 | model = nmt.model_helper.create_template_generator(opt, fields) 183 | if use_cuda: 184 | model = model.cuda() 185 | model.load_checkpoint(args.model) 186 | model.eval() 187 | test = Data_Loader(args.test_file, opt.train_batch_size, train = False, mask_end = True, stop_words = args.stop_words) 188 | fo = open(args.out_file, 'w') 189 | loss, acc, ntokens = 0., 0., 0. 190 | reserved, targeted, received = 0., 0., 0. 191 | for batch in test: 192 | I_word, I_word_length = batch.I 193 | D_word, D_word_length = batch.D 194 | target, _ = batch.mask 195 | target= target.float() 196 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 197 | preds = model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 198 | preds = preds.squeeze(2) 199 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() 200 | loss += F.binary_cross_entropy(preds, target, mask, size_average = False).data[0] 201 | ans = torch.ge(preds, 0.5).float() 202 | output_results(ans, batch, fo, vocab_tgt, args.for_train) 203 | acc += (torch.eq(ans, target).float().data * mask ).sum() 204 | received += (ans.data * target.data * mask).sum() 205 | reserved += (ans.data * mask).sum() 206 | targeted += (target.data * mask).sum() 207 | ntokens += mask.sum() 208 | print ("test_loss: ", loss/ ntokens, "test_acc: ", acc/ntokens, "precision:", received/reserved, "recall: ", received/targeted, "leave percentage", targeted/ntokens) 209 | fo.close() 210 | #x = 1 211 | #while True: 212 | # x = (x+1)%5 213 | return 214 | 215 | train = Data_Loader(args.train_file, opt.train_batch_size, mask_end = True, stop_words = args.stop_words) 216 | valid = Data_Loader(args.valid_file, opt.train_batch_size, mask_end = True, stop_words = args.stop_words) 217 | 218 | # Build model. 219 | 220 | model, start_epoch_at = build_or_load_model(args, opt, fields) 221 | check_save_model_path(args, opt) 222 | 223 | # Build optimizer. 224 | optim = build_optim(model, opt) 225 | lr_scheduler = build_lr_scheduler(optim.optimizer, opt) 226 | 227 | if use_cuda: 228 | model = model.cuda() 229 | 230 | # Do training. 231 | 232 | train_model(opt, model, train, valid, fields, optim, lr_scheduler, start_epoch_at) 233 | print ("DONE") 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /nmt/Translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import nmt 5 | from nmt.utils.data_utils import sequence_mask 6 | import time 7 | 8 | class Translator(object): 9 | """ 10 | Uses a model to translate a batch of sentences. 11 | Args: 12 | model (:obj:`onmt.modules.NMTModel`): 13 | NMT model to use for translation 14 | fields (dict of Fields): data fields 15 | beam_size (int): size of beam to use 16 | n_best (int): number of translations produced 17 | max_length (int): maximum length output to produce 18 | global_scores (:obj:`GlobalScorer`): 19 | object to rescore final translations 20 | cuda (bool): use cuda 21 | beam_trace (bool): trace beam search for debugging 22 | """ 23 | def __init__(self, model, fields, 24 | beam_size, n_best=1, 25 | max_length=100, 26 | global_scorer=None, cuda=False, 27 | beam_trace=False, min_length=0): 28 | self.model = model 29 | # Set model in eval mode. 30 | self.model.eval() 31 | self.fields = fields 32 | self.n_best = n_best 33 | self.max_length = max_length 34 | self.global_scorer = global_scorer 35 | self.beam_size = max(beam_size, n_best) 36 | self.cuda = cuda 37 | self.min_length = min_length 38 | 39 | # for debugging 40 | self.beam_accum = None 41 | if beam_trace: 42 | self.beam_accum = { 43 | "predicted_ids": [], 44 | "beam_parent_ids": [], 45 | "scores": []} 46 | 47 | def translate_batch(self, src, ref_src, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths, batch = None): 48 | """ 49 | Translate a batch of sentences. 50 | Mostly a wrapper around :obj:`Beam`. 51 | Args: 52 | batch (:obj:`Batch`): a batch from a dataset object 53 | data (:obj:`Dataset`): the dataset object 54 | Todo: 55 | Shouldn't need the original dataset. 56 | """ 57 | 58 | # (0) Prep each of the components of the search. 59 | # And helper method for reducing verbosity. 60 | #last_time = time.time() 61 | beam_size = self.beam_size 62 | batch_size = len(src_lengths) 63 | vocab = self.fields["tgt"].vocab 64 | beam = [nmt.Beam(beam_size, n_best=self.n_best, 65 | cuda=self.cuda, 66 | global_scorer=self.global_scorer, 67 | pad=vocab.stoi[vocab.PAD], 68 | eos=vocab.stoi[vocab.EOT], 69 | bos=vocab.stoi[vocab.EOS], 70 | min_length=self.min_length) 71 | for __ in range(batch_size)] 72 | 73 | # Help functions for working with beams and batches 74 | def var(a): return Variable(a, volatile=True) 75 | 76 | def rvar(a): return var(a.repeat(1, beam_size, 1)) 77 | 78 | def bottle(m): 79 | return m.view(batch_size * beam_size, -1) 80 | 81 | def unbottle(m): 82 | return m.view(beam_size, batch_size, -1) 83 | 84 | # (1) Run the encoder on the src. 85 | model_type = self.model.__class__.__name__ 86 | if model_type == "refNMTModel": 87 | context, enc_states, context_keys, context_mask, src_context, src_mask = self.model.encode(src, ref_src, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths) 88 | dec_states = self.model.init_decoder_state(enc_states, context) 89 | context_mask = context_mask.repeat(beam_size, 1) 90 | context = rvar(context.data) 91 | context_keys = rvar(context_keys.data) 92 | src_context = rvar(src_context.data) 93 | src_mask = src_mask.repeat(beam_size, 1) 94 | if model_type == "vanillaNMTModel": 95 | context, enc_states, context_mask = self.model.encode(src, src_lengths) 96 | dec_states = self.model.init_decoder_state(enc_states, context) 97 | context = rvar(context.data) 98 | context_mask = context_mask.repeat(beam_size, 1) 99 | if model_type == "bivanillaNMTModel": 100 | context, enc_states = self.model.encode(src, ref_tgt, src_lengths, ref_tgt_lengths) 101 | dec_states = self.model.init_decoder_state(enc_states, context) 102 | context = rvar(context.data) 103 | if model_type == "evNMTModel": 104 | I_word, I_word_length = batch.I 105 | D_word, D_word_length = batch.D 106 | context, enc_hidden, context_mask, dist = self.model.encode(src, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt, ref_tgt_lengths) 107 | dec_states = self.model.init_decoder_state(enc_hidden, context) 108 | dist = dist.repeat(beam_size, 1) 109 | context_mask = context_mask.repeat(beam_size, 1) 110 | context = rvar(context.data) 111 | if model_type == "responseGenerator": 112 | context, enc_states, context_mask, dist, src_context, src_mask = self.model.encode(src, ref_tgt, src_lengths, ref_tgt_lengths) 113 | dec_states = self.model.init_decoder_state(enc_states, context) 114 | context_mask = context_mask.repeat(beam_size, 1) 115 | context = rvar(context.data) 116 | src_context = rvar(src_context.data) 117 | src_mask = src_mask.repeat(beam_size, 1) 118 | if model_type == "jointTemplateResponseGenerator": 119 | I_word, I_word_length = batch.I 120 | D_word, D_word_length = batch.D 121 | context, enc_states, context_mask, dist, src_context, src_mask, preds = self.model.encode(I_word, I_word_length, D_word, D_word_length, ref_tgt, ref_tgt_lengths, src, src_lengths) 122 | dec_states = self.model.init_decoder_state(enc_states, context) 123 | if dist is not None: 124 | dist = dist.repeat(beam_size, 1) 125 | context_mask = context_mask.repeat(beam_size, 1) 126 | context = rvar(context.data) 127 | src_context = rvar(src_context.data) 128 | src_mask = src_mask.repeat(beam_size, 1) 129 | if model_type == "tem_resNMTModel": 130 | I_word, I_word_length = batch.I 131 | D_word, D_word_length = batch.D 132 | preds, ev = self.model.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt, ref_tgt_lengths, return_ev = True) 133 | preds = preds.squeeze(2) 134 | template, template_lengths = self.model.template_generator.do_mask_and_clean(preds, ref_tgt, ref_tgt_lengths) 135 | context, enc_states, context_mask, dist, src_context, src_mask = self.model.response_generator.encode(src, template, src_lengths, template_lengths, ev) 136 | dec_states = self.model.response_generator.init_decoder_state(enc_states, context) 137 | if dist is not None: 138 | dist = dist.repeat(beam_size, 1) 139 | context_mask = context_mask.repeat(beam_size, 1) 140 | context = rvar(context.data) 141 | src_context = rvar(src_context.data) 142 | src_mask = src_mask.repeat(beam_size, 1) 143 | # (2) Repeat src objects `beam_size` times. 144 | if not isinstance(dec_states, tuple): # GRU 145 | dec_states = Variable(dec_states.data.repeat(1, beam_size, 1)) 146 | else: # LSTM 147 | dec_states = ( 148 | Variable(dec_states[0].data.repeat(1, beam_size, 1)), 149 | Variable(dec_states[1].data.repeat(1, beam_size, 1)), 150 | ) 151 | 152 | # (3) run the decoder to generate sentences, using beam search. 153 | for i in range(self.max_length): 154 | if all((b.done() for b in beam)): 155 | break 156 | 157 | # Construct batch x beam_size nxt words. 158 | # Get all the pending current beam words and arrange for forward. 159 | inp = var(torch.stack([b.get_current_state() for b in beam]) 160 | .t().contiguous().view(1, -1)) 161 | 162 | 163 | # Temporary kludge solution to handle changed dim expectation 164 | # in the decoder 165 | # inp = inp.unsqueeze(2) 166 | 167 | # Run one step. 168 | if model_type == "refNMTModel": 169 | dec_out, dec_states, attn = self.model.decode(inp, context_keys, context, dec_states, context_mask, src_context, src_mask) 170 | if model_type == "vanillaNMTModel": 171 | dec_out, dec_states, attn = self.model.decode(inp, context, dec_states, context_mask) 172 | if model_type == "bivanillaNMTModel": 173 | dec_out, dec_states, attn = self.model.decode(inp, context, dec_states) 174 | if model_type == "evNMTModel": 175 | dec_out, dec_states, attn = self.model.decode(inp, context, dec_states, dist, context_mask) 176 | if model_type == "responseGenerator": 177 | dec_out, dec_states, attn = self.model.decode(inp, context, dec_states, None, context_mask, src_context, src_mask) 178 | if model_type == "tem_resNMTModel": 179 | dec_out, dec_states, attn = self.model.response_generator.decode(inp, context, dec_states, dist, context_mask, src_context, src_mask) 180 | if model_type == "jointTemplateResponseGenerator": 181 | dec_out, dec_states, attn = self.model.decode(inp, context, dec_states, dist, context_mask, src_context, src_mask) 182 | dec_out = dec_out.squeeze(0) 183 | 184 | # (b) Compute a vector of batch*beam word scores. 185 | out = self.model.generator(dec_out).data 186 | out = unbottle(out) 187 | # beam x batch_size x tgt_vocab 188 | # (c) Advance each beam. 189 | for j, b in enumerate(beam): 190 | b.advance(out[:, j]) 191 | self.beam_update(j, b.get_current_origin(), beam_size, dec_states) 192 | 193 | # (4) Extract sentences from beam. 194 | ret = self._from_beam(beam) 195 | 196 | 197 | return ret 198 | 199 | def beam_update(self, idx, positions, beam_size, states): 200 | if not isinstance(states, tuple): 201 | states = (states, ) 202 | 203 | for e in states: 204 | sizes = e.size() 205 | br = sizes[1] 206 | sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] 207 | sent_states.data.copy_( 208 | sent_states.data.index_select(1, positions)) 209 | 210 | def _from_beam(self, beam): 211 | ret = {"predictions": [], 212 | "scores": []} 213 | for b in beam: 214 | if self.beam_accum: 215 | self.beam_accum['predicted_ids'].append(torch.stack(b.next_ys[1:]).tolist()) 216 | self.beam_accum['beam_parent_ids'].append(torch.stack(b.prev_ks).tolist()) 217 | self.beam_accum['scores'].append(torch.stack(b.all_scores).tolist()) 218 | 219 | n_best = self.n_best 220 | scores, ks = b.sort_finished(minimum=n_best) 221 | hyps = [] 222 | for i, (times, k) in enumerate(ks[:n_best]): 223 | hyp = b.get_hyp(times, k) 224 | hyps.append(hyp) 225 | ret["predictions"].append(hyps) 226 | ret["scores"].append(scores) 227 | return ret 228 | -------------------------------------------------------------------------------- /joint_train.py: -------------------------------------------------------------------------------- 1 | import nmt.utils.misc_utils as utils 2 | import argparse 3 | import codecs 4 | import os 5 | import shutil 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | from torch import cuda 10 | import nmt 11 | import random 12 | from data import Vocab, Data_Loader 13 | import numpy as np 14 | from maskGAN import sample 15 | from nmt.Trainer import Statistics 16 | from torch.autograd import Variable 17 | use_cuda = True 18 | 19 | def report_func(opt, global_step, epoch, batch, num_batches, 20 | start_time, lr, report_stats): 21 | if batch % opt.steps_per_stats == -1 % opt.steps_per_stats: 22 | report_stats.print_out(epoch, batch+1, num_batches, start_time) 23 | report_stats = nmt.Statistics() 24 | 25 | return report_stats 26 | 27 | def build_or_load_model(args, model_opt, fields): 28 | if args.model_type == "CAS": 29 | model = nmt.model_helper.create_joint_model(model_opt, fields) 30 | if args.rg_model is not None: 31 | model.response_generator.load_checkpoint(args.rg_model) 32 | if args.tg_model is not None: 33 | model.template_generator.load_checkpoint(args.tg_model) 34 | if args.model_type == "JNT": 35 | model = nmt.model_helper.create_joint_template_response_model(model_opt, fields) 36 | if model_opt.use_critic: 37 | critic = nmt.model_helper.create_critic_model(model_opt, fields) 38 | if args.critic_model is not None: 39 | critic.load_checkpoint(args.critic_model) 40 | else: 41 | critic = None 42 | latest_ckpt = nmt.misc_utils.latest_checkpoint(model_opt.out_dir) 43 | start_epoch_at = 0 44 | if model_opt.start_epoch_at is not None: 45 | ckpt = 'checkpoint_epoch%d.pkl'%(model_opt.start_epoch_at) 46 | ckpt = os.path.join(model_opt.out_dir,ckpt) 47 | else: 48 | ckpt = latest_ckpt 49 | 50 | if ckpt: 51 | print('Loding model from %s...'%(ckpt)) 52 | start_epoch_at = model.load_checkpoint(ckpt) 53 | else: 54 | print('Building model...') 55 | print(model) 56 | #model.save_checkpoint(0, model_opt, os.path.join(model_opt.out_dir,"checkpoint_epoch0.noskeleton.pkl")) 57 | return model, critic, start_epoch_at 58 | 59 | 60 | def build_optims_and_schedulers(model, critic, opt): 61 | if model.__class__.__name__ == "jointTemplateResponseGenerator": 62 | optimR = nmt.Optim(opt.optim_method, 63 | opt.learning_rate_R, 64 | opt.max_grad_norm, 65 | opt.learning_rate_decay, 66 | opt.weight_decay, 67 | opt.start_decay_at) 68 | optimR.set_parameters(model.parameters()) 69 | lr_lambda = lambda epoch: opt.learning_rate_decay ** epoch 70 | schedulerR = torch.optim.lr_scheduler.LambdaLR(optimizer=optimR.optimizer, 71 | lr_lambda=[lr_lambda]) 72 | return optimR, schedulerR, None, None, None, None 73 | 74 | optimR = nmt.Optim(opt.optim_method, 75 | opt.learning_rate_R, 76 | opt.max_grad_norm, 77 | opt.learning_rate_decay, 78 | opt.weight_decay, 79 | opt.start_decay_at) 80 | 81 | optimR.set_parameters(model.response_generator.parameters()) 82 | 83 | lr_lambda = lambda epoch: opt.learning_rate_decay ** epoch 84 | schedulerR = torch.optim.lr_scheduler.LambdaLR(optimizer=optimR.optimizer, 85 | lr_lambda=[lr_lambda]) 86 | optimT = nmt.Optim(opt.optim_method, 87 | opt.learning_rate_T, 88 | opt.max_grad_norm, 89 | opt.learning_rate_decay, 90 | opt.weight_decay, 91 | opt.start_decay_at) 92 | optimT.set_parameters(model.template_generator.parameters()) 93 | schedulerT = torch.optim.lr_scheduler.LambdaLR(optimizer=optimT.optimizer, 94 | lr_lambda=[lr_lambda]) 95 | 96 | if critic is not None: 97 | optimC = nmt.Optim(opt.optim_method, 98 | opt.learning_rate_C, 99 | opt.max_grad_norm, 100 | opt.learning_rate_decay, 101 | opt.weight_decay, 102 | opt.start_decay_at) 103 | optimC.set_parameters(critic.parameters()) 104 | schedulerC = torch.optim.lr_scheduler.LambdaLR(optimizer=optimC.optimizer, 105 | lr_lambda=[lr_lambda]) 106 | else: 107 | optimC, schedulerC = None, None 108 | return optimR, schedulerR, optimT, schedulerT, optimC, schedulerC 109 | 110 | def check_save_model_path(args, opt): 111 | if not os.path.exists(opt.out_dir): 112 | os.makedirs(opt.out_dir) 113 | print('saving config file to %s ...'%(opt.out_dir)) 114 | shutil.copy(args.config, os.path.join(opt.out_dir,'config.yml')) 115 | 116 | 117 | def train_model(train_critic, opt, model, critic, train_iter, valid_iter, fields, optimR, lr_schedulerR, optimT, lr_schedulerT, optimC, lr_schedulerC, start_epoch_at): 118 | train_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) 119 | valid_loss = nmt.NMTLossCompute(model.generator, fields['tgt'].vocab) 120 | 121 | if use_cuda: 122 | train_loss = train_loss.cuda() 123 | valid_loss = valid_loss.cuda() 124 | 125 | shard_size = opt.train_shard_size 126 | trainer = nmt.Trainer(opt, model, 127 | train_iter, 128 | valid_iter, 129 | train_loss, 130 | valid_loss, 131 | optimR, 132 | shard_size) 133 | 134 | scorer = nmt.Scorer(model, fields['tgt'].vocab, fields['src'].vocab, train_loss, opt) 135 | num_train_epochs = opt.num_train_epochs 136 | print('start training...') 137 | global_step = 0 138 | for step_epoch in range(start_epoch_at+1, num_train_epochs): 139 | 140 | if step_epoch >= opt.start_decay_at: 141 | lr_schedulerR.step() 142 | if lr_schedulerT is not None: 143 | lr_schedulerT.step() 144 | if lr_schedulerC is not None: 145 | lr_schedulerC.step() 146 | 147 | total_stats = Statistics() 148 | report_stats = Statistics() 149 | for step_batch, batch in enumerate(train_iter): 150 | global_step += 1 151 | if global_step % 6 == -1 % global_step: 152 | T_turn = True 153 | C_turn = False 154 | R_turn = False 155 | else: 156 | T_turn = False 157 | C_turn = False 158 | R_turn = True 159 | 160 | if train_critic: 161 | T_turn = False 162 | C_turn = True 163 | R_turn = False 164 | if C_turn: 165 | model.template_generator.eval() 166 | model.response_generator.eval() 167 | critic.train() 168 | optimC.optimizer.zero_grad() 169 | src_inputs, src_lengths = batch.src 170 | tgt_inputs, tgt_lengths = batch.tgt 171 | ref_src_inputs, ref_src_lengths = batch.ref_src 172 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 173 | I_word, I_word_length = batch.I 174 | D_word, D_word_length = batch.D 175 | preds, ev = model.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev = True) 176 | preds = preds.squeeze(2) 177 | template, template_lengths = model.template_generator.do_mask_and_clean(preds, ref_tgt_inputs, ref_tgt_lengths) 178 | 179 | #x = template.t().data.tolist() 180 | #vocab = fields['tgt'].vocab 181 | #for t in x: 182 | # print ("---", ' '.join([vocab.itos[tt] for tt in t])) 183 | (response, response_length), logp = sample(model.response_generator, src_inputs, None, template, src_lengths, None, template_lengths, max_len = 20) 184 | 185 | enc_embedding = model.response_generator.enc_embedding 186 | dec_embedding = model.response_generator.dec_embedding 187 | inds = np.arange(len(tgt_lengths)) 188 | np.random.shuffle(inds) 189 | inds_tensor = Variable(torch.LongTensor(inds).cuda()) 190 | random_tgt = tgt_inputs.index_select(1, inds_tensor) 191 | random_tgt_len = [tgt_lengths[i] for i in inds] 192 | 193 | #vocab = fields['tgt'].vocab 194 | #vocab_src = fields['src'].vocab 195 | #w = src_inputs.t().data.tolist() 196 | #x = tgt_inputs.t().data.tolist() 197 | #y = response.t().data.tolist() 198 | #z = random_tgt.t().data.tolist() 199 | #for tw, tx, ty, tz in zip(w, x, y, z): 200 | # print (' '.join([vocab_src.itos[tt] for tt in tw]), '|||||', ' '.join([vocab.itos[tt] for tt in tx]), '|||||', ' '.join([vocab.itos[tt] for tt in ty]), '|||||',' '.join([vocab.itos[tt] for tt in tz])) 201 | 202 | x, y, z = critic(enc_embedding(src_inputs),src_lengths, 203 | dec_embedding(tgt_inputs), tgt_lengths, 204 | dec_embedding(response), response_length, 205 | dec_embedding(random_tgt), random_tgt_len 206 | ) 207 | loss = torch.mean(-x) 208 | #print (loss.data[0]) 209 | loss.backward() 210 | optimC.step() 211 | stats = Statistics() 212 | elif T_turn: 213 | model.template_generator.train() 214 | model.response_generator.eval() 215 | critic.eval() 216 | stats = scorer.update(batch, optimT, 'T', sample, critic) 217 | elif R_turn: 218 | if not ( model.__class__.__name__ == "jointTemplateResponseGenerator"): 219 | model.template_generator.eval() 220 | model.response_generator.train() 221 | critic.eval() 222 | if global_step % 2 ==0: 223 | stats = trainer.update(batch) 224 | else: 225 | stats = scorer.update(batch, optimR, 'R', sample, critic) 226 | else: 227 | stats = trainer.update(batch) 228 | report_stats.update(stats) 229 | total_stats.update(stats) 230 | report_func(opt, global_step, step_epoch, step_batch, len(train_iter), total_stats.start_time, optimR.lr, report_stats) 231 | 232 | if critic is not None: 233 | critic.save_checkpoint(step_epoch, opt, os.path.join(opt.out_dir,"checkpoint_epoch_critic%d.pkl"%step_epoch)) 234 | print('Train perplexity: %g' % total_stats.ppl()) 235 | 236 | #2. Validate on the validation set. 237 | valid_stats = trainer.validate() 238 | print('Validation perplexity: %g' % valid_stats.ppl()) 239 | 240 | trainer.epoch_step(step_epoch, out_dir=opt.out_dir) 241 | 242 | model.train() 243 | 244 | class vocab_wrapper(object): 245 | def __init__(self, vocab): 246 | self.vocab = vocab 247 | 248 | def main(): 249 | parser = argparse.ArgumentParser() 250 | parser.add_argument("-config", type=str) 251 | parser.add_argument("-nmt_dir", type=str) 252 | parser.add_argument("-model_type", type=str) 253 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 254 | parser.add_argument("-valid_file", type=str) 255 | parser.add_argument("-train_file", type=str) 256 | parser.add_argument("-train_score", type=str, default= None) 257 | parser.add_argument("-src_vocab", type = str) 258 | parser.add_argument("-tgt_vocab", type = str) 259 | parser.add_argument("-rg_model", type = str, default = None) 260 | parser.add_argument("-tg_model", type = str, default = None) 261 | parser.add_argument("-critic_model", type = str, default = None) 262 | 263 | args = parser.parse_args() 264 | opt = utils.load_hparams(args.config) 265 | cuda.set_device(args.gpuid[0]) 266 | if opt.random_seed > 0: 267 | random.seed(opt.random_seed) 268 | torch.manual_seed(opt.random_seed) 269 | np.random.seed(opt.random_seed) 270 | 271 | fields = dict() 272 | vocab_src = Vocab(args.src_vocab, noST = True) 273 | vocab_tgt = Vocab(args.tgt_vocab) 274 | 275 | fields['src'] = vocab_wrapper(vocab_src) 276 | fields['tgt'] = vocab_wrapper(vocab_tgt) 277 | 278 | mask_end = True 279 | train = Data_Loader(args.train_file, opt.train_batch_size, score = args.train_score, mask_end = mask_end) 280 | valid = Data_Loader(args.valid_file, opt.train_batch_size, mask_end = mask_end) 281 | 282 | # Build model. 283 | model, critic, start_epoch_at = build_or_load_model(args, opt, fields) 284 | check_save_model_path(args, opt) 285 | 286 | # Build optimizer. 287 | optimR, lr_schedulerR, optimT, lr_schedulerT, optimC, lr_schedulerC = build_optims_and_schedulers(model, critic, opt) 288 | 289 | if use_cuda: 290 | model = model.cuda() 291 | if opt.use_critic: 292 | critic = critic.cuda() 293 | 294 | # Do training. 295 | train_critic = (args.critic_model is None) 296 | train_model(train_critic, opt, model, critic, train, valid, fields, optimR, lr_schedulerR, optimT, lr_schedulerT, optimC, lr_schedulerC, start_epoch_at) 297 | 298 | if __name__ == '__main__': 299 | main() 300 | -------------------------------------------------------------------------------- /nmt/Trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import nmt.utils.misc_utils as utils 3 | import torch 4 | from torch.autograd import Variable 5 | import os 6 | import sys 7 | import math 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from nmt.utils.data_utils import sequence_mask 11 | import numpy as np 12 | 13 | class Statistics(object): 14 | """ 15 | Train/validate loss statistics. 16 | """ 17 | def __init__(self, loss=0, n_words=1e-12, n_correct=0): 18 | self.loss = loss 19 | self.n_words = n_words 20 | self.n_correct = n_correct 21 | self.n_src_words = 0 22 | self.start_time = time.time() 23 | 24 | def update(self, stat): 25 | self.loss += stat.loss 26 | self.n_words += stat.n_words 27 | self.n_correct += stat.n_correct 28 | 29 | def ppl(self): 30 | return utils.safe_exp(self.loss / self.n_words) 31 | 32 | def accuracy(self): 33 | return 100 * (self.n_correct / self.n_words) 34 | 35 | def elapsed_time(self): 36 | return time.time() - self.start_time 37 | 38 | def print_out(self, epoch, batch, n_batches, start): 39 | t = self.elapsed_time() 40 | 41 | out_info = ("Epoch %2d, %5d/%5d| acc: %6.2f| ppl: %6.2f| " + \ 42 | "%3.0f tgt tok/s| %4.0f s elapsed") % \ 43 | (epoch, batch, n_batches, 44 | self.accuracy(), 45 | self.ppl(), 46 | self.n_words / (t + 1e-5), 47 | time.time() - self.start_time) 48 | 49 | print(out_info) 50 | sys.stdout.flush() 51 | 52 | class Scorer(object): 53 | def __init__(self, model, tgt_vocab, src_vocab, train_loss, opt): 54 | self.model = model 55 | self.tgt_vocab = tgt_vocab 56 | self.src_vocab = src_vocab 57 | padding_idx = tgt_vocab.stoi[tgt_vocab.PAD] 58 | weight = torch.ones(len(tgt_vocab)) 59 | weight[padding_idx] = 0 60 | self.criterion = nn.NLLLoss(weight, reduce=False).cuda() 61 | self.global_step = 0 62 | self.train_loss = train_loss 63 | self.opt = opt 64 | 65 | def score_batch(self, src, tgt, ref_src, ref_tgt, src_lengths, tgt_lengths, ref_src_lengths, ref_tgt_lengths, normalization = False): 66 | # src : seq_len x batch_size 67 | self.model.eval() 68 | model_type = self.model.__class__.__name__ 69 | if model_type == "vanillaNMTModel": 70 | outputs, attn = self.model(src, tgt[:-1], src_lengths) 71 | if model_type == "tem_resNMTModel": 72 | outputs, attn = self.model.response_generator(src, tgt[:-1], ref_tgt, src_lengths, ref_tgt_lengths) 73 | 74 | 75 | log_probs = self.model.generator(outputs) 76 | 77 | tgt_out = tgt[1:].view(-1) 78 | 79 | batch_size = src.size(1) 80 | 81 | 82 | log_probs = log_probs.view(-1, log_probs.size(2)) 83 | 84 | log_probs = self.criterion(log_probs, tgt_out).view(-1, batch_size).data 85 | logp = (-torch.sum(log_probs, 0)).tolist() 86 | 87 | 88 | if not normalization: 89 | return outputs, logp 90 | 91 | 92 | ret = [] 93 | for lp, l in zip(logp, tgt_lengths): 94 | ret.append( lp / (l+ 1e-12) ) 95 | return outputs, ret 96 | 97 | def update(self, batch, optim, update_what, sample_func = None, critic = None): 98 | optim.optimizer.zero_grad() 99 | src_inputs, src_lengths = batch.src 100 | tgt_inputs, tgt_lengths = batch.tgt 101 | ref_src_inputs, ref_src_lengths = batch.ref_src 102 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 103 | I_word, I_word_length = batch.I 104 | D_word, D_word_length = batch.D 105 | preds, ev = self.model.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev = True) 106 | preds = preds.squeeze(2) 107 | template, template_lengths = self.model.template_generator.do_mask_and_clean(preds, ref_tgt_inputs, ref_tgt_lengths) 108 | 109 | if sample_func is None: 110 | outputs, scores = self.score_batch(src_inputs, tgt_inputs, None, template, src_lengths, tgt_lengths, None, template_lengths, normalization = True) 111 | avg = sum(scores)/len(scores) 112 | scores = [ t-avg for t in scores] 113 | else: 114 | (response, response_length), logp = sample_func(self.model.response_generator, src_inputs, None, template, src_lengths, None, template_lengths, max_len = 20, show_sample = False) 115 | enc_embedding = self.model.response_generator.enc_embedding 116 | dec_embedding = self.model.response_generator.dec_embedding 117 | inds = np.arange(len(tgt_lengths)) 118 | np.random.shuffle(inds) 119 | inds_tensor = Variable(torch.LongTensor(inds).cuda()) 120 | random_tgt = tgt_inputs.index_select(1, inds_tensor) 121 | random_tgt_len = [tgt_lengths[i] for i in inds] 122 | 123 | vocab = self.tgt_vocab 124 | vocab_src = self.src_vocab 125 | w = src_inputs.t().data.tolist() 126 | x = tgt_inputs.t().data.tolist() 127 | y = response.t().data.tolist() 128 | z = random_tgt.t().data.tolist() 129 | for tw, tx, ty, tz, ww, xx, yy, zz in zip(w, x, y, z, src_lengths, tgt_lengths, response_length, random_tgt_len): 130 | print (' '.join([vocab_src.itos[tt] for tt in tw[:ww]]), '|||||', ' '.join([vocab.itos[tt] for tt in tx[1:xx-1]]), '|||||', ' '.join([vocab.itos[tt] for tt in ty[1:yy-1]]), '|||||',' '.join([vocab.itos[tt] for tt in tz[1:zz-1]])) 131 | 132 | x, y, z = critic(enc_embedding(src_inputs), src_lengths, 133 | dec_embedding(tgt_inputs), tgt_lengths, 134 | dec_embedding(response), response_length, 135 | dec_embedding(random_tgt), random_tgt_len 136 | ) 137 | scores = y.data.tolist() 138 | 139 | if update_what == "R": 140 | logp = logp.sum(0) 141 | scores = torch.FloatTensor(scores) 142 | scores = torch.exp(Variable(scores.cuda())) 143 | #print (logp, scores) 144 | loss = -(logp * scores).mean() 145 | print (loss.data[0]) 146 | loss.backward() 147 | optim.step() 148 | stats = Statistics() 149 | return stats 150 | 151 | ans = torch.ge(preds, 0.5) 152 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) 153 | weight = torch.FloatTensor(mask.size()).zero_().cuda() 154 | weight.masked_fill_(mask, 1.) 155 | 156 | 157 | for i,x in enumerate(scores): 158 | weight[:,i] *= x 159 | 160 | loss = F.binary_cross_entropy(preds, Variable(ans.float().data), weight) 161 | 162 | stats = Statistics() #self.train_loss.monolithic_compute_loss(batch, outputs) 163 | loss.backward() 164 | optim.step() 165 | return stats 166 | 167 | def train(self, epoch, train_iter, optim, report_func): 168 | total_stats = Statistics() 169 | report_stats = Statistics() 170 | for step_batch, batch in enumerate(train_iter): 171 | self.global_step += 1 172 | stats = self.update(batch, optim) 173 | report_stats.update(stats) 174 | total_stats.update(stats) 175 | if report_func is not None: 176 | report_stats = report_func(self.opt, self.global_step, 177 | epoch, step_batch, len(train_iter), 178 | total_stats.start_time, optim.lr, report_stats) 179 | return total_stats 180 | 181 | class Trainer(object): 182 | def __init__(self, opt, model, train_iter, valid_iter, 183 | train_loss, valid_loss, optim, shard_size=32): 184 | 185 | self.opt = opt 186 | self.model = model 187 | self.train_iter = train_iter 188 | self.valid_iter = valid_iter 189 | self.train_loss = train_loss 190 | self.valid_loss = valid_loss 191 | self.optim = optim 192 | 193 | self.shard_size = shard_size 194 | 195 | # Set model in training mode. 196 | self.model.train() 197 | 198 | self.global_step = 0 199 | self.step_epoch = 0 200 | 201 | def update(self, batch): 202 | self.model.zero_grad() 203 | src_inputs, src_lengths = batch.src 204 | tgt_inputs = batch.tgt[0][:-1] 205 | 206 | ref_src_inputs, ref_src_lengths = batch.ref_src 207 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 208 | 209 | model_type = self.model.__class__.__name__ 210 | if model_type == "vanillaNMTModel": 211 | outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths) 212 | if model_type == "bivanillaNMTModel": 213 | outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths) 214 | if model_type == "refNMTModel": 215 | outputs, attn, outputs_f = self.model(src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 216 | if model_type == "evNMTModel": 217 | I_word, I_word_length = batch.I 218 | D_word, D_word_length = batch.D 219 | outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 220 | if model_type == "responseGenerator": 221 | outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths) 222 | if model_type == "tem_resNMTModel": 223 | I_word, I_word_length = batch.I 224 | D_word, D_word_length = batch.D 225 | outputs, attn = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) 226 | if model_type == "jointTemplateResponseGenerator": 227 | I_word, I_word_length = batch.I 228 | D_word, D_word_length = batch.D 229 | target, _ = batch.mask 230 | 231 | outputs, attn, preds = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) 232 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1) 233 | tot = mask.float().sum() 234 | 235 | reserved = target.float().sum() 236 | w1 = (0.5 * tot / reserved).data[0] 237 | w2 = (0.5 * tot / (tot - reserved)).data[0] 238 | #w1, w2 = 1., 1. 239 | weight = torch.FloatTensor(mask.size()).zero_().cuda() 240 | weight.masked_fill_(mask, w2) 241 | weight.masked_fill_(torch.eq(target, 1).data, w1) 242 | 243 | loss = F.binary_cross_entropy(preds, target.float(), weight) 244 | loss.backward(retain_graph = True) 245 | if batch.score is not None: 246 | score = Variable(torch.FloatTensor(batch.score)).cuda() 247 | else: 248 | score = None 249 | 250 | stats = self.train_loss.sharded_compute_loss(batch, outputs, self.shard_size, weight = score ) 251 | 252 | self.optim.step() 253 | return stats 254 | 255 | def train(self, epoch, report_func=None): 256 | """ Called for each epoch to train. """ 257 | total_stats = Statistics() 258 | report_stats = Statistics() 259 | 260 | for step_batch, batch in enumerate(self.train_iter): 261 | self.global_step += 1 262 | stats = self.update(batch) 263 | 264 | report_stats.update(stats) 265 | total_stats.update(stats) 266 | 267 | if report_func is not None: 268 | report_stats = report_func(self.opt, self.global_step, 269 | epoch, step_batch, len(self.train_iter), 270 | total_stats.start_time, self.optim.lr, report_stats) 271 | 272 | 273 | return total_stats 274 | 275 | def validate(self): 276 | self.model.eval() 277 | valid_stats = Statistics() 278 | 279 | for batch in self.valid_iter: 280 | src_inputs, src_lengths = batch.src 281 | tgt_inputs = batch.tgt[0][:-1] 282 | 283 | ref_src_inputs, ref_src_lengths = batch.ref_src 284 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 285 | 286 | model_type = self.model.__class__.__name__ 287 | if model_type == "vanillaNMTModel": 288 | outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths) 289 | if model_type == "bivanillaNMTModel": 290 | outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_input, src_lengths, ref_tgt_lengths) 291 | if model_type == "refNMTModel": 292 | outputs, attn, outputs_f = self.model(src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 293 | if model_type == "evNMTModel": 294 | I_word, I_word_length = batch.I 295 | D_word, D_word_length = batch.D 296 | outputs, attn = self.model(src_inputs, tgt_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 297 | if model_type == "responseGenerator": 298 | outputs, attn = self.model(src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths) 299 | if model_type == "tem_resNMTModel": 300 | I_word, I_word_length = batch.I 301 | D_word, D_word_length = batch.D 302 | outputs, attn = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) 303 | if model_type == "jointTemplateResponseGenerator": 304 | I_word, I_word_length = batch.I 305 | D_word, D_word_length = batch.D 306 | outputs, attn, preds = self.model(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths) 307 | stats = self.valid_loss.monolithic_compute_loss(batch, outputs) 308 | valid_stats.update(stats) 309 | # Set model back to training mode. 310 | self.model.train() 311 | return valid_stats 312 | 313 | def save_per_epoch(self, epoch, out_dir): 314 | f = open(os.path.join(out_dir,'checkpoint'),'w') 315 | f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch)) 316 | f.close() 317 | self.model.save_checkpoint(epoch, self.opt, 318 | os.path.join(out_dir,"checkpoint_epoch%d.pkl"%(epoch))) 319 | 320 | def epoch_step(self, epoch, out_dir): 321 | """ save ckpt """ 322 | self.save_per_epoch(epoch, out_dir) 323 | -------------------------------------------------------------------------------- /maskGAN.py: -------------------------------------------------------------------------------- 1 | import nmt.utils.misc_utils as utils 2 | import argparse 3 | import codecs 4 | import os 5 | import shutil 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import cuda 11 | import nmt 12 | import random 13 | from data import Vocab, Data_Loader, ListsToTensor 14 | from torch.autograd import Variable 15 | import sys 16 | 17 | 18 | use_cuda = True 19 | 20 | class GAN(nn.Module): 21 | def __init__(self, generator, discriminator, critic): 22 | super(GAN, self).__init__() 23 | self.generator = generator 24 | self.discriminator = discriminator 25 | self.critic = critic 26 | 27 | def save_checkpoint(self, epoch, opt, filename): 28 | torch.save({'generator_dict': self.generator.state_dict(), 29 | 'discriminator_dict': self.discriminator.state_dict(), 30 | 'critic_dict': self.critic.state_dict(), 31 | 'opt': opt, 32 | 'epoch': epoch, 33 | }, 34 | filename) 35 | 36 | def load_checkpoint(self, filename): 37 | ckpt = torch.load(filename) 38 | self.generator.load_state_dict(ckpt['generator_dict']) 39 | self.discriminator.load_state_dict(ckpt['discriminator_dict']) 40 | self.critic.load_state_dict(ckpt['critic_dict']) 41 | epoch = ckpt['epoch'] 42 | return epoch 43 | 44 | def sample(model, src, ref_src, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths, max_len, show_sample = False): 45 | model_type = model.__class__.__name__ 46 | if model_type =="refNMTModel": 47 | context, enc_states, context_keys, context_mask, src_context, src_mask = model.encode(src, ref_src, ref_tgt, src_lengths, ref_src_lengths, ref_tgt_lengths) 48 | if model_type == "responseGenerator": 49 | context, enc_states, context_mask, dist, src_context, src_mask = model.encode(src, ref_tgt, src_lengths, ref_tgt_lengths) 50 | dec_states = model.init_decoder_state(enc_states, context) 51 | 52 | vocab = model.fields['tgt'].vocab 53 | EOS_idx = vocab.stoi[vocab.EOS] 54 | PAD_idx = vocab.stoi[vocab.PAD] 55 | EOT_idx = vocab.stoi[vocab.EOT] 56 | batch_size = src.size(1) 57 | 58 | notyet = torch.ByteTensor(batch_size).fill_(1) 59 | inp = Variable(torch.LongTensor(batch_size).fill_(EOS_idx)) 60 | 61 | pad_mask = torch.LongTensor([PAD_idx]) 62 | if use_cuda: 63 | notyet = notyet.cuda() 64 | inp = inp.cuda() 65 | pad_mask = pad_mask.cuda() 66 | 67 | result = [inp] 68 | eps = 1e-12 69 | log_prob= [] 70 | 71 | while notyet.any() and len(result)<= max_len: 72 | inp = inp.unsqueeze(0) 73 | if model_type =="refNMTModel": 74 | dec_out, dec_states, attn = model.decode(inp, context_keys, context, dec_states, context_mask, src_context, src_mask) 75 | if model_type =="responseGenerator": 76 | dec_out, dec_states, attn = model.decode(inp, context, dec_states, None, context_mask, src_context, src_mask) 77 | dec_out = dec_out.squeeze(0) 78 | cur_log_prob = model.generator(dec_out) 79 | cur_log_prob.data.index_fill_(1, pad_mask, -float('inf')) 80 | word_prob = torch.exp(cur_log_prob + eps) 81 | #inp = torch.multinomial(word_prob, 1).squeeze(-1) 82 | _, inp = torch.max(cur_log_prob, -1) 83 | cur_log_prob = torch.gather(cur_log_prob, -1, inp.view(-1, 1)).squeeze(-1) 84 | cur_log_prob.data.masked_fill_(1-notyet, 0.) 85 | log_prob.append(cur_log_prob) 86 | inp.data.masked_fill_( 1-notyet, PAD_idx) # batch_size 87 | result.append(inp) 88 | 89 | endding = torch.eq(inp, EOT_idx) 90 | notyet.masked_fill_(endding.data, 0) 91 | 92 | result = torch.stack(result, 0) 93 | log_prob = torch.stack(log_prob, 0) 94 | 95 | x = result.t().data.tolist() 96 | new_x = [] 97 | for t in x: 98 | new_t = [] 99 | for tt in t: 100 | if tt != PAD_idx: 101 | new_t.append(tt) 102 | new_x.append(new_t) 103 | x = new_x 104 | 105 | if show_sample: 106 | for t in x: 107 | print (' '.join([vocab.itos[tt] for tt in t])) 108 | return ListsToTensor(x, tgt = False), log_prob 109 | 110 | def report_func(opt, global_step, epoch, batch, num_batches, 111 | start_time, lr, report_stats): 112 | """ 113 | This is the user-defined batch-level traing progress 114 | report function. 115 | Args: 116 | epoch(int): current epoch count. 117 | batch(int): current batch count. 118 | num_batches(int): total number of batches. 119 | start_time(float): last report time. 120 | lr(float): current learning rate. 121 | report_stats(Statistics): old Statistics instance. 122 | Returns: 123 | report_stats(Statistics): updated Statistics instance. 124 | """ 125 | if batch % opt.steps_per_stats == -1 % opt.steps_per_stats: 126 | report_stats.print_out(epoch, batch+1, num_batches, start_time) 127 | report_stats = nmt.Statistics() 128 | 129 | return report_stats 130 | 131 | def build_or_load_model(args, model_opt, fields): 132 | if args.model_type == "ref": 133 | generator, discriminator, critic = nmt.model_helper.create_GAN_model(model_opt, fields) 134 | model = GAN(generator, discriminator, critic) 135 | if args.start_point is None: 136 | generator.load_checkpoint("init_point") 137 | discriminator.base_model.load_checkpoint('init_point') 138 | critic.base_model.load_checkpoint('init_point') 139 | else: 140 | model.load_checkpoint(args.start_point) 141 | 142 | latest_ckpt = nmt.misc_utils.latest_checkpoint(model_opt.out_dir) 143 | start_epoch_at = 0 144 | if model_opt.start_epoch_at is not None: 145 | ckpt = 'checkpoint_epoch%d.pkl'%(model_opt.start_epoch_at) 146 | ckpt = os.path.join(model_opt.out_dir,ckpt) 147 | else: 148 | ckpt = latest_ckpt 149 | 150 | if ckpt: 151 | print('Loding model from %s...'%(ckpt)) 152 | start_epoch_at = model.load_checkpoint(ckpt) 153 | else: 154 | print('Building model...') 155 | print(model) 156 | 157 | return model, start_epoch_at 158 | 159 | def build_optims_and_lr_schedulers(model, opt): 160 | optimG = nmt.Optim(opt.optim_method, 161 | opt.learning_rate, 162 | opt.max_grad_norm, 163 | opt.learning_rate_decay, 164 | opt.weight_decay, 165 | opt.start_decay_at) 166 | 167 | optimG.set_parameters(model.generator.parameters()) 168 | 169 | lr_lambda = lambda epoch: opt.learning_rate_decay ** epoch 170 | schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizer=optimG.optimizer, lr_lambda=[lr_lambda]) 171 | optimD = nmt.Optim(opt.optim_method, 172 | opt.learning_rate_D, 173 | opt.max_grad_norm, 174 | opt.learning_rate_decay, 175 | opt.weight_decay, 176 | opt.start_decay_at) 177 | optimD.set_parameters( [ x for x in model.discriminator.parameters() ] + [ y for y in model.critic.parameters()] ) 178 | schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizer=optimD.optimizer, lr_lambda=[lr_lambda]) 179 | return optimG, schedulerG, optimD, schedulerD 180 | 181 | def check_save_model_path(args, opt): 182 | if not os.path.exists(opt.out_dir): 183 | os.makedirs(opt.out_dir) 184 | print('saving config file to %s ...'%(opt.out_dir)) 185 | shutil.copy(args.config, os.path.join(opt.out_dir,'config.yml')) 186 | 187 | def save_per_epoch(model, epoch, opt): 188 | f = open(os.path.join(opt.out_dir,'checkpoint'),'w') 189 | f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch)) 190 | f.close() 191 | model.save_checkpoint(epoch, opt, os.path.join(opt.out_dir,"checkpoint_epoch%d.pkl"%(epoch))) 192 | 193 | def pretrain_discriminators(opt, model, train_iter, valid_iter, fields, optim, lr_scheduler, start_epoch_at): 194 | for step_epoch in range(start_epoch_at+1, opt.num_train_epochs): 195 | for batch in train_iter: 196 | model.zero_grad() 197 | src_inputs, src_lengths = batch.src 198 | tgt_inputs = batch.tgt[0] 199 | ref_src_inputs, ref_src_lengths = batch.ref_src 200 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 201 | (fake_tgt_inputs, _), fake_log_prob = sample(model.generator, src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, opt.max_sample_len) 202 | real_output = model.discriminator(src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 203 | fake_output = model.discriminator(src_inputs, fake_tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 204 | real_output = real_output[1:] 205 | fake_output = fake_output[1:] 206 | 207 | target = torch.ones_like(real_output) 208 | loss_real = F.binary_cross_entropy_with_logits(real_output, target, torch.ne(tgt_inputs[1:], 0).float(), size_average = False) 209 | target = torch.zeros_like(fake_output) 210 | loss_fake = F.binary_cross_entropy_with_logits(fake_output, target, torch.ne(fake_tgt_inputs[1:], 0).float(), size_average = False) 211 | 212 | loss = (loss_real + loss_fake)/ (2 * batch.batch_size) 213 | loss.backward() 214 | optim.step() 215 | save_per_epoch(model, step_epoch, opt) 216 | sys.stdout.flush() 217 | 218 | def G_turn(model, batch, optim, opt): 219 | model.zero_grad() 220 | advantages, log_probs, mask = D_turn(model, batch, None, opt, forG = True) 221 | loss = -(advantages * log_probs) * mask.float() 222 | loss = torch.sum(loss)/ batch.batch_size 223 | loss.backward() 224 | optim.step() 225 | 226 | def D_turn(model, batch, optim, opt, forG = False, show_sample = False): 227 | if not forG: 228 | model.zero_grad() 229 | src_inputs, src_lengths = batch.src 230 | tgt_inputs = batch.tgt[0] 231 | ref_src_inputs, ref_src_lengths = batch.ref_src 232 | ref_tgt_inputs, ref_tgt_lengths = batch.ref_tgt 233 | 234 | if show_sample: 235 | sample(model.generator, src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, opt.max_sample_len, show_sample = True) 236 | return 237 | (fake_tgt_inputs, _), fake_log_prob = sample(model.generator, src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, opt.max_sample_len) 238 | 239 | real_output = model.discriminator(src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 240 | fake_output = model.discriminator(src_inputs, fake_tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 241 | real_output = real_output[1:] 242 | fake_output = fake_output[1:] 243 | 244 | target = torch.ones_like(real_output) 245 | loss_real = F.binary_cross_entropy_with_logits(real_output, target, torch.ne(tgt_inputs[1:], 0).float(), size_average = False) 246 | target = torch.zeros_like(fake_output) 247 | fake_tgt_mask = torch.ne(fake_tgt_inputs[1:], 0) 248 | loss_fake = F.binary_cross_entropy_with_logits(fake_output, target, fake_tgt_mask.float(), size_average = False) 249 | 250 | loss = (loss_real + loss_fake)/ (2 * batch.batch_size) 251 | eps = 1e-12 252 | 253 | estimated_rewards = model.critic(src_inputs, fake_tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths) 254 | estimated_rewards = estimated_rewards[:-1] 255 | 256 | rewards = torch.log(F.sigmoid(fake_output) + eps) 257 | rewards.data.masked_fill_(1 - fake_tgt_mask.data, 0.) 258 | split_rewards = torch.split(rewards, 1, dim = 0) 259 | 260 | sum_rewards = [] 261 | cur = 0. 262 | for r in split_rewards[::-1]: 263 | cur = cur * opt.gamma + r 264 | sum_rewards.append(cur) 265 | sum_rewards = torch.cat(sum_rewards[::-1], 0) 266 | 267 | if forG: 268 | return (sum_rewards - estimated_rewards).detach(), fake_log_prob, fake_tgt_mask 269 | critic_loss = (sum_rewards - estimated_rewards)**2 270 | critic_loss.data.masked_fill_(1 - fake_tgt_mask.data, 0.) 271 | critic_loss = torch.sum(critic_loss)/ batch.batch_size 272 | loss = loss + critic_loss 273 | loss.backward() 274 | optim.step() 275 | 276 | def train_model(opt, model, train_iter, valid_iter, fields, optimG, lr_schedulerG, optimD, lr_schedulerD, start_epoch_at): 277 | num_train_epochs = opt.num_train_epochs 278 | num_updates = 0 279 | print('start training...') 280 | valid_loss = nmt.NMTLossCompute(model.generator.generator,fields['tgt'].vocab) 281 | if use_cuda: 282 | valid_loss = valid_loss.cuda() 283 | shard_size = opt.train_shard_size 284 | trainer = nmt.Trainer(opt, model.generator, train_iter, valid_iter, valid_loss, valid_loss, optimG, lr_schedulerG, shard_size, train_loss_b = None) 285 | 286 | for step_epoch in range(start_epoch_at+1, num_train_epochs): 287 | for batch in train_iter: 288 | if num_updates % (opt.D_turns+1) == -1 % (opt.D_turns+1): 289 | G_turn(model, batch, optimG, opt) 290 | else: 291 | D_turn(model, batch, optimD, opt) 292 | if num_updates % (opt.show_sample_every) == -1 %(opt.show_sample_every): 293 | D_turn(model, batch, optimD, opt, show_sample = True) 294 | num_updates += 1 295 | sys.stdout.flush() 296 | valid_stats = trainer.validate() 297 | print('Validation perplexity: %g' % valid_stats.ppl()) 298 | sys.stdout.flush() 299 | if step_epoch >= opt.start_decay_at: 300 | lr_schedulerD.step() 301 | lr_schedulerG.step() 302 | save_per_epoch(model, step_epoch, opt) 303 | model.train() 304 | 305 | class vocab_wrapper(object): 306 | def __init__(self, vocab): 307 | self.vocab = vocab 308 | 309 | def main(): 310 | parser = argparse.ArgumentParser() 311 | parser.add_argument("-config", type=str) 312 | parser.add_argument("-nmt_dir", type=str) 313 | parser.add_argument("-model_type", type=str) 314 | parser.add_argument('-gpuid', default=[0], nargs='+', type=int) 315 | parser.add_argument("-valid_file", type=str) 316 | parser.add_argument("-train_file", type=str) 317 | parser.add_argument("-train_score", type=str, default= None) 318 | parser.add_argument("-src_vocab", type = str) 319 | parser.add_argument("-tgt_vocab", type = str) 320 | parser.add_argument("-start_point", type = str, default = None) 321 | 322 | args = parser.parse_args() 323 | opt = utils.load_hparams(args.config) 324 | 325 | if opt.random_seed > 0: 326 | random.seed(opt.random_seed) 327 | torch.manual_seed(opt.random_seed) 328 | 329 | fields = dict() 330 | vocab_src = Vocab(args.src_vocab, noST = True) 331 | vocab_tgt = Vocab(args.tgt_vocab) 332 | fields['src'] = vocab_wrapper(vocab_src) 333 | fields['tgt'] = vocab_wrapper(vocab_tgt) 334 | 335 | train = Data_Loader(args.train_file, opt.train_batch_size, score = args.train_score, mask_end = (args.model_type == "ev")) 336 | valid = Data_Loader(args.valid_file, opt.train_batch_size, mask_end = (args.model_type == "ev")) 337 | 338 | # Build model. 339 | 340 | model, start_epoch_at = build_or_load_model(args, opt, fields) 341 | check_save_model_path(args, opt) 342 | 343 | optimG, schedulerG, optimD, schedulerD = build_optims_and_lr_schedulers(model, opt) 344 | 345 | if use_cuda: 346 | model = model.cuda() 347 | 348 | # Do training. 349 | #pretrain_discriminators(opt, model, train, valid, fields, optimD, schedulerD, start_epoch_at) 350 | train_model(opt, model, train, valid, fields, optimG, schedulerG, optimD, schedulerD, start_epoch_at) 351 | print("DONE") 352 | x = 0 353 | while True: 354 | x = (x +1)%5 355 | if __name__ == '__main__': 356 | main() 357 | -------------------------------------------------------------------------------- /nmt/model_helper.py: -------------------------------------------------------------------------------- 1 | from nmt.modules.Encoder import EncoderRNN 2 | from nmt.modules.Decoder import AttnDecoderRNN, KVAttnDecoderRNN, AuxDecoderRNN, AuxMemDecoderRNN 3 | from nmt.modules.Attention import GlobalAttention 4 | from nmt.modules.Embedding import Embedding 5 | from nmt.Model import vanillaNMTModel, refNMTModel, bivanillaNMTModel, editVectorGenerator, templateGenerator, evNMTModel, Discriminator, responseGenerator, tem_resNMTModel, Critic, jointTemplateResponseGenerator 6 | import torch 7 | import torch.nn as nn 8 | 9 | def create_emb_for_encoder_and_decoder(src_vocab_size, 10 | tgt_vocab_size, 11 | src_embed_size, 12 | tgt_embed_size, 13 | padding_idx): 14 | 15 | embedding_encoder = Embedding(src_vocab_size,src_embed_size,padding_idx) 16 | embedding_decoder = Embedding(tgt_vocab_size,tgt_embed_size,padding_idx) 17 | 18 | 19 | return embedding_encoder, embedding_decoder 20 | 21 | 22 | def create_emb_for_encoders_and_decoder(src_vocab_size, 23 | ref_vocab_size, 24 | tgt_vocab_size, 25 | src_embed_size, 26 | ref_embed_size, 27 | tgt_embed_size, 28 | padding_idx): 29 | 30 | embedding_encoder_src = Embedding(src_vocab_size, src_embed_size,padding_idx) 31 | embedding_encoder_ref = Embedding(ref_vocab_size, ref_embed_size,padding_idx) 32 | embedding_decoder = Embedding(tgt_vocab_size, tgt_embed_size,padding_idx) 33 | 34 | return embedding_encoder_src, embedding_encoder_ref, embedding_decoder 35 | 36 | def create_encoder(opt): 37 | encoder = EncoderRNN(opt.rnn_type, 38 | opt.embedding_size, 39 | opt.hidden_size, 40 | opt.num_layers, 41 | opt.dropout, 42 | opt.bidirectional) 43 | 44 | return encoder 45 | 46 | def create_decoder(opt): 47 | if opt.decoder_type == 'AttnDecoderRNN': 48 | decoder = AttnDecoderRNN(opt.rnn_type, 49 | opt.atten_model, 50 | opt.embedding_size, 51 | opt.hidden_size, 52 | opt.num_layers, 53 | opt.dropout) 54 | 55 | return decoder 56 | 57 | def create_generator(input_size, output_size): 58 | generator = nn.Sequential( 59 | nn.Linear(input_size, output_size), 60 | nn.LogSoftmax(dim=-1)) 61 | return generator 62 | 63 | def weights_init(m): 64 | if isinstance(m, nn.Linear): 65 | nn.init.xavier_uniform(m.weight.data) 66 | if isinstance(m, nn.LSTM) or isinstance(m, nn.GRU): 67 | for layer in range(m.num_layers): 68 | nn.init.orthogonal(getattr(m,"weight_ih_l%d"%(layer)).data) 69 | nn.init.orthogonal(getattr(m,"weight_hh_l%d"%(layer)).data) 70 | 71 | def create_ref_model(opt, fields): 72 | src_vocab_size = len(fields['src'].vocab) 73 | tgt_vocab_size = len(fields['tgt'].vocab) 74 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 75 | enc_embedding, dec_embedding = \ 76 | create_emb_for_encoder_and_decoder(src_vocab_size, 77 | tgt_vocab_size, 78 | opt.embedding_size, 79 | opt.embedding_size, 80 | padding_idx) 81 | encoder_src = create_encoder(opt) 82 | encoder_ref = create_encoder(opt) 83 | decoder_ref = AttnDecoderRNN(opt.rnn_type, 84 | (opt.atten_model if opt.src_attention else "none" ), 85 | opt.embedding_size, 86 | opt.hidden_size, 87 | opt.num_layers, 88 | opt.dropout) 89 | decoder = KVAttnDecoderRNN(opt.rnn_type, 90 | opt.atten_model, 91 | opt.embedding_size, 92 | opt.hidden_size, 93 | opt.num_layers, 94 | opt.dropout, 95 | opt.src_attention, 96 | opt.mem_gate, 97 | opt.gate_vector) 98 | 99 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 100 | model = refNMTModel(enc_embedding, 101 | dec_embedding, 102 | encoder_src, 103 | encoder_ref, 104 | decoder_ref, 105 | decoder, 106 | generator, 107 | fields) 108 | 109 | model.apply(weights_init) 110 | return model 111 | 112 | def create_base_model(opt, fields): 113 | src_vocab_size = len(fields['src'].vocab) 114 | tgt_vocab_size = len(fields['tgt'].vocab) 115 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 116 | enc_embedding, dec_embedding = \ 117 | create_emb_for_encoder_and_decoder(src_vocab_size, 118 | tgt_vocab_size, 119 | opt.embedding_size, 120 | opt.embedding_size, 121 | padding_idx) 122 | encoder = create_encoder(opt) 123 | decoder = create_decoder(opt) 124 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 125 | model = vanillaNMTModel(enc_embedding, 126 | dec_embedding, 127 | encoder, 128 | decoder, 129 | generator) 130 | 131 | model.apply(weights_init) 132 | return model 133 | 134 | def create_bibase_model(opt, fields): 135 | src_vocab_size = len(fields['src'].vocab) 136 | tgt_vocab_size = len(fields['tgt'].vocab) 137 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 138 | enc_embedding, enc_embedding_ref, dec_embedding = \ 139 | create_emb_for_encoders_and_decoder(src_vocab_size, 140 | tgt_vocab_size, 141 | tgt_vocab_size, 142 | opt.embedding_size, 143 | opt.embedding_size, 144 | opt.embedding_size, 145 | padding_idx) 146 | encoder = create_encoder(opt) 147 | encoder_ref = create_encoder(opt) 148 | decoder = create_decoder(opt) 149 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 150 | model = bivanillaNMTModel(enc_embedding, 151 | enc_embedding_ref, 152 | dec_embedding, 153 | encoder, 154 | encoder_ref, 155 | decoder, 156 | generator) 157 | 158 | model.apply(weights_init) 159 | return model 160 | 161 | def create_template_generator(opt, fields): 162 | src_vocab_size = len(fields['src'].vocab) 163 | tgt_vocab_size = len(fields['tgt'].vocab) 164 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 165 | enc_embedding, dec_embedding = \ 166 | create_emb_for_encoder_and_decoder(src_vocab_size, 167 | tgt_vocab_size, 168 | opt.embedding_size, 169 | opt.embedding_size, 170 | padding_idx) 171 | encoder_ref = EncoderRNN("GRU", opt.embedding_size, opt.hidden_size, 1, opt.dropout_ev, opt.bidirectional) 172 | attention_src = GlobalAttention(opt.embedding_size, attn_type="mlp") 173 | attention_ref = GlobalAttention(opt.embedding_size, attn_type="mlp") 174 | 175 | ev_generator = editVectorGenerator(enc_embedding, dec_embedding, encoder_ref, attention_src, attention_ref, nn.Dropout(opt.dropout_ev)) 176 | masker = nn.Sequential(nn.Linear(3*opt.embedding_size, opt.embedding_size), nn.ReLU(), nn.Linear(opt.embedding_size, 1), nn.Sigmoid()) 177 | model = templateGenerator(ev_generator, masker, nn.Dropout(opt.dropout_ev)) 178 | model.apply(weights_init) 179 | return model 180 | 181 | def create_ev_model(opt, fields): 182 | src_vocab_size = len(fields['src'].vocab) 183 | tgt_vocab_size = len(fields['tgt'].vocab) 184 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 185 | enc_embedding, dec_embedding = \ 186 | create_emb_for_encoder_and_decoder(src_vocab_size, 187 | tgt_vocab_size, 188 | opt.embedding_size, 189 | opt.embedding_size, 190 | padding_idx) 191 | encoder_ref = EncoderRNN("GRU", opt.embedding_size, opt.hidden_size, 1, opt.dropout_ev, opt.bidirectional) 192 | attention_src = GlobalAttention(opt.embedding_size, attn_type="mlp") 193 | attention_ref = GlobalAttention(opt.embedding_size, attn_type="mlp") 194 | ev_generator = editVectorGenerator(enc_embedding, dec_embedding, encoder_ref, attention_src, attention_ref, nn.Dropout(opt.dropout_ev)) 195 | bridge = nn.Sequential(nn.Linear(2*opt.embedding_size, opt.aux_size), nn.ReLU()) 196 | encoder_src = create_encoder(opt) 197 | decoder = AuxDecoderRNN(opt.rnn_type, 198 | opt.atten_model, 199 | opt.embedding_size + opt.aux_size, 200 | opt.hidden_size, 201 | opt.num_layers, 202 | opt.dropout) 203 | 204 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 205 | 206 | model = evNMTModel(enc_embedding, 207 | dec_embedding, 208 | encoder_src, 209 | decoder, 210 | generator, 211 | ev_generator, 212 | bridge, 213 | fields) 214 | 215 | model.apply(weights_init) 216 | return model 217 | 218 | def create_response_generator(opt, fields): 219 | src_vocab_size = len(fields['src'].vocab) 220 | tgt_vocab_size = len(fields['tgt'].vocab) 221 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 222 | enc_embedding, dec_embedding = create_emb_for_encoder_and_decoder(src_vocab_size, tgt_vocab_size, 223 | opt.embedding_size, 224 | opt.embedding_size, 225 | padding_idx) 226 | encoder_src = create_encoder(opt) 227 | if opt.use_ev: 228 | decoder = AuxMemDecoderRNN(opt.rnn_type, 229 | opt.atten_model, 230 | opt.embedding_size + opt.aux_size, 231 | opt.hidden_size, 232 | opt.num_layers, 233 | opt.dropout, 234 | opt.src_attention, 235 | opt.mem_gate, 236 | opt.gate_vector) 237 | 238 | bridge = nn.Sequential(nn.Linear(2*opt.embedding_size, opt.aux_size), nn.ReLU()) 239 | else: 240 | decoder = KVAttnDecoderRNN(opt.rnn_type, 241 | opt.atten_model, 242 | opt.embedding_size, 243 | opt.hidden_size, 244 | opt.num_layers, 245 | opt.dropout, 246 | opt.src_attention, 247 | opt.mem_gate, 248 | opt.gate_vector) 249 | bridge = None 250 | encoder_ref = create_encoder(opt) 251 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 252 | model = responseGenerator(enc_embedding, dec_embedding, encoder_src, decoder, generator, encoder_ref, bridge, fields) 253 | model.apply(weights_init) 254 | return model 255 | 256 | def create_joint_model(opt, fields): 257 | template_generator = create_template_generator(opt, fields) 258 | response_generator = create_response_generator(opt, fields) 259 | model = tem_resNMTModel(template_generator, response_generator, opt.use_ev) 260 | return model 261 | 262 | def create_joint_template_response_model(opt, fields): 263 | src_vocab_size = len(fields['src'].vocab) 264 | tgt_vocab_size = len(fields['tgt'].vocab) 265 | padding_idx = fields['src'].vocab.stoi[fields['src'].vocab.PAD] 266 | enc_embedding, dec_embedding = create_emb_for_encoder_and_decoder(src_vocab_size, tgt_vocab_size, 267 | opt.embedding_size, 268 | opt.embedding_size, 269 | padding_idx) 270 | encoder_src = create_encoder(opt) 271 | if opt.use_ev: 272 | decoder = AuxMemDecoderRNN(opt.rnn_type, 273 | opt.atten_model, 274 | opt.embedding_size + opt.aux_size, 275 | opt.hidden_size, 276 | opt.num_layers, 277 | opt.dropout, 278 | opt.src_attention, 279 | opt.mem_gate, 280 | opt.gate_vector) 281 | 282 | bridge = nn.Sequential(nn.Linear(2*opt.embedding_size, opt.aux_size), nn.ReLU()) 283 | else: 284 | decoder = KVAttnDecoderRNN(opt.rnn_type, 285 | opt.atten_model, 286 | opt.embedding_size, 287 | opt.hidden_size, 288 | opt.num_layers, 289 | opt.dropout, 290 | opt.src_attention, 291 | opt.mem_gate, 292 | opt.gate_vector) 293 | bridge = None 294 | encoder_ref = EncoderRNN("GRU", opt.embedding_size, opt.hidden_size, 1, opt.dropout_ev, opt.bidirectional) 295 | attention_src = GlobalAttention(opt.embedding_size, attn_type="mlp") 296 | attention_ref = GlobalAttention(opt.embedding_size, attn_type="mlp") 297 | ev_generator = editVectorGenerator(enc_embedding, dec_embedding, encoder_ref, attention_src, attention_ref, nn.Dropout(opt.dropout_ev)) 298 | generator = create_generator(opt.hidden_size, tgt_vocab_size) 299 | masker = nn.Sequential(nn.Linear(3*opt.embedding_size, opt.embedding_size), nn.ReLU(), nn.Linear(opt.embedding_size, 1), nn.Sigmoid()) 300 | model = jointTemplateResponseGenerator(ev_generator, masker, nn.Dropout(opt.dropout_ev), enc_embedding, dec_embedding, encoder_src, decoder, generator, bridge, fields) 301 | model.apply(weights_init) 302 | return model 303 | 304 | def create_critic_model(opt, fields): 305 | encoder_src = EncoderRNN("GRU", opt.embedding_size, opt.hidden_size, opt.num_layers, opt.dropout, opt.bidirectional) 306 | encoder_tgt = EncoderRNN("GRU", opt.embedding_size, opt.hidden_size, opt.num_layers, opt.dropout, opt.bidirectional) 307 | model = Critic(encoder_src, encoder_tgt, opt.dropout) 308 | model.apply(weights_init) 309 | return model 310 | 311 | def create_GAN_model(opt, fields): 312 | # For now, we only consider the ev model for GAN training 313 | generator = create_ref_model(opt, fields) 314 | disc = create_ref_model(opt, fields) 315 | #critic = create_ref_model(opt, fields) 316 | 317 | discriminator = Discriminator(disc, nn.Linear(opt.hidden_size, 1)) 318 | critic = Discriminator(disc, nn.Linear(opt.hidden_size, 1)) 319 | discriminator.adaptor.apply(weights_init) 320 | critic.adaptor.weight.data.zero_() 321 | critic.adaptor.bias.data.zero_() 322 | return generator, discriminator, critic 323 | -------------------------------------------------------------------------------- /nmt/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from nmt.utils.data_utils import sequence_mask 6 | 7 | def ListsToTensor(xs): 8 | batch_size = len(xs) 9 | lens = [ len(x) for x in xs] 10 | mx_len = max( max(lens),1) 11 | ys = [] 12 | for i, x in enumerate(xs): 13 | y = x + ([0]*(mx_len - lens[i])) 14 | ys.append(y) 15 | 16 | lens = [ max(1, x) for x in lens] 17 | data = Variable(torch.LongTensor(ys).t_()) 18 | 19 | data = data.cuda() 20 | 21 | return (data, lens) 22 | 23 | class editVectorGenerator(nn.Module): 24 | def __init__(self, enc_embedding, dec_embedding, encoder_ref, attention_src, attention_ref, dropout): 25 | super(editVectorGenerator, self).__init__() 26 | self.enc_embedding = enc_embedding 27 | self.dec_embedding = dec_embedding 28 | self.encoder_ref = encoder_ref 29 | self.attention_src = attention_src 30 | self.attention_ref = attention_ref 31 | self.dropout = dropout 32 | 33 | def forward(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths): 34 | 35 | enc_outputs, enc_hidden = self.encoder_ref(self.dec_embedding(ref_tgt_inputs), ref_tgt_lengths, None) 36 | I_context = self.enc_embedding(I_word) 37 | D_context = self.enc_embedding(D_word) 38 | enc_hidden = enc_hidden.squeeze(0) 39 | 40 | I_context = self.dropout(I_context) 41 | D_context = self.dropout(D_context) 42 | enc_hidden = self.dropout(enc_hidden) 43 | 44 | I_context = I_context.transpose(0, 1).contiguous() 45 | D_context = D_context.transpose(0, 1).contiguous() 46 | 47 | I, _ = self.attention_src(enc_hidden, I_context, mask = sequence_mask(I_word_length)) 48 | D, _ = self.attention_ref(enc_hidden, D_context, mask = sequence_mask(D_word_length)) 49 | 50 | return torch.cat([I,D], 1), enc_outputs 51 | 52 | class jointTemplateResponseGenerator(nn.Module): 53 | def __init__(self, ev_generator, masker, masker_dropout, enc_embedding, dec_embedding, encoder_src, decoder, generator, bridge, fields): 54 | super(jointTemplateResponseGenerator, self).__init__() 55 | self.ev_generator = ev_generator 56 | self.masker = masker 57 | self.masker_dropout = masker_dropout 58 | self.enc_embedding = enc_embedding 59 | self.dec_embedding = dec_embedding 60 | self.encoder_src = encoder_src 61 | self.decoder = decoder 62 | self.generator = generator 63 | self.bridge = bridge 64 | self.fields = fields 65 | 66 | 67 | def forward(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths): 68 | ref_contexts, enc_hidden, ref_mask, dist, src_contexts, src_mask, preds = self.encode(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, src_lengths) 69 | 70 | dec_init_hidden = self.init_decoder_state(enc_hidden, ref_contexts) 71 | 72 | dec_outputs , dec_hiddens, attn = self.decode( 73 | tgt_inputs, ref_contexts, dec_init_hidden, dist, ref_mask, src_contexts, src_mask 74 | ) 75 | return dec_outputs, attn, preds 76 | 77 | 78 | def init_decoder_state(self, enc_hidden, context): 79 | return enc_hidden 80 | 81 | def encode(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, src_lengths): 82 | ev, enc_outputs = self.ev_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 83 | ev = self.masker_dropout(ev) 84 | ev_for_return = ev 85 | enc_outputs = self.masker_dropout(enc_outputs) 86 | _, _dim = ev.size() 87 | _len, _batch, _ = enc_outputs.size() 88 | 89 | if self.bridge is not None: 90 | dist = self.bridge(ev) 91 | else: 92 | dist = None 93 | 94 | ev = ev.unsqueeze(0) 95 | ev = ev.expand(_len, _batch, _dim) 96 | preds = self.masker(torch.cat([ev, enc_outputs], 2)) 97 | preds = preds.squeeze(2) 98 | 99 | emb_src = self.enc_embedding(src_inputs) 100 | src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None) 101 | 102 | ref_mask = sequence_mask(ref_tgt_lengths) 103 | src_mask = sequence_mask(src_lengths) 104 | return enc_outputs, enc_hidden, ref_mask, dist, src_contexts, src_mask, preds 105 | 106 | def decode(self, input, context, state, dist, context_mask, src_context, src_context_mask): 107 | emb = self.dec_embedding(input) 108 | if dist is not None: 109 | dec_outputs , dec_hiddens, attn = self.decoder( 110 | emb, context, state, dist, context_mask, src_context, src_context_mask) 111 | else: 112 | dec_outputs , dec_hiddens, attn = self.decoder( 113 | emb, context, context, state, context_mask, src_context, src_context_mask) 114 | 115 | return dec_outputs, dec_hiddens, attn 116 | 117 | def save_checkpoint(self, epoch, opt, filename): 118 | torch.save({ 'ev_generator_dict': self.ev_generator.state_dict(), 119 | 'masker_dict': self.masker.state_dict(), 120 | 'masker_dropout_dict': self.masker_dropout.state_dict(), 121 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 122 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 123 | 'encoder_src_dict': self.encoder_src.state_dict(), 124 | 'decoder_dict': self.decoder.state_dict(), 125 | 'generator_dict': self.generator.state_dict(), 126 | 'epoch': epoch, 127 | 'opt': opt, 128 | # 'bridge_dict': self.bridge.state_dict() 129 | }, filename) 130 | 131 | def load_checkpoint(self, filename): 132 | ckpt = torch.load(filename) 133 | self.ev_generator.load_state_dict(ckpt['ev_generator_dict']) 134 | self.masker.load_state_dict(ckpt['masker_dict']) 135 | self.masker_dropout.load_state_dict(ckpt['masker_dropout_dict']) 136 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 137 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 138 | self.encoder_src.load_state_dict(ckpt['encoder_src_dict']) 139 | self.decoder.load_state_dict(ckpt['decoder_dict']) 140 | self.generator.load_state_dict(ckpt['generator_dict']) 141 | if self.bridge is not None: 142 | self.bridge.load_state_dict(ckpt['bridge_dict']) 143 | epoch = ckpt['epoch'] 144 | return epoch 145 | 146 | class templateGenerator(nn.Module): 147 | def __init__(self, ev_generator, masker, dropout): 148 | super(templateGenerator, self).__init__() 149 | self.ev_generator = ev_generator 150 | self.masker = masker 151 | self.dropout = dropout 152 | 153 | def forward(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev = False): 154 | ev, enc_outputs = self.ev_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 155 | ev = self.dropout(ev) 156 | ev_for_return = ev 157 | enc_outputs = self.dropout(enc_outputs) 158 | _, _dim = ev.size() 159 | _len, _batch, _ = enc_outputs.size() 160 | ev = ev.unsqueeze(0) 161 | ev = ev.expand(_len, _batch, _dim) 162 | preds = self.masker(torch.cat([ev, enc_outputs], 2)) 163 | if return_ev: 164 | return preds, ev_for_return 165 | return preds 166 | 167 | def save_checkpoint(self, epoch, opt, filename): 168 | torch.save({'ev_generator_dict': self.ev_generator.state_dict(), 169 | 'masker_dict': self.masker.state_dict(), 170 | 'opt': opt, 171 | 'epoch': epoch, 172 | }, 173 | filename) 174 | 175 | def load_checkpoint(self, filename): 176 | ckpt = torch.load(filename) 177 | self.ev_generator.load_state_dict(ckpt['ev_generator_dict']) 178 | self.masker.load_state_dict(ckpt['masker_dict']) 179 | epoch = ckpt['epoch'] 180 | return epoch 181 | 182 | def do_mask_and_clean(self, preds, ref_tgt_inputs, ref_tgt_lengths): 183 | mask = sequence_mask(ref_tgt_lengths).transpose(0, 1).float() 184 | ans = torch.ge(preds, 0.5) 185 | ref_tgt_inputs.data.masked_fill_(1-ans.data, 0) 186 | y = ref_tgt_inputs.transpose(0, 1).data.tolist() 187 | data = [ z[:l] for z,l in zip(y, ref_tgt_lengths) ] 188 | new_data = [] 189 | for z in data: 190 | new_z = [] 191 | iszero = False 192 | for w in z: 193 | if iszero and w == 0: 194 | continue 195 | else: 196 | new_z.append(w) 197 | iszero = (w==0) 198 | new_data.append([1] + new_z+ [2]) 199 | return ListsToTensor(new_data) 200 | 201 | class responseGenerator(nn.Module): 202 | def __init__(self, enc_embedding, dec_embedding, encoder_src, decoder, generator, encoder_ref, bridge, fields): 203 | super(responseGenerator, self).__init__() 204 | self.enc_embedding = enc_embedding 205 | self.dec_embedding = dec_embedding 206 | self.encoder_src = encoder_src 207 | self.decoder = decoder 208 | self.generator = generator 209 | self.encoder_ref = encoder_ref 210 | self.bridge = bridge 211 | self.fields = fields 212 | 213 | def forward(self, src_inputs, tgt_inputs, template_inputs, src_lengths, template_lengths, ev = None): 214 | # Run words through encoder 215 | ref_contexts, enc_hidden, ref_mask, dist, src_contexts, src_mask = self.encode(src_inputs, template_inputs, src_lengths, template_lengths, ev) 216 | dec_init_hidden = self.init_decoder_state(enc_hidden, ref_contexts) 217 | dec_outputs , dec_hiddens, attn = self.decode( 218 | tgt_inputs, ref_contexts, dec_init_hidden, dist, ref_mask, src_contexts, src_mask 219 | ) 220 | return dec_outputs, attn 221 | 222 | def encode(self, src_inputs, template_inputs, src_lengths, template_lengths, ev = None): 223 | emb_src = self.enc_embedding(src_inputs) 224 | src_contexts, enc_hidden = self.encoder_src(emb_src, src_lengths, None) 225 | if ev is not None and self.bridge is not None: 226 | dist = self.bridge(ev) 227 | else: 228 | dist = None 229 | 230 | ref_contexts, ref_mask = [], [] 231 | for template_input, template_length in zip(template_inputs, template_lengths): 232 | emb_ref = self.dec_embedding(template_input) 233 | ref_context, _ = self.encoder_ref(emb_ref, template_length) 234 | ref_mask_ = sequence_mask(template_length) 235 | ref_contexts.append(ref_context) 236 | ref_mask.append(ref_mask_) 237 | ref_contexts = torch.cat(ref_contexts, 0) 238 | ref_mask = torch.cat(ref_mask, 1) 239 | src_mask = sequence_mask(src_lengths) 240 | return ref_contexts, enc_hidden, ref_mask, dist, src_contexts, src_mask 241 | 242 | def init_decoder_state(self, enc_hidden, context): 243 | return enc_hidden 244 | 245 | def decode(self, input, context, state, dist, context_mask, src_context, src_context_mask): 246 | emb = self.dec_embedding(input) 247 | if dist is not None: 248 | dec_outputs , dec_hiddens, attn = self.decoder( 249 | emb, context, state, dist, context_mask, src_context, src_context_mask) 250 | else: 251 | dec_outputs , dec_hiddens, attn = self.decoder( 252 | emb, context, context, state, context_mask, src_context, src_context_mask) 253 | 254 | return dec_outputs, dec_hiddens, attn 255 | 256 | def save_checkpoint(self, epoch, opt, filename): 257 | torch.save({'encoder_src_dict': self.encoder_src.state_dict(), 258 | 'decoder_dict': self.decoder.state_dict(), 259 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 260 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 261 | 'generator_dict': self.generator.state_dict(), 262 | 'encoder_ref_dict': self.encoder_ref.state_dict(), 263 | #'bridge_dict': self.bridge.state_dict(), 264 | 'opt': opt, 265 | 'epoch': epoch, 266 | }, 267 | filename) 268 | 269 | def load_checkpoint(self, filename): 270 | ckpt = torch.load(filename) 271 | self.encoder_src.load_state_dict(ckpt['encoder_src_dict']) 272 | self.decoder.load_state_dict(ckpt['decoder_dict']) 273 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 274 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 275 | self.generator.load_state_dict(ckpt['generator_dict']) 276 | self.encoder_ref.load_state_dict(ckpt['encoder_ref_dict']) 277 | #self.bridge.load_state_dict(ckpt['bridge_dict']) 278 | epoch = ckpt['epoch'] 279 | return epoch 280 | 281 | class tem_resNMTModel(nn.Module): 282 | 283 | def __init__(self, template_generator, response_generator, use_ev): 284 | super(tem_resNMTModel, self).__init__() 285 | self.template_generator = template_generator 286 | self.response_generator = response_generator 287 | self.generator = self.response_generator.generator 288 | self.use_ev = use_ev 289 | 290 | def forward(self, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, src_inputs, tgt_inputs, src_lengths): 291 | preds, ev = self.template_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths, return_ev = True) 292 | preds = preds.squeeze(2) 293 | ev = ev.detach() 294 | template_inputs, template_lengths = self.template_generator.do_mask_and_clean(preds, ref_tgt_inputs, ref_tgt_lengths) 295 | 296 | dec_outputs, attn = self.response_generator(src_inputs, tgt_inputs, template_inputs, src_lengths, template_lengths, ev = ( None if not self.use_ev else ev) ) 297 | return dec_outputs, attn 298 | 299 | def save_checkpoint(self, epoch, opt, filename): 300 | torch.save({'template_generator_dict': self.template_generator.state_dict(), 301 | 'response_generator_dict': self.response_generator.state_dict(), 302 | 'opt': opt, 303 | 'epoch': epoch, 304 | }, 305 | filename) 306 | 307 | def load_checkpoint(self, filename): 308 | ckpt = torch.load(filename) 309 | self.template_generator.load_state_dict(ckpt['template_generator_dict']) 310 | self.response_generator.load_state_dict(ckpt['response_generator_dict']) 311 | epoch = ckpt['epoch'] 312 | return epoch 313 | 314 | class vanillaNMTModel(nn.Module): 315 | def __init__(self, enc_embedding, dec_embedding, encoder, decoder, generator): 316 | super(vanillaNMTModel, self).__init__() 317 | self.enc_embedding = enc_embedding 318 | self.dec_embedding = dec_embedding 319 | self.encoder = encoder 320 | self.decoder = decoder 321 | self.generator = generator 322 | 323 | def forward(self, src_inputs, tgt_inputs, src_lengths): 324 | 325 | # Run wrods through encoder 326 | 327 | enc_outputs, enc_hidden, enc_mask = self.encode(src_inputs, src_lengths, None) 328 | 329 | dec_init_hidden = self.init_decoder_state(enc_hidden, enc_outputs) 330 | 331 | dec_outputs, dec_hiddens, attn = self.decode( 332 | tgt_inputs, enc_outputs, dec_init_hidden, enc_mask 333 | ) 334 | 335 | return dec_outputs, attn 336 | 337 | def encode(self, input, lengths=None, hidden=None): 338 | emb = self.enc_embedding(input) 339 | enc_outputs, enc_hidden = self.encoder(emb, lengths, None) 340 | enc_mask = sequence_mask(lengths) 341 | return enc_outputs, enc_hidden, enc_mask 342 | 343 | def init_decoder_state(self, enc_hidden, context): 344 | return enc_hidden 345 | 346 | def decode(self, input, context, state, mask): 347 | emb = self.dec_embedding(input) 348 | 349 | dec_outputs , dec_hiddens, attn = self.decoder( 350 | emb, context, state, mask 351 | ) 352 | 353 | return dec_outputs, dec_hiddens, attn 354 | 355 | def save_checkpoint(self, epoch, opt, filename): 356 | torch.save({'encoder_dict': self.encoder.state_dict(), 357 | 'decoder_dict': self.decoder.state_dict(), 358 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 359 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 360 | 'generator_dict': self.generator.state_dict(), 361 | 'decoder_rnn_dict' : self.decoder.rnn.state_dict(), 362 | 'decoder_attn_dict': self.decoder.attn.state_dict(), 363 | 'opt': opt, 364 | 'epoch': epoch, 365 | }, 366 | filename) 367 | 368 | def load_checkpoint(self, filename): 369 | ckpt = torch.load(filename) 370 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 371 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 372 | self.encoder.load_state_dict(ckpt['encoder_dict']) 373 | self.decoder.load_state_dict(ckpt['decoder_dict']) 374 | self.generator.load_state_dict(ckpt['generator_dict']) 375 | epoch = ckpt['epoch'] 376 | return epoch 377 | 378 | class bivanillaNMTModel(nn.Module): 379 | def __init__(self, enc_embedding, enc_embedding_ref, dec_embedding, encoder, encoder_ref, decoder, generator): 380 | super(bivanillaNMTModel, self).__init__() 381 | self.enc_embedding = enc_embedding 382 | self.dec_embedding = dec_embedding 383 | self.enc_embedding_ref = enc_embedding_ref 384 | self.encoder = encoder 385 | self.encoder_ref = encoder_ref 386 | self.decoder = decoder 387 | self.generator = generator 388 | self.bridge_h = nn.Linear(2*encoder.hidden_size + 2*encoder_ref.hidden_size, 2*encoder.hidden_size) 389 | self.bridge_c = nn.Linear(2*encoder.hidden_size + 2*encoder_ref.hidden_size, 2*encoder.hidden_size) 390 | 391 | def forward(self, src_inputs, tgt_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths): 392 | 393 | # Run words through encoder 394 | 395 | enc_outputs, enc_hidden = self.encode(src_inputs, ref_tgt_inputs, src_lengths, ref_tgt_lengths, None) 396 | 397 | dec_init_hidden = self.init_decoder_state(enc_hidden, enc_outputs) 398 | 399 | dec_outputs , dec_hiddens, attn = self.decode( 400 | tgt_inputs, enc_outputs, dec_init_hidden 401 | ) 402 | 403 | return dec_outputs, attn 404 | 405 | 406 | 407 | def encode(self, src, ref_tgt, src_lengths=None, ref_tgt_lengths = None, hidden=None): 408 | emb = self.enc_embedding(src) 409 | emb_ref = self.enc_embedding_ref(ref_tgt) 410 | enc_outputs, enc_hidden = self.encoder(emb, src_lengths, None) 411 | enc_outputs_x, enc_hidden_x = self.encoder_ref(emb_ref, ref_tgt_lengths, None) 412 | 413 | h = torch.cat([enc_hidden[0], enc_hidden_x[0]], -1) 414 | c = torch.cat([enc_hidden[1], enc_hidden_x[1]], -1) 415 | h = self.bridge_h(h) 416 | c = self.bridge_c(c) 417 | return enc_outputs, (h, c) 418 | 419 | def init_decoder_state(self, enc_hidden, context): 420 | return enc_hidden 421 | 422 | def decode(self, input, context, state): 423 | emb = self.dec_embedding(input) 424 | dec_outputs , dec_hiddens, attn = self.decoder( 425 | emb, context, state 426 | ) 427 | 428 | return dec_outputs, dec_hiddens, attn 429 | def save_checkpoint(self, epoch, opt, filename): 430 | torch.save({'encoder_dict': self.encoder.state_dict(), 431 | 'encoder_ref_dict': self.encoder_ref.state_dict(), 432 | 'decoder_dict': self.decoder.state_dict(), 433 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 434 | "enc_embedding_ref":self.enc_embedding_ref.state_dict(), 435 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 436 | 'generator_dict': self.generator.state_dict(), 437 | 'bridge_h_dict': self.bridge_h.state_dict(), 438 | 'bridge_c_dict': self.bridge_c.state_dict(), 439 | 'opt': opt, 440 | 'epoch': epoch, 441 | }, 442 | filename) 443 | 444 | def load_checkpoint(self, filename): 445 | ckpt = torch.load(filename) 446 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 447 | self.enc_embedding_ref.load_state_dict(ckpt['enc_embedding_ref']) 448 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 449 | self.encoder.load_state_dict(ckpt['encoder_dict']) 450 | self.encoder_ref.load_state_dict(ckpt['encoder_ref_dict']) 451 | self.decoder.load_state_dict(ckpt['decoder_dict']) 452 | self.generator.load_state_dict(ckpt['generator_dict']) 453 | self.bridge_h.load_state_dict(ckpt['bridge_h_dict']) 454 | self.bridge_c.load_state_dict(ckpt['bridge_c_dict']) 455 | epoch = ckpt['epoch'] 456 | return epoch 457 | 458 | class refNMTModel(nn.Module): 459 | def __init__(self, enc_embedding, dec_embedding, encoder_src, encoder_ref, decoder_ref, decoder, generator, fields): 460 | super(refNMTModel, self).__init__() 461 | self.enc_embedding = enc_embedding 462 | self.dec_embedding = dec_embedding 463 | self.encoder_src = encoder_src 464 | self.encoder_ref = encoder_ref 465 | self.decoder_ref = decoder_ref 466 | self.decoder = decoder 467 | self.generator = generator 468 | self.fields = fields 469 | 470 | def forward(self, src_inputs, tgt_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths): 471 | 472 | # Run words through encoder 473 | ref_values, enc_hidden, ref_keys, ref_mask, src_context, src_mask = self.encode(src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, None) 474 | 475 | dec_init_hidden = self.init_decoder_state(enc_hidden, ref_values) 476 | 477 | dec_outputs , dec_hiddens, attn = self.decode( 478 | tgt_inputs, ref_keys, ref_values, dec_init_hidden, ref_mask, src_context, src_mask 479 | ) 480 | return dec_outputs, attn, ref_keys 481 | 482 | def encode(self, src_inputs, ref_src_inputs, ref_tgt_inputs, src_lengths, ref_src_lengths, ref_tgt_lengths, hidden=None): 483 | emb_src = self.enc_embedding(src_inputs) 484 | embs_ref_src = [ self.enc_embedding(ref_src_input) for ref_src_input in ref_src_inputs ] 485 | embs_ref_tgt = [ self.dec_embedding(ref_tgt_input) for ref_tgt_input in ref_tgt_inputs ] 486 | 487 | ref_values, ref_keys, ref_mask = [], [], [] 488 | for emb_ref_src, emb_ref_tgt, ref_src_length, ref_tgt_length in zip(embs_ref_src, embs_ref_tgt, ref_src_lengths, ref_tgt_lengths): 489 | ref_src_context, enc_ref_hidden = self.encoder_src(emb_ref_src, ref_src_length, None) 490 | ref_src_mask = sequence_mask(ref_src_length) 491 | ref_key, _, _ = self.decoder_ref(emb_ref_tgt, ref_src_context, enc_ref_hidden, ref_src_mask) 492 | ref_value, _ = self.encoder_ref(emb_ref_tgt, ref_tgt_length, None) 493 | ref_msk = sequence_mask([ x-1 for x in ref_tgt_length]) 494 | ref_values.append(ref_value[1:]) 495 | ref_keys.append(ref_key[:-1]) 496 | ref_mask.append(ref_msk) 497 | ref_values = torch.cat(ref_values, 0) 498 | ref_keys = torch.cat(ref_keys, 0) 499 | ref_mask = torch.cat(ref_mask, 1) 500 | 501 | src_context, enc_hidden = self.encoder_src(emb_src, src_lengths, None) 502 | src_mask = sequence_mask(src_lengths) 503 | 504 | return ref_values, enc_hidden, ref_keys, ref_mask, src_context, src_mask 505 | 506 | def init_decoder_state(self, enc_hidden, context): 507 | return enc_hidden 508 | 509 | def decode(self, input, context_key, context_value, state, context_mask, src_context, src_mask): 510 | emb = self.dec_embedding(input) 511 | dec_outputs , dec_hiddens, attn = self.decoder( 512 | emb, context_key, context_value, state, context_mask, src_context, src_mask 513 | ) 514 | return dec_outputs, dec_hiddens, attn 515 | 516 | def save_checkpoint(self, epoch, opt, filename): 517 | torch.save({'encoder_src_dict': self.encoder_src.state_dict(), 518 | 'encoder_ref_dict': self.encoder_ref.state_dict(), 519 | "decoder_ref_dict": self.decoder_ref.state_dict(), 520 | 'decoder_dict': self.decoder.state_dict(), 521 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 522 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 523 | 'generator_dict': self.generator.state_dict(), 524 | 'opt': opt, 525 | 'epoch': epoch, 526 | }, 527 | filename) 528 | 529 | def load_checkpoint(self, filename): 530 | ckpt = torch.load(filename) 531 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 532 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 533 | self.encoder_src.load_state_dict(ckpt['encoder_src_dict']) 534 | self.encoder_ref.load_state_dict(ckpt['encoder_ref_dict']) 535 | self.decoder.load_state_dict(ckpt['decoder_dict']) 536 | self.decoder_ref.load_state_dict(ckpt['decoder_ref_dict']) 537 | self.generator.load_state_dict(ckpt['generator_dict']) 538 | epoch = ckpt['epoch'] 539 | return epoch 540 | 541 | class evNMTModel(nn.Module): 542 | def __init__(self, enc_embedding, dec_embedding, encoder_src, decoder, generator, ev_generator, bridge, fields): 543 | super(evNMTModel, self).__init__() 544 | self.enc_embedding = enc_embedding 545 | self.dec_embedding = dec_embedding 546 | self.encoder_src = encoder_src 547 | self.ev_generator = ev_generator 548 | self.decoder = decoder 549 | self.generator = generator 550 | self.bridge = bridge 551 | self.fields = fields 552 | 553 | def forward(self, src_inputs, tgt_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths): 554 | 555 | # Run words through encoder 556 | ref_contexts, enc_hidden, ref_mask, dist = self.encode(src_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 557 | 558 | dec_init_hidden = self.init_decoder_state(enc_hidden, ref_contexts) 559 | dec_outputs , dec_hiddens, attn = self.decode( 560 | tgt_inputs, ref_contexts, dec_init_hidden, dist, ref_mask 561 | ) 562 | return dec_outputs, attn 563 | 564 | def encode(self, src_inputs, src_lengths, I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths): 565 | emb_src = self.enc_embedding(src_inputs) 566 | _, enc_hidden = self.encoder_src(emb_src, src_lengths, None) 567 | ev, ref_contexts = self.ev_generator(I_word, I_word_length, D_word, D_word_length, ref_tgt_inputs, ref_tgt_lengths) 568 | dist = self.bridge(ev) 569 | ref_mask = sequence_mask(ref_tgt_lengths) 570 | return ref_contexts, enc_hidden, ref_mask, dist 571 | 572 | def init_decoder_state(self, enc_hidden, context): 573 | return enc_hidden 574 | 575 | def decode(self, input,context, state, dist, context_mask): 576 | emb = self.dec_embedding(input) 577 | dec_outputs , dec_hiddens, attn = self.decoder( 578 | emb, context ,state, dist, context_mask 579 | ) 580 | return dec_outputs, dec_hiddens, attn 581 | 582 | def save_checkpoint(self, epoch, opt, filename): 583 | torch.save({'encoder_src_dict': self.encoder_src.state_dict(), 584 | 'decoder_dict': self.decoder.state_dict(), 585 | 'enc_embedding_dict': self.enc_embedding.state_dict(), 586 | 'encoder_ref_dict': self.ev_generator.encoder_ref.state_dict(), 587 | 'attention_src_dict': self.ev_generator.attention_src.state_dict(), 588 | 'attention_ref_dict': self.ev_generator.attention_ref.state_dict(), 589 | # 'bridge_dict': self.bridge.state_dict(), 590 | 'dec_embedding_dict': self.dec_embedding.state_dict(), 591 | 'generator_dict': self.generator.state_dict(), 592 | 'opt': opt, 593 | 'epoch': epoch, 594 | }, 595 | filename) 596 | 597 | def load_checkpoint(self, filename): 598 | ckpt = torch.load(filename) 599 | self.enc_embedding.load_state_dict(ckpt['enc_embedding_dict']) 600 | self.dec_embedding.load_state_dict(ckpt['dec_embedding_dict']) 601 | self.encoder_src.load_state_dict(ckpt['encoder_src_dict']) 602 | self.decoder.load_state_dict(ckpt['decoder_dict']) 603 | self.generator.load_state_dict(ckpt['generator_dict']) 604 | #self.bridge.load_state_dict(ckpt['bridge_dict']) 605 | self.ev_generator.encoder_ref.load_state_dict(ckpt['encoder_ref_dict']) 606 | self.ev_generator.attention_src.load_state_dict(ckpt['attention_src_dict']) 607 | self.ev_generator.attention_ref.load_state_dict(ckpt['attention_ref_dict']) 608 | epoch = ckpt['epoch'] 609 | return epoch 610 | 611 | class Critic(nn.Module): 612 | def __init__(self, encoder_src, encoder_tgt, dropout): 613 | super(Critic, self).__init__() 614 | self.encoder_src = encoder_src 615 | self.encoder_tgt = encoder_tgt 616 | self.dropout = nn.Dropout(dropout) 617 | self.linear_out = nn.Linear(encoder_src.hidden_size*2, encoder_tgt.hidden_size*2) 618 | self.log_softmax = nn.LogSoftmax(dim = -1) 619 | 620 | def forward(self, src_inputs_emb, src_lengths, *args): 621 | _, src = self.encoder_src(src_inputs_emb, src_lengths) 622 | assert len(args)%2==0 623 | ret = [] 624 | src = torch.squeeze(src[-1], 0) 625 | src = self.linear_out(self.dropout(src)) 626 | for input_emb, lengths in zip(args[::2], args[1::2]): 627 | _, tgti = self.encoder_tgt(input_emb, lengths) 628 | tgti = torch.squeeze(tgti[-1], 0) 629 | tgti = self.dropout(tgti) 630 | score = torch.sum(src * tgti, 1) 631 | ret.append(score) 632 | cat = torch.stack(ret, 1) 633 | logp = self.log_softmax(cat) 634 | 635 | 636 | _, max_i = torch.max(logp, 1) 637 | print ('-',torch.mean(torch.eq(max_i, 0).float()).data[0]) 638 | print ('-',torch.mean(torch.eq(max_i, 1).float()).data[0]) 639 | print ('-',torch.mean(torch.eq(max_i, 2).float()).data[0]) 640 | x, y, z = torch.split(logp, 1, -1) 641 | return torch.squeeze(x, 1), torch.squeeze(y, 1), torch.squeeze(z, 1) 642 | 643 | def save_checkpoint(self, epoch, opt, filename): 644 | torch.save({ 645 | 'encoder_src_dict': self.encoder_src.state_dict(), 646 | 'encoder_tgt_dict': self.encoder_tgt.state_dict(), 647 | 'linear_out_dict': self.linear_out.state_dict(), 648 | 'opt': opt, 649 | 'epoch': epoch 650 | } ,filename) 651 | 652 | def load_checkpoint(self, filename): 653 | ckpt = torch.load(filename) 654 | self.encoder_src.load_state_dict(ckpt['encoder_src_dict']) 655 | self.encoder_tgt.load_state_dict(ckpt['encoder_tgt_dict']) 656 | self.linear_out.load_state_dict(ckpt['linear_out_dict']) 657 | epoch = ckpt['epoch'] 658 | return epoch 659 | 660 | class Discriminator(nn.Module): 661 | def __init__(self, base_model, adaptor): 662 | super(Discriminator, self).__init__() 663 | self.base_model = base_model 664 | self.adaptor = adaptor 665 | 666 | def forward(self, *args, **kwargs): 667 | outputs, _, _ = self.base_model.forward(*args, **kwargs) 668 | logits = self.adaptor(outputs) 669 | logits = logits.squeeze(-1) 670 | return logits 671 | 672 | def load_base_checkpoint(self, filename): 673 | self.base_model.load_checkpoint(filename) 674 | 675 | --------------------------------------------------------------------------------