├── slides.pdf ├── train_wmt.sh ├── train_iwslt.sh ├── decode_wmt.sh ├── decode_iwslt.sh ├── tune_wmt.sh ├── tune_iwslt.sh ├── rf_wmt.sh ├── rf_iwslt.sh ├── test.py ├── LICENSE ├── LICENSE_nyu ├── README.md ├── distill.py ├── mscoco.py ├── data.py ├── decode.py ├── train.py ├── utils.py ├── run.py └── model.py /slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/RSI-NAT/HEAD/slides.pdf -------------------------------------------------------------------------------- /train_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation 2 | 3 | -------------------------------------------------------------------------------- /train_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation 2 | 3 | -------------------------------------------------------------------------------- /decode_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-ende --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --debug --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /decode_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset iwslt-ende --vocab_size 40000 --gpu 2 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --debug --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /tune_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /tune_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 2 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /rf_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --resume --load_from --batch_size 1024 --eval_every 100 --sample_method stepwise --nat_finetune --train_repeat_dec 2 --dataset wmt14-ende --vocab_size 60000 --gpu 1 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation 2 | 3 | 4 | -------------------------------------------------------------------------------- /rf_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --resume --load_from --batch_size 1024 --eval_every 100 --sample_method stepwise --nat_finetune --train_repeat_dec 2 --dataset iwslt-ende --vocab_size 40000 --gpu 1 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation 2 | 3 | 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ipdb 3 | import torch 4 | import numpy as np 5 | from torch.autograd import Variable 6 | from utils import corrupt_target 7 | 8 | def convert(lst): 9 | vocab = "what I 've come to realize about Afghanistan , and this is something that is often dismissed in the West".split() 10 | dd = {idx+4 : word for idx, word in enumerate(vocab)} 11 | dd[0] = "UNK" 12 | dd[1] = "PAD" 13 | dd[2] = "BOS" 14 | dd[3] = "EOS" 15 | return " ".join( dd[xx] for xx in lst ) 16 | 17 | trg = [ [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 3, 1, 1] ] 18 | decoder_masks = [ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0] ] 19 | weight = float(sys.argv[1]) 20 | 21 | cor_p = sys.argv[2] # repeat / drop / repeat and drop next / swap / add random word 22 | cor_p = [int(xx) for xx in cor_p.split("-")] 23 | cor_p = [xx/sum(cor_p) for xx in cor_p] 24 | 25 | trg = Variable( torch.from_numpy( np.array( trg ) ) ) 26 | decoder_masks = torch.from_numpy( np.array( decoder_masks ) ) 27 | 28 | print ( convert( trg.data.numpy().tolist()[0] ) ) 29 | print ( convert( corrupt_target( trg, decoder_masks, 15, weight, cor_p ).data.numpy().tolist()[0] ) ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, University of Chinese Academy of Sciences (Chenze Shao) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /LICENSE_nyu: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, New York University (Kyunghyun Cho, Jason Lee, Elman Mansimov) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation 2 | ================================== 3 | PyTorch implementation of the models described in the paper [Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation 4 | ](https://arxiv.org/abs/1906.09444 "Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation"). 5 | 6 | Dependencies 7 | ------------------ 8 | ### Python 9 | * Python 3.6 10 | * PyTorch 0.4 11 | * Numpy 12 | * NLTK 13 | * torchtext 14 | * torchvision 15 | * revtok 16 | * multiset 17 | * ipdb 18 | 19 | ### GPU 20 | * CUDA (we recommend using the latest version. The version 8.0 was used in all our experiments.) 21 | 22 | ### Related code 23 | * This code is based on [dl4mt-nonauto](https://github.com/nyu-dl/dl4mt-nonauto "dl4mt-nonauto"). We mainly modified the [`model.py`](https://github.com/ictnlp/RSI-NAT/blob/master/model.py "model.py") (line 1103-1199). 24 | 25 | Downloading Datasets 26 | ------------------ 27 | The original translation corpora can be downloaded from ([IWLST'16 En-De](https://wit3.fbk.eu/), [WMT'16 En-Ro](http://www.statmt.org/wmt16/translation-task.html), [WMT'14 En-De](http://www.statmt.org/wmt14/translation-task.html)). We recommend you to download the preprocessed corpora released in [dl4mt-nonauto](https://github.com/nyu-dl/dl4mt-nonauto/tree/multigpu "dl4mt-nonauto"). 28 | 29 | Before you run the code 30 | ------------------ 31 | Set correct path to data in `data_path()` function located in [`data.py`](https://github.com/ictnlp/RSI-NAT/blob/master/data.py): 32 | 33 | Training New Models 34 | ------------------ 35 | Train a NAT model using the cross-entropy loss. This process usually takes about 10 days. 36 | You can download our pretrained models [here](https://share.weiyun.com/5lnIanI "here") 37 | #### IWSLT 38 | ```bash 39 | $ sh train_iwslt.sh 40 | ``` 41 | 42 | #### WMT14 En-De 43 | ```bash 44 | $ sh rf_wmt.sh 45 | ``` 46 | Finetuning (RF-NAT) 47 | ------------------ 48 | Take a checkpoint pre-trained non-autoregressive model and finetune the checkpoint using the RF-NAT algorithm. This process usually takes about 1 days. 49 | If you want to use the origin REINFORCE, change the flag `--nat_finetune` to `--rf_finetune`. 50 | #### IWSLT 51 | ```bash 52 | $ sh rf_iwslt.sh 53 | ``` 54 | 55 | #### WMT14 En-De 56 | ```bash 57 | $ sh rf_wmt.sh 58 | ``` 59 | Training the Length Prediction Model 60 | ------------------ 61 | Take a finetuned checkpoint and train the length prediction model. This process usually takes about 1 day. 62 | #### IWSLT 63 | ```bash 64 | $ sh tune_iwslt.sh 65 | ``` 66 | 67 | #### WMT14 En-De 68 | ```bash 69 | $ sh tune_wmt.sh 70 | ``` 71 | Decoding 72 | ------------------ 73 | Decode the test set. This process usually takes about 20 seconds. 74 | #### IWSLT 75 | ```bash 76 | $ sh decode_iwslt.sh 77 | ``` 78 | 79 | #### WMT14 En-De 80 | ```bash 81 | $ sh decode_wmt.sh 82 | ``` 83 | Citation 84 | ------------------ 85 | If you find the resources in this repository useful, please consider citing: 86 | ``` 87 | @inproceedings{shao2019retrieving, 88 | title = "Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation", 89 | author = "Shao, Chenze and 90 | Feng, Yang and 91 | Zhang, Jinchao and 92 | Meng, Fandong and 93 | Chen, Xilin and 94 | Zhou, Jie", 95 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 96 | month = jul, 97 | year = "2019", 98 | url = "https://www.aclweb.org/anthology/P19-1288", 99 | pages = "3013--3024", 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import ipdb 3 | import math 4 | import os 5 | import torch 6 | import numpy as np 7 | import time 8 | 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, computeBLEU, Batch, masked_sort, computeGroupBLEU, organise_trg_len_dic, make_decoder_masks, double_source_masks, remove_repeats, remove_repeats_tensor, print_bleu 16 | from time import gmtime, strftime 17 | import copy 18 | from multiset import Multiset 19 | import json 20 | 21 | tokenizer = lambda x: x.replace('@@ ', '').split() 22 | 23 | def distill_model(args, model, dev, evaluate=True, 24 | distill_path=None, names=None, maxsteps=None): 25 | 26 | if not args.no_tqdm: 27 | progressbar = tqdm(total=200, desc='start decoding') 28 | 29 | trg_len_dic = None 30 | 31 | args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(args.f_size, args.beam_size, args.alpha)) 32 | dev.train = False # make iterator volatile=True 33 | 34 | model.eval() 35 | if distill_path is not None: 36 | if args.dataset != "mscoco": 37 | handles = [open(os.path.join(distill_path, name), 'w') for name in names] 38 | else: 39 | distill_annots = [] 40 | distill_filepath = os.path.join(str(distill_path), "train.bpe.fixed.distill") 41 | 42 | 43 | corpus_size = 0 44 | src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] 45 | all_decs = [ [] for idx in range(args.valid_repeat_dec)] 46 | decoded_words, target_words, decoded_info = 0, 0, 0 47 | 48 | attentions = None 49 | decoder = model.decoder[0] if args.model is FastTransformer else model.decoder 50 | pad_id = decoder.field.vocab.stoi[''] 51 | eos_id = decoder.field.vocab.stoi[''] 52 | 53 | curr_time = 0 54 | cum_bs = 0 55 | 56 | for iters, dev_batch in enumerate(dev): 57 | 58 | start_t = time.time() 59 | 60 | if args.dataset != "mscoco": 61 | decoder_inputs, decoder_masks,\ 62 | targets, target_masks,\ 63 | sources, source_masks,\ 64 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 65 | else: 66 | all_captions = dev_batch[1] 67 | all_img_names = dev_batch[2] 68 | dev_batch[1] = dev_batch[1][0] 69 | decoder_inputs, decoder_masks,\ 70 | targets, target_masks,\ 71 | _, source_masks,\ 72 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 73 | 74 | 75 | corpus_size += batch_size 76 | 77 | batch_size, src_len, hsize = encoding[0].size() 78 | 79 | # for now 80 | if type(model) is Transformer: 81 | all_decodings = [] 82 | decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, 83 | beam=args.beam_size, alpha=args.alpha, \ 84 | decoding=True, feedback=attentions) 85 | all_decodings.append( decoding ) 86 | curr_iter = [0] 87 | 88 | used_t = time.time() - start_t 89 | curr_time += used_t 90 | 91 | real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() 92 | if args.dataset != "mscoco": 93 | outputs = [model.output_decoding(d, False) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] 94 | 95 | for s, t, d in zip(outputs[0], outputs[1], outputs[-1]): 96 | #s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') 97 | print(s, file=handles[0], flush=True) 98 | print(t, file=handles[1], flush=True) 99 | print(d, file=handles[2], flush=True) 100 | else: 101 | outputs = [model.output_decoding(d, unbpe=False) for d in [('trg', targets), ('trg', decoding)]] 102 | 103 | for c, (t, d) in enumerate(zip(outputs[0], outputs[1])): 104 | annot = {} 105 | annot['bpes'] = [d] 106 | annot['img_name'] = all_img_names[c] 107 | distill_annots.append(annot) 108 | 109 | json.dump(distill_annots, open(distill_filepath, 'w')) 110 | 111 | if not args.no_tqdm: 112 | progressbar.update(iters) 113 | progressbar.set_description('finishing sentences={}/batches={}, \ 114 | length={}/average iter={}, speed={} sec/batch'.format(\ 115 | corpus_size, iters, src_len, np.mean(np.array(curr_iter)), curr_time / (1 + iters))) 116 | 117 | if args.dataset == "mscoco": 118 | json.dump(distill_annots, open(distill_filepath, 'w')) 119 | 120 | args.logger.info("Total time {}".format((curr_time / float(cum_bs) * 1000))) 121 | -------------------------------------------------------------------------------- /mscoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import _pickle as pickle 7 | import json 8 | import numpy as np 9 | import time 10 | import random 11 | from collections import OrderedDict 12 | import ipdb 13 | 14 | def process_json(dataPath, annFile, max_len=None, size=None): 15 | annPath = os.path.join(dataPath, annFile) 16 | 17 | # load dataset 18 | annots = json.load(open(annPath, 'r')) 19 | if size != None: 20 | annots = annots[:size] 21 | 22 | bpes = [] 23 | features_path = [] 24 | bpe2img = {} 25 | img2bpes = {} 26 | 27 | bpe_i, feature_i = 0, 0 28 | 29 | for annot in annots: 30 | bpes_i = [] 31 | for bpe in annot['bpes']: 32 | len_bpe = len(bpe.split(' ')) 33 | if max_len != None and len_bpe > max_len: 34 | continue 35 | bpes.append(bpe) 36 | bpe2img[bpe_i] = feature_i 37 | bpes_i.append(bpe_i) 38 | bpe_i = bpe_i + 1 39 | img2bpes[feature_i] = bpes_i 40 | img_name = annot['img_name'] + '.npy' 41 | if 'train' in img_name: 42 | load_path = os.path.join(dataPath, 'train2014_features') 43 | elif 'val' in img_name: 44 | load_path = os.path.join(dataPath, 'val2014_features') 45 | else: 46 | sys.exit() 47 | features_path.append(os.path.join(load_path, img_name)) 48 | feature_i = feature_i + 1 49 | 50 | return bpes, features_path, bpe2img, img2bpes 51 | 52 | def minibatch_same_length(lengths, batch_size): 53 | # make sure all of them are integers 54 | all(isinstance(ll, int) for ll in lengths) 55 | 56 | # sort them out 57 | len_unique = np.unique(lengths) 58 | 59 | # indices of unique lengths 60 | len_indices = OrderedDict() 61 | len_counts = OrderedDict() 62 | for ll in len_unique: 63 | len_indices[ll] = np.where(lengths == ll)[0] 64 | len_counts[ll] = len(len_indices[ll]) 65 | 66 | # sort indicies into minibatches 67 | minibatches = [] 68 | len_indices_keys = list(len_indices.keys()) 69 | for k in len_indices_keys: 70 | avg_samples = max(1, int(batch_size / k)) 71 | for j in range(0, len_counts[k], avg_samples): 72 | minibatches.append(len_indices[k][j:j+avg_samples]) 73 | 74 | return minibatches 75 | 76 | class BatchSamplerCaptionsSameLength(object): 77 | def __init__(self, dataset, batch_size): 78 | assert (type(dataset) == CocoCaptionsIndexedCaption) 79 | self.bpes = dataset.bpes 80 | lengths = [] 81 | 82 | for bpe in self.bpes: 83 | len_bpe = len(bpe.split(' ')) 84 | lengths.append(len_bpe) 85 | 86 | self.minibatches = minibatch_same_length(lengths, batch_size) 87 | random.shuffle(self.minibatches) 88 | 89 | def __iter__(self): 90 | # randomly sample minibatch index 91 | for i in range(len(self.minibatches)): 92 | minibatch = self.minibatches[i] 93 | yield minibatch 94 | 95 | def __len__(self): 96 | return len(self.minibatches) 97 | 98 | class BatchSamplerImagesSameLength(object): 99 | def __init__(self, dataset, batch_size): 100 | assert (type(dataset) == CocoCaptionsIndexedImage or type(dataset) == CocoCaptionsIndexedImageDistill) 101 | self.img2bpes = dataset.img2bpes 102 | self.bpes = dataset.bpes 103 | 104 | # calculate average length of 5 captions for each image 105 | lengths = [] 106 | img_keys = self.img2bpes.keys() 107 | for i in img_keys: 108 | length_i = [] 109 | for bpe_i in self.img2bpes[i]: 110 | length_i.append(len(self.bpes[bpe_i].split())) 111 | lengths.append(int(np.mean(np.array(length_i)))) 112 | 113 | self.minibatches = minibatch_same_length(lengths, batch_size) 114 | random.shuffle(self.minibatches) 115 | 116 | 117 | def __iter__(self): 118 | # randomly sample minibatch index 119 | for i in range(len(self.minibatches)): 120 | minibatch = self.minibatches[i] 121 | yield minibatch 122 | 123 | def __len__(self): 124 | return len(self.minibatches) 125 | 126 | # dataset indexed based on images 127 | class CocoCaptionsIndexedImage(data.Dataset): 128 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 129 | self.bpes = bpes 130 | self.features_path = features_path 131 | self.bpe2img = bpe2img 132 | self.img2bpes = img2bpes 133 | 134 | def __getitem__(self, index): 135 | feature = np.float32(np.load(self.features_path[index])) 136 | bpes = [] 137 | for i in self.img2bpes[index]: 138 | bpes.append(self.bpes[i]) 139 | return torch.from_numpy(feature), bpes 140 | 141 | def __len__(self): 142 | return len(self.img2bpes.keys()) 143 | 144 | class CocoCaptionsIndexedImageDistill(data.Dataset): 145 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 146 | self.bpes = bpes 147 | self.features_path = features_path 148 | self.bpe2img = bpe2img 149 | self.img2bpes = img2bpes 150 | 151 | def __getitem__(self, index): 152 | feature = np.float32(np.load(self.features_path[index])) 153 | img_name = self.features_path[index].split('/')[-1].split('.')[0] 154 | bpes = [] 155 | for i in self.img2bpes[index]: 156 | bpes.append(self.bpes[i]) 157 | return torch.from_numpy(feature), bpes, img_name 158 | 159 | def __len__(self): 160 | return len(self.img2bpes.keys()) 161 | 162 | # dataset indexed based on captions 163 | class CocoCaptionsIndexedCaption(data.Dataset): 164 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 165 | self.bpes = bpes 166 | self.features_path = features_path 167 | self.bpe2img = bpe2img 168 | self.img2bpes = img2bpes 169 | 170 | def __getitem__(self, index): 171 | bpe = self.bpes[index] 172 | feature = np.float32(np.load(self.features_path[self.bpe2img[index]])) 173 | return torch.from_numpy(feature), bpe 174 | 175 | def __len__(self): 176 | return len(self.bpe2img.keys()) 177 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import ipdb 3 | import torch 4 | import random 5 | import numpy as np 6 | import _pickle as pickle 7 | import revtok 8 | import os 9 | from itertools import groupby 10 | import getpass 11 | from collections import Counter 12 | 13 | from torch.autograd import Variable 14 | from torchtext import data, datasets 15 | from nltk.translate.gleu_score import sentence_gleu, corpus_gleu 16 | from nltk.translate.bleu_score import closest_ref_length, brevity_penalty, modified_precision, SmoothingFunction 17 | from contextlib import ExitStack 18 | from collections import OrderedDict 19 | import fractions 20 | 21 | from mscoco import CocoCaptionsIndexedImage, CocoCaptionsIndexedCaption, CocoCaptionsIndexedImageDistill, \ 22 | BatchSamplerImagesSameLength, BatchSamplerCaptionsSameLength 23 | from mscoco import process_json 24 | 25 | try: 26 | fractions.Fraction(0, 1000, _normalize=False) 27 | from fractions import Fraction 28 | except TypeError: 29 | from nltk.compat import Fraction 30 | 31 | def data_path(dataset): 32 | if dataset == "iwslt-ende" or dataset == "iwslt-deen": 33 | path="../IWSLT/en-de/" 34 | elif dataset == "wmt14-ende" or dataset == "wmt14-deen": 35 | path="../wmt14/en-de/" 36 | elif dataset == "wmt16-enro" or dataset == "wmt16-roen": 37 | path="../wmt16/en-ro/" 38 | elif dataset == "wmt17-enlv" or dataset == "wmt17-lven": 39 | path="../wmt17/en-lv/" 40 | elif dataset == "mscoco": 41 | path="mscoco" 42 | 43 | return path 44 | 45 | # load the dataset + reversible tokenization 46 | class NormalField(data.Field): 47 | 48 | def reverse(self, batch, unbpe=True): 49 | if not self.batch_first: 50 | batch.t_() 51 | 52 | with torch.cuda.device_of(batch): 53 | batch = batch.tolist() 54 | 55 | batch = [[self.vocab.itos[ind] for ind in ex] for ex in batch] # denumericalize 56 | 57 | def trim(s, t): 58 | sentence = [] 59 | for w in s: 60 | if w == t: 61 | break 62 | sentence.append(w) 63 | return sentence 64 | 65 | batch = [trim(ex, self.eos_token) for ex in batch] # trim past frst eos 66 | def filter_special(tok): 67 | return tok not in (self.init_token, self.pad_token) 68 | 69 | if unbpe: 70 | batch = [" ".join(filter(filter_special, ex)).replace("@@ ","") for ex in batch] 71 | else: 72 | batch = [" ".join(filter(filter_special, ex)) for ex in batch] 73 | return batch 74 | 75 | class MSCOCOVocab(object): 76 | """Simple vocabulary wrapper.""" 77 | def __init__(self): 78 | self.stoi = {} 79 | self.itos = {} 80 | self.idx = 0 81 | 82 | def add_word(self, word): 83 | if not word in self.stoi: 84 | self.stoi[word] = self.idx 85 | self.itos[self.idx] = word 86 | self.idx += 1 87 | 88 | def __call__(self, word): 89 | if not word in self.stoi: 90 | return self.stoi[''] 91 | return self.stoi[word] 92 | 93 | def __len__(self): 94 | return len(self.stoi) 95 | 96 | class MSCOCODataset(object): 97 | def __init__(self, path, batch_size, max_len=None, valid_size=None, distill=False, use_distillation=False): 98 | self.path = path 99 | 100 | if distill: 101 | self.train_data, self.train_sampler = self.prepare_distill_data(path, 'karpathy_split/train.json.bpe.fixed', batch_size, max_len=max_len, size=None) 102 | else: 103 | train_f = 'karpathy_split/train.json.bpe.fixed' 104 | if use_distillation: 105 | train_f = 'karpathy_split/train.json.bpe.fixed.high.distill' 106 | self.train_data, self.train_sampler = self.prepare_train_data(path, train_f, batch_size, max_len=max_len, size=None) 107 | 108 | self.valid_data, self.valid_sampler = self.prepare_test_data(path, 'karpathy_split/valid.json.bpe.fixed', batch_size, max_len=None, size=valid_size) 109 | self.test_data, self.test_sampler = self.prepare_test_data(path, 'karpathy_split/test.json.bpe.fixed', batch_size, max_len=None, size=valid_size) 110 | 111 | self.unk_token = 0 112 | self.pad_token = 1 113 | self.init_token = 2 114 | self.eos_token = 3 115 | 116 | def prepare_train_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 117 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 118 | 119 | # get max len of dataset 120 | self.max_dataset_length = 0 121 | for bpe in bpes: 122 | len_bpe = len(bpe.split(' ')) 123 | if len_bpe > self.max_dataset_length: 124 | self.max_dataset_length = len_bpe 125 | 126 | dataset_captions = CocoCaptionsIndexedCaption(bpes, features_path, bpe2img, img2bpes) 127 | sampler_captions = BatchSamplerCaptionsSameLength(dataset_captions, batch_size=batch_size) 128 | return dataset_captions, sampler_captions 129 | 130 | def prepare_test_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 131 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 132 | 133 | dataset_images = CocoCaptionsIndexedImage(bpes, features_path, bpe2img, img2bpes) 134 | sampler_images = BatchSamplerImagesSameLength(dataset_images, batch_size=batch_size) 135 | return dataset_images, sampler_images 136 | 137 | def prepare_distill_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 138 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 139 | 140 | dataset_images = CocoCaptionsIndexedImageDistill(bpes, features_path, bpe2img, img2bpes) 141 | sampler_images = BatchSamplerImagesSameLength(dataset_images, batch_size=batch_size) 142 | return dataset_images, sampler_images 143 | 144 | 145 | def build_vocab(self): 146 | """Build a simple vocabulary wrapper.""" 147 | from collections import Counter 148 | 149 | bpes = self.train_data.bpes 150 | 151 | counter = Counter() 152 | for bpe in bpes: 153 | counter.update(bpe.split()) 154 | 155 | words = [word for word, cnt in counter.items()] 156 | 157 | # Creates a vocab wrapper and add some special tokens. 158 | # MAKE SURE CONSTANTS ARE CONSISTENT WITH TRANSLATION DATASETS !!! 159 | self.vocab = MSCOCOVocab() 160 | self.vocab.add_word('') 161 | self.vocab.add_word('') 162 | self.vocab.add_word('') 163 | self.vocab.add_word('') 164 | 165 | # Adds the words to the vocabulary. 166 | for i, word in enumerate(words): 167 | self.vocab.add_word(word) 168 | 169 | def reverse(self, batch, unbpe=True): 170 | #batch = batch.t() 171 | with torch.cuda.device_of(batch): 172 | batch = batch.tolist() 173 | batch = [[self.vocab.itos[ind] for ind in ex] for ex in batch] # denumericalize 174 | 175 | def trim(s, t): 176 | sentence = [] 177 | for w in s: 178 | if w == t: 179 | break 180 | sentence.append(w) 181 | return sentence 182 | 183 | batch = [trim(ex, '') for ex in batch] # trim past frst eos 184 | 185 | def filter_special(tok): 186 | return tok not in ('', '') 187 | 188 | #batch = [filter(filter_special, ex) for ex in batch] 189 | if unbpe: 190 | batch = [" ".join(filter(filter_special, ex)).replace("@@ ","") for ex in batch] 191 | else: 192 | batch = [" ".join(filter(filter_special, ex)) for ex in batch] 193 | return batch 194 | 195 | class TranslationDataset(data.Dataset): 196 | """Defines a dataset for machine translation.""" 197 | 198 | @staticmethod 199 | def sort_key(ex): 200 | return data.interleave_keys(len(ex.src), len(ex.trg)) 201 | 202 | def __init__(self, path, exts, fields, **kwargs): 203 | """Create a TranslationDataset given paths and fields. 204 | Arguments: 205 | path: Common prefix of paths to the data files for both languages. 206 | exts: A tuple containing the extension to path for each language. 207 | fields: A tuple containing the fields that will be used for data 208 | in each language. 209 | Remaining keyword arguments: Passed to the constructor of 210 | data.Dataset. 211 | """ 212 | if not isinstance(fields[0], (tuple, list)): 213 | fields = [('src', fields[0]), ('trg', fields[1])] 214 | 215 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 216 | 217 | examples = [] 218 | with open(src_path) as src_file, open(trg_path) as trg_file: 219 | for src_line, trg_line in zip(src_file, trg_file): 220 | src_line, trg_line = src_line.strip(), trg_line.strip() 221 | if src_line != '' and trg_line != '': 222 | examples.append(data.Example.fromlist( 223 | [src_line, trg_line], fields)) 224 | 225 | super(TranslationDataset, self).__init__(examples, fields, **kwargs) 226 | 227 | @classmethod 228 | def splits(cls, path, exts, fields, root='.data', 229 | train='train', validation='val', test='test', **kwargs): 230 | """Create dataset objects for splits of a TranslationDataset. 231 | Arguments: 232 | root: Root dataset storage directory. Default is '.data'. 233 | exts: A tuple containing the extension to path for each language. 234 | fields: A tuple containing the fields that will be used for data 235 | in each language. 236 | train: The prefix of the train data. Default: 'train'. 237 | validation: The prefix of the validation data. Default: 'val'. 238 | test: The prefix of the test data. Default: 'test'. 239 | Remaining keyword arguments: Passed to the splits method of 240 | Dataset. 241 | """ 242 | #path = cls.download(root) 243 | 244 | train_data = None if train is None else cls( 245 | os.path.join(path, train), exts, fields, **kwargs) 246 | val_data = None if validation is None else cls( 247 | os.path.join(path, validation), exts, fields, **kwargs) 248 | test_data = None if test is None else cls( 249 | os.path.join(path, test), exts, fields, **kwargs) 250 | return tuple(d for d in (train_data, val_data, test_data) 251 | if d is not None) 252 | 253 | 254 | class NormalTranslationDataset(TranslationDataset): 255 | """Defines a dataset for machine translation.""" 256 | 257 | def __init__(self, path, exts, fields, load_dataset=False, save_dataset=False, prefix='', **kwargs): 258 | """Create a TranslationDataset given paths and fields. 259 | 260 | Arguments: 261 | path: Common prefix of paths to the data files for both languages. 262 | exts: A tuple containing the extension to path for each language. 263 | fields: A tuple containing the fields that will be used for data 264 | in each language. 265 | Remaining keyword arguments: Passed to the constructor of 266 | data.Dataset. 267 | """ 268 | if not isinstance(fields[0], (tuple, list)): 269 | fields = [('src', fields[0]), ('trg', fields[1])] 270 | 271 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 272 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 273 | examples = pickle.load(open(path + '.processed.{}.pt'.format(prefix), "rb")) 274 | print ("Loaded TorchText dataset") 275 | else: 276 | examples = [] 277 | with open(src_path,encoding='utf-8') as src_file, open(trg_path,encoding='utf-8') as trg_file: 278 | for src_line, trg_line in zip(src_file, trg_file): 279 | src_line, trg_line = src_line.strip(), trg_line.strip() 280 | if src_line != '' and trg_line != '': 281 | examples.append(data.Example.fromlist( 282 | [src_line, trg_line], fields)) 283 | if save_dataset: 284 | pickle.dump(examples, open(path + '.processed.{}.pt'.format(prefix), "wb")) 285 | print ("Saved TorchText dataset") 286 | 287 | super(TranslationDataset, self).__init__(examples, fields, **kwargs) 288 | 289 | class TripleTranslationDataset(datasets.TranslationDataset): 290 | """Define a triple-translation dataset: src, trg, dec(output of a pre-trained teacher)""" 291 | 292 | def __init__(self, path, exts, fields, load_dataset=False, prefix='', **kwargs): 293 | if not isinstance(fields[0], (tuple, list)): 294 | fields = [('src', fields[0]), ('trg', fields[1]), ('dec', fields[2])] 295 | 296 | src_path, trg_path, dec_path = tuple(os.path.expanduser(path + x) for x in exts) 297 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 298 | examples = torch.load(path + '.processed.{}.pt'.format(prefix)) 299 | else: 300 | examples = [] 301 | with open(src_path) as src_file, open(trg_path) as trg_file, open(dec_path) as dec_file: 302 | for src_line, trg_line, dec_line in zip(src_file, trg_file, dec_file): 303 | src_line, trg_line, dec_line = src_line.strip(), trg_line.strip(), dec_line.strip() 304 | if src_line != '' and trg_line != '' and dec_line != '': 305 | examples.append(data.Example.fromlist( 306 | [src_line, trg_line, dec_line], fields)) 307 | if load_dataset: 308 | torch.save(examples, path + '.processed.{}.pt'.format(prefix)) 309 | 310 | super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 311 | 312 | class ParallelDataset(datasets.TranslationDataset): 313 | """ Define a N-parallel dataset: supports abitriry numbers of input streams""" 314 | 315 | def __init__(self, path=None, exts=None, fields=None, 316 | load_dataset=False, prefix='', examples=None, **kwargs): 317 | 318 | if examples is None: 319 | assert len(exts) == len(fields), 'N parallel dataset must match' 320 | self.N = len(fields) 321 | 322 | paths = tuple(os.path.expanduser(path + x) for x in exts) 323 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 324 | examples = torch.load(path + '.processed.{}.pt'.format(prefix)) 325 | else: 326 | examples = [] 327 | with ExitStack() as stack: 328 | files = [stack.enter_context(open(fname)) for fname in paths] 329 | for lines in zip(*files): 330 | lines = [line.strip() for line in lines] 331 | if not any(line == '' for line in lines): 332 | examples.append(data.Example.fromlist(lines, fields)) 333 | if load_dataset: 334 | torch.save(examples, path + '.processed.{}.pt'.format(prefix)) 335 | 336 | super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 337 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import ipdb 3 | import math 4 | import os 5 | import torch 6 | import numpy as np 7 | import time 8 | 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, computeBLEU, computeBLEUMSCOCO, Batch, masked_sort, computeGroupBLEU, organise_trg_len_dic, make_decoder_masks, \ 16 | double_source_masks, remove_repeats, remove_repeats_tensor, print_bleu, oracle_converged, equality_converged, jaccard_converged 17 | from time import gmtime, strftime 18 | import copy 19 | from multiset import Multiset 20 | 21 | tokenizer = lambda x: x.replace('@@ ', '').split() 22 | 23 | def run_fast_transformer(decoder_inputs, decoder_masks,\ 24 | sources, source_masks,\ 25 | targets,\ 26 | encoding,\ 27 | model, args, use_argmax=True): 28 | 29 | trg_unidx = model.output_decoding( ('trg', targets),unbpe=False) 30 | src_unidx = model.output_decoding( ('src', sources),unbpe=False) 31 | batch_size, src_len, hsize = encoding[0].size() 32 | #s = open("decoding/dec_source","a") 33 | #r = open("decoding/dec_ref","a") 34 | #l = len(src_unidx) 35 | #for i in range(l): 36 | # s.write(src_unidx[i]+'\n') 37 | # r.write(trg_unidx[i]+'\n') 38 | all_decodings = [] 39 | all_probs = [] 40 | iter_ = 0 41 | bleu_hist = [ [] for xx in range(batch_size) ] 42 | output_hist = [ [] for xx in range(batch_size) ] 43 | multiset_hist = [ [] for xx in range(batch_size) ] 44 | num_iters = [ 0 for xx in range(batch_size) ] 45 | done_ = [False for xx in range(batch_size)] 46 | final_decoding = [ None for xx in range(batch_size) ] 47 | 48 | while True: 49 | curr_iter = min(iter_, args.num_decs-1) 50 | next_iter = min(iter_+1, args.num_decs-1) 51 | 52 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, 53 | decoding=True, return_probs=True, iter_=curr_iter) 54 | 55 | dec_output = decoding.data.cpu().numpy().tolist() 56 | #out_unidx = model.output_decoding( ('trg', decoding ),unbpe=False ) 57 | #o = open("decoding/decode_out" + str(iter_),"a") 58 | #l = len(src_unidx) 59 | #for i in range(l): 60 | # o.write(out_unidx[i]+'\n') 61 | 62 | """ 63 | if args.trg_len_option != "reference": 64 | decoder_masks = 0. * decoder_masks 65 | for bidx in range(batch_size): 66 | try: 67 | decoder_masks[bidx,:(dec_output[bidx].index(3))+1] = 1. 68 | except: 69 | decoder_masks[bidx,:] = 1. 70 | """ 71 | 72 | if args.adaptive_decoding == "oracle": 73 | out_unidx = model.output_decoding( ('trg', decoding ) ) 74 | sentence_bleus = computeBLEU(out_unidx, trg_unidx, corpus=False, tokenizer=tokenizer) 75 | 76 | for bidx in range(batch_size): 77 | output_hist[bidx].append( dec_output[bidx] ) 78 | bleu_hist[bidx].append(sentence_bleus[bidx]) 79 | 80 | converged = oracle_converged( bleu_hist, num_items=args.adaptive_window ) 81 | for bidx in range(batch_size): 82 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 83 | num_iters[bidx] = iter_ + 1 - (args.adaptive_window -1) 84 | done_[bidx] = True 85 | final_decoding[bidx] = output_hist[bidx][-args.adaptive_window] 86 | 87 | elif args.adaptive_decoding == "equality": 88 | for bidx in range(batch_size): 89 | #if 3 in dec_output[bidx]: 90 | # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] 91 | output_hist[bidx].append( dec_output[bidx] ) 92 | 93 | converged = equality_converged( output_hist, num_items=args.adaptive_window ) 94 | 95 | for bidx in range(batch_size): 96 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 97 | num_iters[bidx] = iter_ + 1 98 | done_[bidx] = True 99 | final_decoding[bidx] = output_hist[bidx][-1] 100 | 101 | elif args.adaptive_decoding == "jaccard": 102 | for bidx in range(batch_size): 103 | #if 3 in dec_output[bidx]: 104 | # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] 105 | output_hist[bidx].append( dec_output[bidx] ) 106 | multiset_hist[bidx].append( Multiset(dec_output[bidx]) ) 107 | 108 | converged = jaccard_converged( multiset_hist, num_items=args.adaptive_window ) 109 | 110 | for bidx in range(batch_size): 111 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 112 | num_iters[bidx] = iter_ + 1 113 | done_[bidx] = True 114 | final_decoding[bidx] = output_hist[bidx][-1] 115 | 116 | all_decodings.append( decoding ) 117 | all_probs.append(probs) 118 | 119 | decoder_inputs = 0 120 | if args.next_dec_input in ["both", "emb"]: 121 | if use_argmax: 122 | _, argmax = torch.max(probs, dim=-1) 123 | else: 124 | probs_sz = probs.size() 125 | probs_ = Variable(probs.data, requires_grad=False) 126 | argmax = torch.multinomial(probs_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) 127 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 128 | decoder_inputs += emb 129 | 130 | if args.next_dec_input in ["both", "out"]: 131 | decoder_inputs += out 132 | 133 | iter_ += 1 134 | if iter_ == args.valid_repeat_dec or (False not in done_): 135 | break 136 | 137 | if args.adaptive_decoding != None: 138 | for bidx in range(batch_size): 139 | if num_iters[bidx] == 0: 140 | num_iters[bidx] = 20 141 | if final_decoding[bidx] == None: 142 | if args.adaptive_decoding == "oracle": 143 | final_decoding[bidx] = output_hist[bidx][np.argmax(bleu_hist[bidx])] 144 | else: 145 | final_decoding[bidx] = output_hist[bidx][-1] 146 | 147 | decoding = Variable(torch.LongTensor(np.array(final_decoding))) 148 | if decoder_masks.is_cuda: 149 | decoding = decoding.cuda() 150 | 151 | return decoding, all_decodings, num_iters, all_probs 152 | 153 | def decode_model(args, model, dev, evaluate=True, trg_len_dic=None, 154 | decoding_path=None, names=None, maxsteps=None): 155 | 156 | args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(args.f_size, args.beam_size, args.alpha)) 157 | dev.train = False # make iterator volatile=True 158 | 159 | if not args.no_tqdm: 160 | progressbar = tqdm(total=200, desc='start decoding') 161 | 162 | model.eval() 163 | if not args.debug: 164 | decoding_path.mkdir(parents=True, exist_ok=True) 165 | handles = [(decoding_path / name ).open('w') for name in names] 166 | 167 | corpus_size = 0 168 | src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] 169 | all_decs = [ [] for idx in range(args.valid_repeat_dec)] 170 | decoded_words, target_words, decoded_info = 0, 0, 0 171 | 172 | attentions = None 173 | decoder = model.decoder[0] if args.model is FastTransformer else model.decoder 174 | pad_id = decoder.field.vocab.stoi[''] 175 | eos_id = decoder.field.vocab.stoi[''] 176 | 177 | curr_time = 0 178 | cum_sentences = 0 179 | cum_tokens = 0 180 | cum_images = 0 # used for mscoco 181 | num_iters_total = [] 182 | 183 | for iters, dev_batch in enumerate(dev): 184 | start_t = time.time() 185 | 186 | if args.dataset != "mscoco": 187 | decoder_inputs, decoder_masks,\ 188 | targets, target_masks,\ 189 | sources, source_masks,\ 190 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 191 | else: 192 | # only use first caption for calculating log likelihood 193 | all_captions = dev_batch[1] 194 | dev_batch[1] = dev_batch[1][0] 195 | decoder_inputs, decoder_masks,\ 196 | targets, target_masks,\ 197 | _, source_masks,\ 198 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_len, trg_len_dic=trg_len_dic, bp=args.bp, gpu=args.gpu) 199 | sources = None 200 | 201 | cum_sentences += batch_size 202 | 203 | batch_size, src_len, hsize = encoding[0].size() 204 | 205 | # for now 206 | if type(model) is Transformer: 207 | all_decodings = [] 208 | decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, 209 | beam=args.beam_size, alpha=args.alpha, \ 210 | decoding=True, feedback=attentions) 211 | all_decodings.append( decoding ) 212 | num_iters = [0] 213 | 214 | elif type(model) is FastTransformer: 215 | decoding, all_decodings, num_iters, argmax_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ 216 | sources, source_masks, targets, encoding, model, args, use_argmax=True) 217 | num_iters_total.extend( num_iters ) 218 | 219 | if not args.use_argmax: 220 | for _ in range(args.num_samples): 221 | _, _, _, sampled_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ 222 | sources, source_masks, encoding, model, args, use_argmax=False) 223 | for iter_ in range(args.valid_repeat_dec): 224 | argmax_all_probs[iter_] = argmax_all_probs[iter_] + sampled_all_probs[iter_] 225 | 226 | all_decodings = [] 227 | for iter_ in range(args.valid_repeat_dec): 228 | argmax_all_probs[iter_] = argmax_all_probs[iter_] / args.num_samples 229 | all_decodings.append(torch.max(argmax_all_probs[iter_], dim=-1)[-1]) 230 | decoding = all_decodings[-1] 231 | 232 | used_t = time.time() - start_t 233 | curr_time += used_t 234 | 235 | if args.dataset != "mscoco": 236 | if args.remove_repeats: 237 | outputs_unidx = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', remove_repeats_tensor(decoding))]] 238 | else: 239 | outputs_unidx = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] 240 | 241 | else: 242 | # make sure that 5 captions per each example 243 | num_captions = len(all_captions[0]) 244 | for c in range(1, len(all_captions)): 245 | assert (num_captions == len(all_captions[c])) 246 | 247 | # untokenize reference captions 248 | for n_ref in range(len(all_captions)): 249 | n_caps = len(all_captions[0]) 250 | for c in range(n_caps): 251 | all_captions[n_ref][c] = all_captions[n_ref][c].replace("@@ ","") 252 | 253 | outputs_unidx = [ list(map(list, zip(*all_captions))) ] 254 | 255 | if args.remove_repeats: 256 | all_dec_outputs = [model.output_decoding(d) for d in [('trg', remove_repeats_tensor(all_decodings[ii])) for ii in range(len(all_decodings))]] 257 | else: 258 | all_dec_outputs = [model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))]] 259 | 260 | corpus_size += batch_size 261 | if args.dataset != "mscoco": 262 | cum_tokens += sum([len(xx.split(" ")) for xx in outputs_unidx[0]]) # NOTE source tokens, not target 263 | 264 | if args.dataset != "mscoco": 265 | src_outputs += outputs_unidx[0] 266 | trg_outputs += outputs_unidx[1] 267 | if args.remove_repeats: 268 | dec_outputs += remove_repeats(outputs_unidx[-1]) 269 | else: 270 | dec_outputs += outputs_unidx[-1] 271 | 272 | else: 273 | trg_outputs += outputs_unidx[0] 274 | 275 | for idx, each_output in enumerate(all_dec_outputs): 276 | if args.remove_repeats: 277 | all_decs[idx] += remove_repeats(each_output) 278 | else: 279 | all_decs[idx] += each_output 280 | 281 | #if True: 282 | if False and decoding_path is not None: 283 | for sent_i in range(len(outputs_unidx[0])): 284 | if args.dataset != "mscoco": 285 | print ('SRC') 286 | print (outputs_unidx[0][sent_i]) 287 | for ii in range(len(all_decodings)): 288 | print ('DEC iter {}'.format(ii)) 289 | print (all_dec_outputs[ii][sent_i]) 290 | print ('TRG') 291 | print (outputs_unidx[1][sent_i]) 292 | else: 293 | print ('TRG') 294 | trg = outputs_unidx[0] 295 | for subsent_i in range(len(trg[sent_i])): 296 | print ('TRG {}'.format(subsent_i)) 297 | print (trg[sent_i][subsent_i]) 298 | for ii in range(len(all_decodings)): 299 | print ('DEC iter {}'.format(ii)) 300 | print (all_dec_outputs[ii][sent_i]) 301 | print ('---------------------------') 302 | 303 | timings += [used_t] 304 | 305 | if not args.debug: 306 | for s, t, d in zip(outputs_unidx[0], outputs_unidx[1], outputs_unidx[2]): 307 | s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') 308 | print(s, file=handles[0], flush=True) 309 | print(t, file=handles[1], flush=True) 310 | print(d, file=handles[2], flush=True) 311 | 312 | if not args.no_tqdm: 313 | progressbar.update(iters) 314 | progressbar.set_description('finishing sentences={}/batches={}, \ 315 | length={}/average iter={}, speed={} sec/batch'.format(\ 316 | corpus_size, iters, src_len, np.mean(np.array(num_iters)), curr_time / (1 + iters))) 317 | 318 | if evaluate: 319 | for idx, each_dec in enumerate(all_decs): 320 | if len(all_decs[idx]) != len(trg_outputs): 321 | break 322 | if args.dataset != "mscoco": 323 | bleu_output = computeBLEU(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) 324 | else: 325 | bleu_output = computeBLEUMSCOCO(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) 326 | args.logger.info("iter {} | {}".format(idx+1, print_bleu(bleu_output))) 327 | 328 | if args.adaptive_decoding != None: 329 | args.logger.info("----------------------------------------------") 330 | args.logger.info("Average # iters {}".format(np.mean(num_iters_total))) 331 | bleu_output = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) 332 | args.logger.info("Adaptive BLEU | {}".format(print_bleu(bleu_output))) 333 | 334 | args.logger.info("----------------------------------------------") 335 | args.logger.info("Decoding speed analysis :") 336 | args.logger.info("{} sentences".format(cum_sentences)) 337 | if args.dataset != "mscoco": 338 | args.logger.info("{} tokens".format(cum_tokens)) 339 | args.logger.info("{:.3f} seconds".format(curr_time)) 340 | 341 | args.logger.info("{:.3f} ms / sentence".format((curr_time / float(cum_sentences) * 1000))) 342 | if args.dataset != "mscoco": 343 | args.logger.info("{:.3f} ms / token".format((curr_time / float(cum_tokens) * 1000))) 344 | 345 | args.logger.info("{:.3f} sentences / s".format(float(cum_sentences) / curr_time)) 346 | if args.dataset != "mscoco": 347 | args.logger.info("{:.3f} tokens / s".format(float(cum_tokens) / curr_time)) 348 | args.logger.info("----------------------------------------------") 349 | 350 | if args.decode_which > 0: 351 | args.logger.info("Writing to special file") 352 | parent = decoding_path / "speed" / "b_{}{}".format(args.beam_size if args.model is Transformer else args.valid_repeat_dec, 353 | "" if args.model is Transformer else "_{}".format(args.adaptive_decoding != None)) 354 | args.logger.info(str(parent)) 355 | parent.mkdir(parents=True, exist_ok=True) 356 | speed_handle = (parent / "results.{}".format(args.decode_which) ).open('w') 357 | 358 | print("----------------------------------------------", file=speed_handle, flush=True) 359 | print("Decoding speed analysis :", file=speed_handle, flush=True) 360 | print("{} sentences".format(cum_sentences), file=speed_handle, flush=True) 361 | if args.dataset != "mscoco": 362 | print("{} tokens".format(cum_tokens), file=speed_handle, flush=True) 363 | print("{:.3f} seconds".format(curr_time), file=speed_handle, flush=True) 364 | 365 | print("{:.3f} ms / sentence".format((curr_time / float(cum_sentences) * 1000)), file=speed_handle, flush=True) 366 | if args.dataset != "mscoco": 367 | print("{:.3f} ms / token".format((curr_time / float(cum_tokens) * 1000)), file=speed_handle, flush=True) 368 | 369 | print("{:.3f} sentences / s".format(float(cum_sentences) / curr_time), file=speed_handle, flush=True) 370 | if args.dataset != "mscoco": 371 | print("{:.3f} tokens / s".format(float(cum_tokens) / curr_time), file=speed_handle, flush=True) 372 | print("----------------------------------------------", file=speed_handle, flush=True) 373 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import numpy as np 4 | import math 5 | import gc 6 | import os 7 | 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, TargetLength, computeBLEU, computeBLEUMSCOCO, compute_bp, Batch, masked_sort, computeGroupBLEU, \ 16 | corrupt_target, remove_repeats, remove_repeats_tensor, print_bleu, corrupt_target_fix, set_eos, organise_trg_len_dic 17 | from time import gmtime, strftime 18 | 19 | # helper functions 20 | def export(x): 21 | try: 22 | with torch.cuda.device_of(x): 23 | return x.data.cpu().float().mean() 24 | except Exception: 25 | return 0 26 | 27 | tokenizer = lambda x: x.replace('@@ ', '').split() 28 | 29 | def valid_model(args, model, dev, dev_metrics=None, dev_metrics_trg=None, dev_metrics_average=None, 30 | print_out=False, teacher_model=None, trg_len_dic=None): 31 | print_seq = (['REF '] if args.dataset == "mscoco" else ['SRC ', 'REF ']) + ['HYP{}'.format(ii+1) for ii in range(args.valid_repeat_dec)] 32 | 33 | trg_outputs = [] 34 | real_all_outputs = [ [] for ii in range(args.valid_repeat_dec)] 35 | short_all_outputs = [ [] for ii in range(args.valid_repeat_dec)] 36 | outputs_data = {} 37 | 38 | model.eval() 39 | for j, dev_batch in enumerate(dev): 40 | if args.dataset == "mscoco": 41 | # only use first caption for calculating log likelihood 42 | all_captions = dev_batch[1] 43 | dev_batch[1] = dev_batch[1][0] 44 | decoder_inputs, decoder_masks,\ 45 | targets, target_masks,\ 46 | _, source_masks,\ 47 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 48 | 49 | else: 50 | decoder_inputs, decoder_masks,\ 51 | targets, target_masks,\ 52 | sources, source_masks,\ 53 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 54 | 55 | losses, all_decodings = [], [] 56 | if type(model) is Transformer: 57 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=1, decoding=True, return_probs=True) 58 | loss = model.cost(targets, target_masks, out=out) 59 | losses.append(loss) 60 | all_decodings.append( decoding ) 61 | 62 | elif type(model) is FastTransformer: 63 | for iter_ in range(args.valid_repeat_dec): 64 | curr_iter = min(iter_, args.num_decs-1) 65 | next_iter = min(curr_iter + 1, args.num_decs-1) 66 | 67 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) 68 | 69 | #loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) 70 | #losses.append(loss) 71 | all_decodings.append( decoding ) 72 | 73 | decoder_inputs = 0 74 | if args.next_dec_input in ["both", "emb"]: 75 | _, argmax = torch.max(probs, dim=-1) 76 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 77 | decoder_inputs += emb 78 | 79 | if args.next_dec_input in ["both", "out"]: 80 | decoder_inputs += out 81 | 82 | if args.dataset == "mscoco": 83 | # make sure that 5 captions per each example 84 | num_captions = len(all_captions[0]) 85 | for c in range(1, len(all_captions)): 86 | assert (num_captions == len(all_captions[c])) 87 | 88 | # untokenize reference captions 89 | for n_ref in range(len(all_captions)): 90 | n_caps = len(all_captions[0]) 91 | for c in range(n_caps): 92 | all_captions[n_ref][c] = all_captions[n_ref][c].replace("@@ ","") 93 | 94 | src_ref = [ list(map(list, zip(*all_captions))) ] 95 | else: 96 | src_ref = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets)] ] 97 | 98 | real_outputs = [ model.output_decoding(d) for d in [('trg', xx) for xx in all_decodings] ] 99 | 100 | if print_out: 101 | if args.dataset != "mscoco": 102 | for k, d in enumerate(src_ref + real_outputs): 103 | args.logger.info("{} ({}): {}".format(print_seq[k], len(d[0].split(" ")), d[0])) 104 | else: 105 | for k in range(len(all_captions[0])): 106 | for c in range(len(all_captions)): 107 | args.logger.info("REF ({}): {}".format(len(all_captions[c][k].split(" ")), all_captions[c][k])) 108 | 109 | for c in range(len(real_outputs)): 110 | args.logger.info("HYP {} ({}): {}".format(c+1, len(real_outputs[c][k].split(" ")), real_outputs[c][k])) 111 | args.logger.info('------------------------------------------------------------------') 112 | 113 | trg_outputs += src_ref[-1] 114 | for ii, d_outputs in enumerate(real_outputs): 115 | real_all_outputs[ii] += d_outputs 116 | 117 | #if dev_metrics is not None: 118 | # dev_metrics.accumulate(batch_size, *losses) 119 | if dev_metrics_trg is not None: 120 | dev_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) 121 | if dev_metrics_average is not None: 122 | dev_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) 123 | 124 | if args.dataset != "mscoco": 125 | real_bleu = [computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs] 126 | else: 127 | real_bleu = [computeBLEUMSCOCO(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs] 128 | 129 | outputs_data['real'] = real_bleu 130 | 131 | if "predict" in args.trg_len_option: 132 | outputs_data['pred_target_len_loss'] = getattr(dev_metrics_trg, 'pred_target_len_loss') 133 | outputs_data['pred_target_len_correct'] = getattr(dev_metrics_trg, 'pred_target_len_correct') 134 | outputs_data['pred_target_len_approx'] = getattr(dev_metrics_trg, 'pred_target_len_approx') 135 | outputs_data['average_target_len_correct'] = getattr(dev_metrics_average, 'average_target_len_correct') 136 | outputs_data['average_target_len_approx'] = getattr(dev_metrics_average, 'average_target_len_approx') 137 | 138 | #if dev_metrics is not None: 139 | # args.logger.info(dev_metrics) 140 | if dev_metrics_trg is not None: 141 | args.logger.info(dev_metrics_trg) 142 | if dev_metrics_average is not None: 143 | args.logger.info(dev_metrics_average) 144 | 145 | for idx in range(args.valid_repeat_dec): 146 | print_str = "iter {} | {}".format(idx+1, print_bleu(real_bleu[idx], verbose=False)) 147 | args.logger.info( print_str ) 148 | 149 | return outputs_data 150 | 151 | def train_model(args, model, train, dev, src=None, trg=None, trg_len_dic=None, teacher_model=None, save_path=None, maxsteps=None): 152 | 153 | if args.tensorboard and (not args.debug): 154 | from tensorboardX import SummaryWriter 155 | writer = SummaryWriter(str(args.event_path / args.id_str)) 156 | 157 | if type(model) is FastTransformer and args.denoising_prob > 0.0: 158 | denoising_weights = [args.denoising_weight for idx in range(args.train_repeat_dec)] 159 | denoising_out_weights = [args.denoising_out_weight for idx in range(args.train_repeat_dec)] 160 | 161 | if type(model) is FastTransformer and args.layerwise_denoising_weight: 162 | start, end = 0.9, 0.1 163 | diff = (start-end)/(args.train_repeat_dec-1) 164 | denoising_weights = np.arange(start=end, stop=start, step=diff).tolist()[::-1] + [0.1] 165 | 166 | # optimizer 167 | for k, p in zip(model.state_dict().keys(), model.parameters()): 168 | # only finetune layers that are responsible to predicting target len 169 | if args.finetune_trg_len: 170 | if "pred_len" not in k: 171 | p.requires_grad = False 172 | else: 173 | print(k) 174 | else: 175 | if "pred_len" in k: 176 | p.requires_grad = False 177 | 178 | params = [p for p in model.parameters() if p.requires_grad] 179 | if args.optimizer == 'Adam': 180 | opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9) 181 | else: 182 | raise NotImplementedError 183 | 184 | # if resume training 185 | if (args.load_from is not None) and (args.resume) and not args.finetune_trg_len: 186 | with torch.cuda.device(args.gpu): # very important. 187 | offset, opt_states = torch.load(str(args.model_path / args.load_from) + '.pt.states', 188 | map_location=lambda storage, loc: storage.cuda()) 189 | opt.load_state_dict(opt_states) 190 | else: 191 | offset = 0 192 | 193 | if not args.finetune_trg_len: 194 | best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], 195 | 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, 196 | which=range(args.valid_repeat_dec)) 197 | else: 198 | best = Best(max, *['pred_target_len_correct'], 199 | 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, 200 | which=[0]) 201 | train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg") 202 | dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg") 203 | 204 | if "predict" in args.trg_len_option: 205 | train_metrics_trg = Metrics('train loss target', *["pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx"], data_type="avg") 206 | train_metrics_average = Metrics('train loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") 207 | dev_metrics_trg = Metrics('dev loss target', *["pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx"], data_type="avg") 208 | dev_metrics_average = Metrics('dev loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") 209 | else: 210 | train_metrics_trg = None 211 | train_metrics_average = None 212 | dev_metrics_trg = None 213 | dev_metrics_average = None 214 | 215 | if not args.no_tqdm: 216 | progressbar = tqdm(total=args.eval_every, desc='start training.') 217 | 218 | if maxsteps is None: 219 | maxsteps = args.maximum_steps 220 | 221 | #targetlength = TargetLength() 222 | for iters, train_batch in enumerate(train): 223 | #targetlength.accumulate( train_batch ) 224 | #continue 225 | 226 | iters += offset 227 | 228 | if args.save_every > 0 and iters % args.save_every == 0: 229 | args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) 230 | with torch.cuda.device(args.gpu): 231 | torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(str(args.model_path / args.id_str), iters)) 232 | torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(str(args.model_path / args.id_str), iters)) 233 | 234 | if (iters+1) % args.eval_every == 0: 235 | torch.cuda.empty_cache() 236 | gc.collect() 237 | dev_metrics.reset() 238 | if dev_metrics_trg is not None: 239 | dev_metrics_trg.reset() 240 | if dev_metrics_average is not None: 241 | dev_metrics_average.reset() 242 | outputs_data = valid_model(args, model, dev, dev_metrics, dev_metrics_trg=dev_metrics_trg, dev_metrics_average=dev_metrics_average, teacher_model=None, print_out=False, trg_len_dic=trg_len_dic) 243 | #outputs_data = [0, [0,0,0,0], 0, 0] 244 | if args.tensorboard and (not args.debug): 245 | for ii in range(args.valid_repeat_dec): 246 | writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters) # NLL averaged over dev corpus 247 | writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['real'][ii][0], iters) # NOTE corpus bleu 248 | 249 | if "predict" in args.trg_len_option: 250 | writer.add_scalar("dev/single/pred_target_len_loss", outputs_data["pred_target_len_loss"], iters) 251 | writer.add_scalar("dev/single/pred_target_len_correct", outputs_data["pred_target_len_correct"], iters) 252 | writer.add_scalar("dev/single/pred_target_len_approx", outputs_data["pred_target_len_approx"], iters) 253 | writer.add_scalar("dev/single/average_target_len_correct", outputs_data["average_target_len_correct"], iters) 254 | writer.add_scalar("dev/single/average_target_len_approx", outputs_data["average_target_len_approx"], iters) 255 | 256 | """ 257 | writer.add_scalars('dev/total/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters) 258 | writer.add_scalars('dev/total/Losses', 259 | { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) 260 | for idx in range(args.valid_repeat_dec) }, 261 | iters ) 262 | """ 263 | 264 | if not args.debug: 265 | if not args.finetune_trg_len: 266 | best.accumulate(*[xx[0] for xx in outputs_data['real']], iters) 267 | 268 | values = list( best.metrics.values() ) 269 | args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \ 270 | "i={}".format( values[args.valid_repeat_dec] ), ) ) 271 | else: 272 | best.accumulate(*[outputs_data['pred_target_len_correct']], iters) 273 | values = list( best.metrics.values() ) 274 | args.logger.info("best model : {}".format( "pred_target_len_correct = {}".format(values[0])) ) 275 | 276 | args.logger.info('model:' + args.prefix + args.hp_str) 277 | 278 | # ---set-up a new progressor--- 279 | if not args.no_tqdm: 280 | progressbar.close() 281 | progressbar = tqdm(total=args.eval_every, desc='start training.') 282 | 283 | if type(model) is FastTransformer and args.anneal_denoising_weight: 284 | for ii, bb in enumerate([xx[0] for xx in outputs_data['real']][:-1]): 285 | denoising_weights[ii] = 0.9 - 0.1 * int(math.floor(bb / 3.0)) 286 | 287 | if iters > maxsteps: 288 | args.logger.info('reached the maximum updating steps.') 289 | break 290 | 291 | model.train() 292 | 293 | def get_lr_transformer(i, lr0=0.1): 294 | return lr0 * 10 / math.sqrt(args.d_model) * min( 295 | 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) 296 | 297 | def get_lr_anneal(iters, lr0=0.1): 298 | lr_end = 1e-5 299 | return max( 0, (args.lr - lr_end) * (args.anneal_steps - iters) / args.anneal_steps ) + lr_end 300 | 301 | if args.lr_schedule == "fixed": 302 | opt.param_groups[0]['lr'] = args.lr 303 | elif args.lr_schedule == "anneal": 304 | opt.param_groups[0]['lr'] = get_lr_anneal(iters + 1) 305 | elif args.lr_schedule == "transformer": 306 | opt.param_groups[0]['lr'] = get_lr_transformer(iters + 1) 307 | #if iters % 2 == 0: 308 | # opt.zero_grad() 309 | 310 | if args.dataset == "mscoco": 311 | decoder_inputs, decoder_masks,\ 312 | targets, target_masks,\ 313 | _, source_masks,\ 314 | encoding, batch_size, rest = model.quick_prepare_mscoco(train_batch, all_captions=train_batch[1], fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 315 | else: 316 | decoder_inputs, decoder_masks,\ 317 | targets, target_masks,\ 318 | sources, source_masks,\ 319 | encoding, batch_size, rest = model.quick_prepare(train_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 320 | 321 | losses = [] 322 | if type(model) is Transformer: 323 | loss = model.cost(targets, target_masks, out=model(encoding, source_masks, decoder_inputs, decoder_masks)) 324 | losses.append( loss ) 325 | 326 | elif type(model) is FastTransformer: 327 | all_logits = [] 328 | all_denoising_masks = [] 329 | for iter_ in range(args.train_repeat_dec): 330 | torch.cuda.empty_cache() 331 | curr_iter = min(iter_, args.num_decs-1) 332 | next_iter = min(curr_iter + 1, args.num_decs-1) 333 | 334 | out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter, return_probs=False) 335 | 336 | if args.rf_finetune is True: 337 | loss = model.rf_cost(args, targets, target_masks, out=out, iter_=curr_iter) 338 | elif args.nat_finetune is True: 339 | loss = model.nat_cost(args, targets, target_masks, out=out, iter_=curr_iter) 340 | else: 341 | loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) 342 | 343 | logits = model.decoder[curr_iter].out(out) 344 | 345 | if args.use_argmax: 346 | _, argmax = torch.max(logits, dim=-1) 347 | else: 348 | probs = softmax(logits) 349 | probs_sz = probs.size() 350 | logits_ = Variable(probs.data, requires_grad=False) 351 | argmax = torch.multinomial(logits_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) 352 | 353 | if args.self_distil > 0.0: 354 | all_logits.append(logits_masked) 355 | 356 | losses.append(loss) 357 | 358 | decoder_inputs_ = 0 359 | denoising_mask = 1 360 | if args.next_dec_input in ["both", "emb"]: 361 | if args.denoising_prob > 0.0 and np.random.rand() < args.denoising_prob: 362 | cor = corrupt_target(targets, decoder_masks, len(trg.vocab), denoising_weights[iter_], args.corruption_probs) 363 | 364 | emb = F.embedding(cor, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 365 | denoising_mask = 0 366 | else: 367 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 368 | 369 | if args.denoising_out_weight > 0: 370 | if denoising_out_weights[iter_] > 0.0: 371 | corrupted_argmax = corrupt_target(argmax, decoder_masks, denoising_out_weights[iter_]) 372 | else: 373 | corrupted_argmax = argmax 374 | emb = F.embedding(corrupted_argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 375 | decoder_inputs_ += emb 376 | all_denoising_masks.append( denoising_mask ) 377 | 378 | if args.next_dec_input in ["both", "out"]: 379 | decoder_inputs_ += out 380 | decoder_inputs = decoder_inputs_ 381 | 382 | # self distillation loss if requested 383 | if args.self_distil > 0.0: 384 | self_distil_losses = [] 385 | 386 | for logits_i in range(1, len(all_logits)-1): 387 | self_distill_loss_i = 0.0 388 | for logits_j in range(logits_i+1, len(all_logits)): 389 | self_distill_loss_i += \ 390 | all_denoising_masks[logits_j] * \ 391 | all_denoising_masks[logits_i] * \ 392 | (1/(logits_j-logits_i)) * args.self_distil * F.mse_loss(all_logits[logits_i], all_logits[logits_j].detach()) 393 | 394 | self_distil_losses.append(self_distill_loss_i) 395 | 396 | self_distil_loss = sum(self_distil_losses) 397 | 398 | loss = sum(losses) 399 | 400 | # accmulate the training metrics 401 | train_metrics.accumulate(batch_size, *losses, print_iter=None) 402 | if train_metrics_trg is not None: 403 | train_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) 404 | if train_metrics_average is not None: 405 | train_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) 406 | if type(model) is FastTransformer and args.self_distil > 0.0: 407 | (loss+self_distil_loss).backward() 408 | else: 409 | if "predict" in args.trg_len_option: 410 | if args.finetune_trg_len: 411 | rest[0].backward() 412 | else: 413 | loss.backward() 414 | else: 415 | loss.backward() 416 | 417 | if args.grad_clip > 0: 418 | total_norm = nn.utils.clip_grad_norm(params, args.grad_clip) 419 | # if iters % 2 == 1: 420 | # opt.step() 421 | 422 | mid_str = '' 423 | if type(model) is FastTransformer and args.self_distil > 0.0: 424 | mid_str += 'distil={:.5f}, '.format(self_distil_loss.cpu().data.numpy()[0]) 425 | #if type(model) is FastTransformer and "predict" in args.trg_len_option: 426 | # mid_str += 'pred_target_len_loss={:.5f}, '.format(rest[0].cpu().data.numpy()[0]) 427 | if type(model) is FastTransformer and args.denoising_prob > 0.0: 428 | mid_str += "/".join(["{:.1f}".format(ff) for ff in denoising_weights[:-1]])+", " 429 | 430 | info = 'update={}, loss={}, {}lr={:.1e}'.format( iters, 431 | "/".join(["{:.3f}".format(export(ll)) for ll in losses]), 432 | mid_str, 433 | opt.param_groups[0]['lr']) 434 | 435 | if args.no_tqdm: 436 | if iters % args.eval_every == 0: 437 | args.logger.info("update {} : {}".format(iters, str(train_metrics))) 438 | else: 439 | progressbar.update(1) 440 | progressbar.set_description(info) 441 | 442 | if (iters+1) % args.eval_every == 0 and args.tensorboard and (not args.debug): 443 | for idx in range(args.train_repeat_dec): 444 | writer.add_scalar('train/single/Loss_{}'.format(idx+1), getattr(train_metrics, "loss_{}".format(idx+1)), iters) 445 | if "predict" in args.trg_len_option: 446 | writer.add_scalar("train/single/pred_target_len_loss", getattr(train_metrics_trg, "pred_target_len_loss"), iters) 447 | writer.add_scalar("train/single/pred_target_len_correct", getattr(train_metrics_trg, "pred_target_len_correct"), iters) 448 | writer.add_scalar("train/single/pred_target_len_approx", getattr(train_metrics_trg, "pred_target_len_approx"), iters) 449 | writer.add_scalar("train/single/average_target_len_correct", getattr(train_metrics_average, "average_target_len_correct"), iters) 450 | writer.add_scalar("train/single/average_target_len_approx", getattr(train_metrics_average, "average_target_len_approx"), iters) 451 | 452 | train_metrics.reset() 453 | if train_metrics_trg is not None: 454 | train_metrics_trg.reset() 455 | if train_metrics_average is not None: 456 | train_metrics_average.reset() 457 | 458 | #torch.save(targetlength.lengths, str(args.data_prefix / "trg_len_dic" / args.dataset[-4:])) 459 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import ipdb 3 | import torch 4 | import random 5 | import numpy as np 6 | import _pickle as pickle 7 | import revtok 8 | import os 9 | from itertools import groupby 10 | import getpass 11 | from collections import Counter 12 | 13 | from torch.autograd import Variable 14 | from torchtext import data, datasets 15 | from nltk.translate.gleu_score import sentence_gleu, corpus_gleu 16 | from nltk.translate.bleu_score import closest_ref_length, brevity_penalty, modified_precision, SmoothingFunction 17 | from contextlib import ExitStack 18 | from collections import OrderedDict 19 | import fractions 20 | 21 | 22 | try: 23 | fractions.Fraction(0, 1000, _normalize=False) 24 | from fractions import Fraction 25 | except TypeError: 26 | from nltk.compat import Fraction 27 | 28 | def sentence_bleu(references, hypothesis, weights=(0.25, 0.25, 0.25, 0.25), 29 | smoothing_function=None, auto_reweigh=False, 30 | emulate_multibleu=False): 31 | 32 | return corpus_bleu([references], [hypothesis], 33 | weights, smoothing_function, auto_reweigh, 34 | emulate_multibleu) 35 | 36 | 37 | def corpus_bleu(list_of_references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), 38 | smoothing_function=None, auto_reweigh=False, 39 | emulate_multibleu=False): 40 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 41 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 42 | hyp_lengths, ref_lengths = 0, 0 43 | 44 | if len(list_of_references) != len(hypotheses): 45 | print ("The number of hypotheses and their reference(s) should be the same") 46 | return (0, (0, 0, 0, 0), 0, 0, 0) 47 | 48 | # Iterate through each hypothesis and their corresponding references. 49 | for references, hypothesis in zip(list_of_references, hypotheses): 50 | # For each order of ngram, calculate the numerator and 51 | # denominator for the corpus-level modified precision. 52 | for i, _ in enumerate(weights, start=1): 53 | p_i = modified_precision(references, hypothesis, i) 54 | p_numerators[i] += p_i.numerator 55 | p_denominators[i] += p_i.denominator 56 | 57 | # Calculate the hypothesis length and the closest reference length. 58 | # Adds them to the corpus-level hypothesis and reference counts. 59 | hyp_len = len(hypothesis) 60 | hyp_lengths += hyp_len 61 | ref_lengths += closest_ref_length(references, hyp_len) 62 | 63 | # Calculate corpus-level brevity penalty. 64 | bp = brevity_penalty(ref_lengths, hyp_lengths) 65 | 66 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 67 | # order of n-grams < 4 and weights is set at default. 68 | if auto_reweigh: 69 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 70 | weights = ( 1 / hyp_lengths ,) * hyp_lengths 71 | 72 | # Collects the various precision values for the different ngram orders. 73 | p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) 74 | for i, _ in enumerate(weights, start=1)] 75 | 76 | p_n_ = [xx.numerator / xx.denominator * 100 for xx in p_n] 77 | 78 | # Returns 0 if there's no matching n-grams 79 | # We only need to check for p_numerators[1] == 0, since if there's 80 | # no unigrams, there won't be any higher order ngrams. 81 | if p_numerators[1] == 0: 82 | return (0, (0, 0, 0, 0), 0, 0, 0) 83 | 84 | # If there's no smoothing, set use method0 from SmoothinFunction class. 85 | if not smoothing_function: 86 | smoothing_function = SmoothingFunction().method0 87 | # Smoothen the modified precision. 88 | # Note: smoothing_function() may convert values into floats; 89 | # it tries to retain the Fraction object as much as the 90 | # smoothing method allows. 91 | p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, 92 | hyp_len=hyp_len, emulate_multibleu=emulate_multibleu) 93 | s = (w * math.log(p_i) for i, (w, p_i) in enumerate(zip(weights, p_n))) 94 | s = bp * math.exp(math.fsum(s)) * 100 95 | final_bleu = round(s, 4) if emulate_multibleu else s 96 | return (final_bleu, p_n_, bp, ref_lengths, hyp_lengths) 97 | 98 | INF = 1e10 99 | TINY = 1e-9 100 | 101 | def n_grams(list_words): 102 | set_1gram, set_2gram, set_3gram, set_4gram = set(), set(), set(), set() 103 | count = {} 104 | l = len(list_words) 105 | for i in range(l): 106 | word = list_words[i] 107 | if word not in set_1gram: 108 | set_1gram.add(word) 109 | count[word] = 1 110 | else: 111 | set_1gram.add((word,count[word])) 112 | count[word] += 1 113 | count = {} 114 | 115 | for i in range(l-1): 116 | word = (list_words[i],list_words[i+1]) 117 | if word not in set_2gram: 118 | set_2gram.add(word) 119 | count[word] = 1 120 | else: 121 | set_2gram.add((word,count[word])) 122 | count[word] += 1 123 | 124 | count = {} 125 | 126 | for i in range(l-2): 127 | word = (list_words[i],list_words[i+1], list_words[i+2]) 128 | if word not in set_3gram: 129 | set_3gram.add(word) 130 | count[word] = 1 131 | else: 132 | set_3gram.add((word,count[word])) 133 | count[word] += 1 134 | count = {} 135 | 136 | for i in range(l-3): 137 | word = (list_words[i],list_words[i+1], list_words[i+2], list_words[i+3]) 138 | if word not in set_4gram: 139 | set_4gram.add(word) 140 | count[word] = 1 141 | else: 142 | set_4gram.add((word,count[word])) 143 | count[word] += 1 144 | 145 | return set_1gram, set_2gram, set_3gram, set_4gram 146 | 147 | def my_sentence_gleu(references, hypothesis): 148 | global t1,t2 149 | reference = references[0] 150 | ref_grams = n_grams(reference) 151 | hyp_grams = n_grams(hypothesis) 152 | match_grams = [x.intersection(y) for (x,y) in zip(ref_grams, hyp_grams)] 153 | ref_count = sum([len(x) for x in ref_grams]) 154 | hyp_count = sum([len(x) for x in hyp_grams]) 155 | match_count = sum([len(x) for x in match_grams]) 156 | gleu = float(match_count) / float(max(ref_count,hyp_count)) 157 | return gleu 158 | 159 | def computeGLEU(outputs, targets, corpus=False, tokenizer=None): 160 | if tokenizer is None: 161 | tokenizer = revtok.tokenize 162 | 163 | 164 | if not corpus: 165 | return [my_sentence_gleu([t], o) for o, t in zip(outputs, targets)] 166 | 167 | return corpus_gleu([[t] for t in targets], [o for o in outputs]) 168 | 169 | def computeBLEU(outputs, targets, corpus=False, tokenizer=None): 170 | if tokenizer is None: 171 | tokenizer = revtok.tokenize 172 | 173 | outputs = [tokenizer(o) for o in outputs] 174 | targets = [tokenizer(t) for t in targets] 175 | 176 | if corpus: 177 | return corpus_bleu([[t] for t in targets], [o for o in outputs], emulate_multibleu=True) 178 | else: 179 | return [sentence_bleu([t], o)[0] for o, t in zip(outputs, targets)] 180 | #return torch.Tensor([sentence_bleu([t], o)[0] for o, t in zip(outputs, targets)]) 181 | 182 | def computeBLEUMSCOCO(outputs, targets, corpus=True, tokenizer=None): 183 | # outputs is list of 5000 captions 184 | # targets is list of 5000 lists each length of 5 185 | if tokenizer is None: 186 | tokenizer = revtok.tokenize 187 | 188 | outputs = [tokenizer(o) for o in outputs] 189 | new_targets = [] 190 | for i, t in enumerate(targets): 191 | new_targets.append([tokenizer(tt) for tt in t]) 192 | #targets[i] = [tokenizer(tt) for tt in t] 193 | 194 | if corpus: 195 | return corpus_bleu(new_targets, outputs, emulate_multibleu=True) 196 | else: 197 | return [sentence_bleu(new_t, o)[0] for o, new_t in zip(outputs, new_targets)] 198 | 199 | def compute_bp(hypotheses, list_of_references): 200 | hyp_lengths, ref_lengths = 0, 0 201 | for references, hypothesis in zip(list_of_references, hypotheses): 202 | hyp_len = len(hypothesis) 203 | hyp_lengths += hyp_len 204 | ref_lengths += closest_ref_length(references, hyp_len) 205 | 206 | # Calculate corpus-level brevity penalty. 207 | bp = brevity_penalty(ref_lengths, hyp_lengths) 208 | return bp 209 | 210 | def computeGroupBLEU(outputs, targets, tokenizer=None, bra=10, maxmaxlen=80): 211 | if tokenizer is None: 212 | tokenizer = revtok.tokenize 213 | 214 | outputs = [tokenizer(o) for o in outputs] 215 | targets = [tokenizer(t) for t in targets] 216 | maxlens = max([len(t) for t in targets]) 217 | print(maxlens) 218 | maxlens = min([maxlens, maxmaxlen]) 219 | nums = int(np.ceil(maxlens / bra)) 220 | outputs_buckets = [[] for _ in range(nums)] 221 | targets_buckets = [[] for _ in range(nums)] 222 | for o, t in zip(outputs, targets): 223 | idx = len(o) // bra 224 | if idx >= len(outputs_buckets): 225 | idx = -1 226 | outputs_buckets[idx] += [o] 227 | targets_buckets[idx] += [t] 228 | 229 | for k in range(nums): 230 | print(corpus_bleu([[t] for t in targets_buckets[k]], [o for o in outputs_buckets[k]], emulate_multibleu=True)) 231 | 232 | class TargetLength: 233 | def __init__(self, lengths=None): # data_type : sum, avg 234 | self.lengths = lengths if lengths != None else dict() 235 | 236 | def accumulate(self, batch): 237 | src_len = (batch.src != 1).sum(-1).cpu().data.numpy() 238 | trg_len = (batch.trg != 1).sum(-1).cpu().data.numpy() 239 | for (slen, tlen) in zip(src_len, trg_len): 240 | if not slen in self.lengths: 241 | self.lengths[slen] = (1, int(tlen)) 242 | else: 243 | (count, acc) = self.lengths[slen] 244 | self.lengths[slen] = (count + 1, acc + int(tlen)) 245 | 246 | def get_trg_len(self, src_len): 247 | if not src_len in self.lengths: 248 | return self.get_trg_len(src_len + 1) - 1 249 | else: 250 | (count, acc) = self.lengths[src_len] 251 | return acc / float(count) 252 | 253 | def organise_trg_len_dic(trg_len_dic): 254 | trg_len_dic = {k:int(v[1]/float(v[0])) for (k, v) in trg_len_dic.items()} 255 | return trg_len_dic 256 | 257 | def query_trg_len_dic(trg_len_dic, q): 258 | max_src_len = max(trg_len_dic.keys()) 259 | if q <= max_src_len: 260 | if q in trg_len_dic: 261 | return trg_len_dic[q] 262 | else: 263 | return query_trg_len_dic(trg_len_dic, q+1) - 1 264 | else: 265 | return int(math.floor( trg_len_dic[max_src_len] / max_src_len * q )) 266 | 267 | def make_decoder_masks(source_masks, trg_len_dic): 268 | batch_size, src_max_len = source_masks.size() 269 | src_len = (source_masks == 1).sum(-1).cpu().numpy() 270 | trg_len = [int(math.floor(query_trg_len_dic(trg_len_dic, src) * 1.1)) for src in src_len] 271 | trg_max_len = max(trg_len) 272 | decoder_masks = np.zeros((batch_size, trg_max_len)) 273 | #decoder_masks = Variable(torch.zeros(batch_size, trg_max_len), requires_grad=False) 274 | for idx, tt in enumerate(trg_len): 275 | decoder_masks[idx][:tt] = 1 276 | result = torch.from_numpy(decoder_masks).float() 277 | if source_masks.is_cuda: 278 | result = result.cuda() 279 | return result 280 | 281 | def double_source_masks(source_masks): 282 | batch_size, src_max_len = source_masks.size() 283 | src_len = (source_masks == 1).sum(-1).cpu().numpy() 284 | decoder_masks = np.zeros((batch_size, src_max_len * 2)) 285 | for idx, tt in enumerate(src_len): 286 | decoder_masks[idx][:2*tt] = 1 287 | result = torch.from_numpy(decoder_masks).float() 288 | if source_masks.is_cuda: 289 | result = result.cuda() 290 | return result 291 | 292 | class Metrics: 293 | 294 | def __init__(self, name, *metrics, data_type="sum"): # data_type : sum, avg 295 | self.count = 0 296 | self.metrics = OrderedDict((metric, 0) for metric in metrics) 297 | self.name = name 298 | self.data_type = data_type 299 | 300 | def accumulate(self, count, *values, print_iter=None): 301 | self.count += count 302 | if print_iter is not None: 303 | print(print_iter, end=' ') 304 | for value, metric in zip(values, self.metrics): 305 | if isinstance(value, torch.autograd.Variable): 306 | value = value.data 307 | if torch.is_tensor(value): 308 | with torch.cuda.device_of(value): 309 | value = value.cpu() 310 | value = value.float().sum() 311 | 312 | if print_iter is not None: 313 | print('%.3f' % value, end=' ') 314 | if self.data_type == "sum": 315 | self.metrics[metric] += value 316 | elif self.data_type == "avg": 317 | self.metrics[metric] += value * count 318 | 319 | if print_iter is not None: 320 | print() 321 | return values[0] # loss 322 | 323 | def __getattr__(self, key): 324 | if key in self.metrics: 325 | return self.metrics[key] / (self.count + 1e-9) 326 | raise AttributeError 327 | 328 | def __repr__(self): 329 | return ("{}: ".format(self.name) + 330 | "[{}]".format( ', '.join(["{:.4f}".format(getattr(self, metric)) for metric, value in self.metrics.items() if value is not 0 ] ) ) ) 331 | 332 | def tensorboard(self, expt, i): 333 | for metric in self.metrics: 334 | value = getattr(self, metric) 335 | if value != 0: 336 | #expt.add_scalar_value(f'{self.name}_{metric}', value, step=i) 337 | expt.add_scalar_value("{}_{}".format(self.name, metric), value, step=i) 338 | 339 | def reset(self): 340 | self.count = 0 341 | self.metrics.update({metric: 0 for metric in self.metrics}) 342 | 343 | class Best: 344 | def __init__(self, cmp_fn, *metrics, model=None, opt=None, path='', gpu=0, which=[0]): 345 | self.cmp_fn = cmp_fn 346 | self.model = model 347 | self.opt = opt 348 | self.path = path + '.pt' 349 | self.metrics = OrderedDict((metric, None) for metric in metrics) 350 | self.gpu = gpu 351 | self.which = which 352 | self.best_cmp_value = None 353 | 354 | def accumulate(self, *other_values): 355 | 356 | with torch.cuda.device(self.gpu): 357 | cmp_values = [other_values[which] for which in self.which] 358 | if self.best_cmp_value is None or \ 359 | self.cmp_fn(self.best_cmp_value, *cmp_values) != self.best_cmp_value: 360 | self.metrics.update( { metric: value for metric, value in zip( 361 | list(self.metrics.keys()), other_values) } ) 362 | self.best_cmp_value = self.cmp_fn( [ list(self.metrics.items())[which][1] for which in self.which ] ) 363 | 364 | #open(self.path + '.temp', 'w') 365 | if self.model is not None: 366 | torch.save(self.model.state_dict(), self.path) 367 | 368 | if self.opt is not None: 369 | torch.save([self.i, self.opt.state_dict()], self.path + '.states') 370 | #os.remove(self.path + '.temp') 371 | 372 | def __getattr__(self, key): 373 | if key in self.metrics: 374 | return self.metrics[key] 375 | raise AttributeError 376 | 377 | def __repr__(self): 378 | return ("BEST: " + 379 | ', '.join(["{}: {:.4f}".format(metric, getattr(self, metric)) for metric, value in self.metrics.items() if value is not 0])) 380 | 381 | class CacheExample(data.Example): 382 | 383 | @classmethod 384 | def fromsample(cls, data_lists, names): 385 | ex = cls() 386 | for data, name in zip(data_lists, names): 387 | setattr(ex, name, data) 388 | return ex 389 | 390 | 391 | class Cache: 392 | 393 | def __init__(self, size=10000, fileds=["src", "trg"]): 394 | self.cache = [] 395 | self.maxsize = size 396 | 397 | def demask(self, data, mask): 398 | with torch.cuda.device_of(data): 399 | data = [d[:l] for d, l in zip(data.data.tolist(), mask.sum(1).long().tolist())] 400 | return data 401 | 402 | def add(self, data_lists, masks, names): 403 | data_lists = [self.demask(d, m) for d, m in zip(data_lists, masks)] 404 | for data in zip(*data_lists): 405 | self.cache.append(CacheExample.fromsample(data, names)) 406 | 407 | if len(self.cache) >= self.maxsize: 408 | self.cache = self.cache[-self.maxsize:] 409 | 410 | 411 | class Batch: 412 | def __init__(self, src=None, trg=None, dec=None): 413 | self.src, self.trg, self.dec = src, trg, dec 414 | 415 | def masked_sort(x, mask, dim=-1): 416 | x.data += ((1 - mask) * INF).long() 417 | y, i = torch.sort(x, dim) 418 | y.data *= mask.long() 419 | return y, i 420 | 421 | def unsorted(y, i, dim=-1): 422 | z = Variable(y.data.new(*y.size())) 423 | z.scatter_(dim, i, y) 424 | return z 425 | 426 | 427 | def merge_cache(decoding_path, names0, last_epoch=0, max_cache=20): 428 | file_lock = open(decoding_path + '/_temp_decode', 'w') 429 | 430 | for name in names0: 431 | filenames = [] 432 | for i in range(max_cache): 433 | filenames.append('{}/{}.ep{}'.format(decoding_path, name, last_epoch - i)) 434 | if (last_epoch - i) <= 0: 435 | break 436 | code = 'cat {} > {}.train.{}'.format(" ".join(filenames), '{}/{}'.format(decoding_path, name), last_epoch) 437 | os.system(code) 438 | os.remove(decoding_path + '/_temp_decode') 439 | 440 | def corrupt_target_fix(trg, decoder_masks, vocab_size, weight=0.1, cor_p=[0.1, 0.1, 0.1, 0.1]): 441 | batch_size, max_trg_len = trg.size() # actual trg len 442 | max_dec_len = decoder_masks.size(1) # 2 * actual src len 443 | dec_lens = (decoder_masks == 1).sum(-1).cpu().numpy() 444 | trg_lens = (trg != 1).sum(-1).data.cpu().numpy() 445 | 446 | num_corrupts = np.array( [ np.random.choice(dec_lens[bidx]//2, 447 | min( max( math.floor(weight * (dec_lens[bidx]//2)), 1 ), dec_lens[bidx]//2), 448 | replace=False ) \ 449 | for bidx in range(batch_size) ] ) 450 | 451 | #min_len = min(max_trg_len, max_dec_len) 452 | decoder_input = np.ones((batch_size, max_dec_len)) 453 | decoder_input.fill(3) 454 | #decoder_input[:, :min_len] = trg[:, :min_len].data.cpu().numpy() 455 | 456 | for bidx in range(batch_size): 457 | min_len = min(dec_lens[bidx], trg_lens[bidx]) 458 | decoder_input[bidx][:min_len] = trg[bidx, :min_len].data.cpu().numpy() 459 | nr_list = num_corrupts[bidx] 460 | for nr in nr_list: 461 | 462 | prob = np.random.rand() 463 | 464 | #### each corruption changes multiple words 465 | if prob < sum(cor_p[:1]): # repeat 466 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 467 | 468 | elif prob < sum(cor_p[:2]): # drop 469 | decoder_input[bidx][nr:-1] = decoder_input[bidx][nr+1:] 470 | 471 | #### each corruption changes one word 472 | elif prob < sum(cor_p[:3]): # replace word with random word 473 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 474 | 475 | #### each corruption changes two words 476 | elif prob < sum(cor_p[:4]): # swap 477 | temp = decoder_input[bidx][nr] 478 | decoder_input[bidx][nr] = decoder_input[bidx][nr+1] 479 | decoder_input[bidx][nr+1] = temp 480 | 481 | result = torch.from_numpy(decoder_input).long() 482 | if decoder_masks.is_cuda: 483 | result = result.cuda(decoder_masks.get_device()) 484 | return Variable(result, requires_grad=False) 485 | 486 | def corrupt_target(trg, decoder_masks, vocab_size, weight=0.1, cor_p=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]): 487 | batch_size, max_trg_len = trg.size() 488 | max_dec_len = decoder_masks.size(1) 489 | dec_lens = (decoder_masks == 1).sum(-1).cpu().numpy() 490 | 491 | num_corrupts = np.array( [ np.random.choice(dec_lens[bidx]-1, 492 | min( max( math.floor(weight * dec_lens[bidx]), 1 ), dec_lens[bidx]-1 ), 493 | replace=False ) \ 494 | for bidx in range(batch_size) ] ) 495 | 496 | min_len = min(max_trg_len, max_dec_len) 497 | decoder_input = np.ones((batch_size, max_dec_len)) 498 | decoder_input.fill(3) 499 | decoder_input[:, :min_len] = trg[:, :min_len].data.cpu().numpy() 500 | 501 | for bidx in range(batch_size): 502 | nr_list = num_corrupts[bidx] 503 | for nr in nr_list: 504 | 505 | prob = np.random.rand() 506 | 507 | #### each corruption changes multiple words 508 | if prob < sum(cor_p[:1]): # repeat 509 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 510 | 511 | elif prob < sum(cor_p[:2]): # drop 512 | decoder_input[bidx][nr:-1] = decoder_input[bidx][nr+1:] 513 | 514 | elif prob < sum(cor_p[:3]): # add random word 515 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 516 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 # sample except UNK/PAD/INIT/EOS 517 | 518 | #### each corruption changes one word 519 | elif prob < sum(cor_p[:4]): # repeat and drop next 520 | decoder_input[bidx][nr+1] = decoder_input[bidx][nr] 521 | 522 | elif prob < sum(cor_p[:5]): # replace word with random word 523 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 524 | 525 | #### each corruption changes two words 526 | elif prob < sum(cor_p[:6]): # swap 527 | temp = decoder_input[bidx][nr] 528 | decoder_input[bidx][nr] = decoder_input[bidx][nr+1] 529 | decoder_input[bidx][nr+1] = temp 530 | 531 | elif prob < sum(cor_p[:7]): # global swap 532 | swap_idx = np.random.randint(1, dec_lens[bidx]-nr) + nr 533 | temp = decoder_input[bidx][nr] 534 | decoder_input[bidx][nr] = decoder_input[bidx][swap_idx] 535 | decoder_input[bidx][swap_idx] = temp 536 | 537 | result = torch.from_numpy(decoder_input).long() 538 | if decoder_masks.is_cuda: 539 | result = result.cuda(decoder_masks.get_device()) 540 | return Variable(result, requires_grad=False) 541 | 542 | def drop(sentence, n_d): 543 | cur_len = np.sum( sentence != 1 ) 544 | for idx in range(n_d): 545 | drop_pos = random.randint(0, cur_len - 1) # a <= N <= b 546 | sentence[drop_pos:-1] = sentence[drop_pos+1:] 547 | cur_len = cur_len - 1 548 | sentence[-n_d:] = 1 549 | return sentence 550 | 551 | def repeat(sentence, n_r): 552 | cur_len = np.sum( sentence != 1 ) 553 | for idx in range(n_r): 554 | drop_pos = random.randint(0, cur_len) # a <= N <= b 555 | sentence[drop_pos+1:] = sentence[drop_pos:-1] 556 | sentence[cur_len:] = 1 557 | return sentence 558 | 559 | def remove_repeats(lst_of_sentences): 560 | lst = [] 561 | for sentence in lst_of_sentences: 562 | lst.append( " ".join([x[0] for x in groupby(sentence.split())]) ) 563 | return lst 564 | 565 | def remove_repeats_tensor(tensor): 566 | tensor = tensor.data.cpu() 567 | newtensor = tensor.clone() 568 | batch_size, seq_len = tensor.size() 569 | for bidx in range(batch_size): 570 | for sidx in range(seq_len-1): 571 | if newtensor[bidx, sidx] == newtensor[bidx, sidx+1]: 572 | newtensor[bidx, sidx:-1] = newtensor[bidx, sidx+1:] 573 | return Variable(newtensor) 574 | 575 | def mkdir(path): 576 | if not os.path.exists(path): 577 | os.mkdir(path) 578 | 579 | def print_bleu(bleu_output, verbose=True): 580 | (final_bleu, prec, bp, ref_lengths, hyp_lengths) = bleu_output 581 | ratio = 0 if ref_lengths == 0 else hyp_lengths/ref_lengths 582 | if verbose: 583 | return "BLEU = {:.2f}, {:.1f}/{:.1f}/{:.1f}/{:.1f} (BP={:.3f}, ratio={:.3f}, hyp_len={}, ref_len={})".format( 584 | final_bleu, prec[0], prec[1], prec[2], prec[3], bp, ratio, hyp_lengths, ref_lengths 585 | ) 586 | else: 587 | return "BLEU = {:.2f}, {:.1f}/{:.1f}/{:.1f}/{:.1f} (BP={:.3f}, ratio={:.3f})".format( 588 | final_bleu, prec[0], prec[1], prec[2], prec[3], bp, ratio 589 | ) 590 | 591 | def set_eos(argmax): 592 | new_argmax = Variable(argmax.data.new(*argmax.size()), requires_grad=False) 593 | new_argmax.fill_(3) 594 | batch_size, seq_len = argmax.size() 595 | argmax_lst = argmax.data.cpu().numpy().tolist() 596 | for bidx in range(batch_size): 597 | if 3 in argmax_lst[bidx]: 598 | idx = argmax_lst[bidx].index(3) 599 | if idx > 0 : 600 | new_argmax[bidx,:idx] = argmax[bidx,:idx] 601 | return new_argmax 602 | 603 | def init_encoder(model, saved): 604 | saved_ = {k.replace("encoder.",""):v for (k,v) in saved.items() if "encoder" in k} 605 | encoder = model.encoder 606 | encoder.load_state_dict(saved_) 607 | return model 608 | 609 | def oracle_converged(bleu_hist, num_items=5): 610 | batch_size = len(bleu_hist) 611 | converged = [False for bidx in range(batch_size)] 612 | for bidx in range(batch_size): 613 | if len(bleu_hist[bidx]) < num_items: 614 | converged[bidx] = False 615 | else: 616 | converged[bidx] = True 617 | hist = bleu_hist[bidx][-num_items:] 618 | for item in hist[1:]: 619 | if item > hist[0]: 620 | converged[bidx] = False # if BLEU improves in 4 iters, not converged 621 | return converged 622 | 623 | def equality_converged(output_hist, num_items=5): 624 | batch_size = len(output_hist) 625 | converged = [False for bidx in range(batch_size)] 626 | for bidx in range(batch_size): 627 | if len(output_hist[bidx]) < num_items: 628 | converged[bidx] = False 629 | else: 630 | converged[bidx] = False 631 | hist = output_hist[bidx][-num_items:] 632 | for item in hist[1:]: 633 | if item == hist[0]: 634 | converged[bidx] = True # if out_i == out_j for (j = i+1, i+2, i+3, i+4), converged 635 | return converged 636 | 637 | def jaccard_converged(multiset_hist, num_items=5, jaccard_thresh=1.0): 638 | batch_size = len(multiset_hist) 639 | converged = [False for bidx in range(batch_size)] 640 | for bidx in range(batch_size): 641 | if len(multiset_hist[bidx]) < num_items: 642 | converged[bidx] = False 643 | else: 644 | converged[bidx] = False 645 | hist = multiset_hist[bidx][-num_items:] 646 | for item in hist[1:]: 647 | 648 | inters = len(item.intersection(hist[0])) 649 | unio = len(item.union(hist[0])) 650 | jaccard_index = float(inters) / np.maximum(1.,float(unio)) 651 | 652 | if jaccard_index >= jaccard_thresh: 653 | converged[bidx] = True 654 | return converged 655 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['QT_QPA_PLATFORM']='offscreen' # weird can't ipdb with mscoco without this flag 3 | import torch 4 | import numpy as np 5 | from torchtext import data 6 | from torchtext import datasets 7 | from torch.nn import functional as F 8 | from torch.autograd import Variable 9 | 10 | import revtok 11 | import logging 12 | import random 13 | import ipdb 14 | import string 15 | import traceback 16 | import math 17 | import uuid 18 | import argparse 19 | import copy 20 | import time 21 | import pickle 22 | 23 | from train import train_model 24 | from distill import distill_model 25 | from model import FastTransformer, Transformer, INF, TINY, HighwayBlock, ResidualBlock, NonresidualBlock 26 | from utils import mkdir, organise_trg_len_dic, init_encoder 27 | from data import NormalField, NormalTranslationDataset, MSCOCODataset, data_path 28 | from time import gmtime, strftime 29 | from decode import decode_model 30 | 31 | import itertools 32 | from traceback import extract_tb 33 | from code import interact 34 | from pathlib import Path 35 | 36 | parser = argparse.ArgumentParser(description='Train a Transformer / FastTransformer.') 37 | 38 | # dataset settings 39 | parser.add_argument('--rf_finetune', action='store_true', default=False) 40 | parser.add_argument('--nat_finetune', action='store_true', default=False) 41 | parser.add_argument('--sample_method', type=str, default='sentence', choices=['sentence','stepwise']) 42 | parser.add_argument('--stepwise_sampletimes', type=int, default=10) 43 | parser.add_argument('--topk', type=int, default=5) 44 | parser.add_argument('--workers', type=int, default=5) 45 | 46 | 47 | parser.add_argument('--dataset', type=str, default='iwslt-ende', choices=['iwslt-ende', 'iwslt-deen', \ 48 | 'wmt14-ende', 'wmt14-deen', \ 49 | 'wmt16-enro', 'wmt16-roen', \ 50 | 'wmt17-enlv', 'wmt17-lven', \ 51 | 'mscoco']) 52 | parser.add_argument('--vocab_size', type=int, default=40000, help='limit the train set sentences to this many tokens') 53 | 54 | parser.add_argument('--valid_size', type=int, default=None, help='size of valid dataset (tested on coco only)') 55 | parser.add_argument('--load_vocab', action='store_true', help='load a pre-computed vocabulary') 56 | parser.add_argument('--load_dataset', action='store_true', default=False, help='load a pre-processed dataset') 57 | parser.add_argument('--save_dataset', action='store_true', default=False, help='save a pre-processed dataset') 58 | parser.add_argument('--max_len', type=int, default=None, help='limit the train set sentences to this many tokens') 59 | parser.add_argument('--max_train_data', type=int, default=None, help='limit the train set sentences to this many sentences') 60 | 61 | # model basic settings 62 | parser.add_argument('--prefix', type=str, default='[time]', help='prefix to denote the model, nothing or [time]') 63 | parser.add_argument('--fast', dest='model', action='store_const', const=FastTransformer, default=Transformer) 64 | 65 | # model ablation settings 66 | parser.add_argument('--ffw_block', type=str, default="residual", choices=['residual', 'highway', 'nonresidual']) 67 | parser.add_argument('--diag', action='store_true', default=False, help='ignore diagonal attention when doing self-attention.') 68 | parser.add_argument('--use_wo', action='store_true', default=True, help='use output weight matrix in multihead attention') 69 | parser.add_argument('--inputs_dec', type=str, default='pool', choices=['zeros', 'pool'], help='inputs to first decoder') 70 | parser.add_argument('--out_norm', action='store_true', default=False, help='normalize last softmax layer') 71 | parser.add_argument('--share_embed', action='store_true', default=True, help='share embeddings and linear out weight') 72 | parser.add_argument('--share_vocab', action='store_true', default=True, help='share vocabulary between src and target') 73 | parser.add_argument('--share_embed_enc_dec1', action='store_true', default=False, help='share embedding weigth between encoder and first decoder') 74 | parser.add_argument('--positional', action='store_true', default=True, help='incorporate positional information in key/value') 75 | parser.add_argument('--enc_last', action='store_true', default=False, help='attend only to last encoder hidden states') 76 | 77 | parser.add_argument('--params', type=str, default='user', choices=['user', 'small', 'big']) 78 | parser.add_argument('--n_layers', type=int, default=5, help='number of layers') 79 | parser.add_argument('--n_heads', type=int, default=2, help='number of heads') 80 | parser.add_argument('--d_model', type=int, default=278, help='number of heads') 81 | parser.add_argument('--d_hidden', type=int, default=507, help='number of heads') 82 | 83 | parser.add_argument('--num_decs', type=int, default=2, help='1 (one shared decoder) \ 84 | 2 (2nd decoder and above is shared) \ 85 | -1 (no decoder is shared)') 86 | parser.add_argument('--train_repeat_dec', type=int, default=4, help='number of times to repeat generation') 87 | parser.add_argument('--valid_repeat_dec', type=int, default=4, help='number of times to repeat generation') 88 | parser.add_argument('--use_argmax', action='store_true', default=False) 89 | parser.add_argument('--next_dec_input', type=str, default='both', choices=['emb', 'out', 'both']) 90 | 91 | parser.add_argument('--bp', type=float, default=1.0, help='number of heads') 92 | 93 | # running setting 94 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'distill']) # distill : take a trained AR model and decode a training set 95 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use or -1 for CPU') 96 | parser.add_argument('--seed', type=int, default=19920206, help='seed for randomness') 97 | parser.add_argument('--distill_which', type=int, default=0 ) 98 | parser.add_argument('--decode_which', type=int, default=0 ) 99 | parser.add_argument('--test_which', type=str, default='test', choices=['valid', 'test']) # distill : take a trained AR model and decode a training set 100 | 101 | # training 102 | parser.add_argument('--no_tqdm', action="store_true", default=False) 103 | parser.add_argument('--eval_every', type=int, default=1000, help='run dev every') 104 | parser.add_argument('--save_every', type=int, default=-1, help='5000') 105 | parser.add_argument('--batch_size', type=int, default=2048, help='# of tokens processed per batch') 106 | parser.add_argument('--optimizer', type=str, default='Adam') 107 | parser.add_argument('--lr', type=float, default=3e-4) 108 | parser.add_argument('--lr_schedule', type=str, default='anneal', choices=['transformer', 'anneal', 'fixed']) 109 | parser.add_argument('--warmup', type=int, default=16000, help='maximum steps to linearly anneal the learning rate') 110 | parser.add_argument('--anneal_steps', type=int, default=250000, help='maximum steps to linearly anneal the learning rate') 111 | parser.add_argument('--maximum_steps', type=int, default=5000000, help='maximum steps you take to train a model') 112 | parser.add_argument('--drop_ratio', type=float, default=0.1, help='dropout ratio') 113 | parser.add_argument('--drop_len_pred', type=float, default=0.3, help='dropout ratio for length prediction module') 114 | parser.add_argument('--input_drop_ratio', type=float, default=0.1, help='dropout ratio only for inputs') 115 | parser.add_argument('--grad_clip', type=float, default=-1.0, help='gradient clipping') 116 | 117 | # target length 118 | parser.add_argument('--trg_len_option', type=str, default="reference", choices=['reference', "noisy_ref", 'average', 'fixed', 'predict']) 119 | #parser.add_argument('--trg_len_option_valid', type=str, default="average", choices=['reference', "noisy_ref", 'average', 'fixed', 'predict']) 120 | parser.add_argument('--trg_len_ratio', type=float, default=2.0) 121 | parser.add_argument('--decoder_input_how', type=str, default='copy', choices=['copy', 'interpolate', 'pad', 'wrap']) 122 | parser.add_argument('--finetune_trg_len', action='store_true', default=False, help="finetune one layer that predicts target len offset") 123 | parser.add_argument('--use_predicted_trg_len', action='store_true', default=False, help="use predicted target len masks") 124 | parser.add_argument('--max_offset', type=int, default=20, help='max target len offset of the whole dataset') 125 | 126 | # denoising 127 | parser.add_argument('--denoising_prob', type=float, default=0.0, help="use denoising with this probability") 128 | parser.add_argument('--denoising_weight', type=float, default=0.1, help="use denoising with this weight.") 129 | parser.add_argument('--corruption_probs', type=str, default="0-0-0-1-1-1-0", help="probs for \ 130 | repeat\ 131 | add random word\ 132 | repeat and drop next\ 133 | replace with random word\ 134 | swap\ 135 | global swap") 136 | parser.add_argument('--denoising_out_weight', type=float, default=0.0, help="use denoising for decoder output with this weight.") 137 | parser.add_argument('--anneal_denoising_weight', action='store_true', default=False, help="anneal denoising weight over time") 138 | parser.add_argument('--layerwise_denoising_weight', action='store_true', default=False, help="use different denoising weight per iteration") 139 | 140 | # self-distillation 141 | parser.add_argument('--self_distil', type=float, default=0.0) 142 | 143 | # decoding 144 | parser.add_argument('--length_ratio', type=int, default=2, help='maximum lengths of decoding') 145 | parser.add_argument('--length_dec', type=int, default=20, help='maximum length of decoding for MSCOCO dataset') 146 | parser.add_argument('--beam_size', type=int, default=1, help='beam-size used in Beamsearch, default using greedy decoding') 147 | parser.add_argument('--f_size', type=int, default=1, help='heap size for sampling/searching in the fertility space') 148 | parser.add_argument('--alpha', type=float, default=1, help='length normalization weights') 149 | parser.add_argument('--temperature', type=float, default=1, help='smoothing temperature for noisy decodig') 150 | parser.add_argument('--remove_repeats', action='store_true', default=False, help='debug mode: no saving or tensorboard') 151 | parser.add_argument('--num_samples', type=int, default=2, help='number of samples to use when using non-argmax decoding') 152 | parser.add_argument('--T', type=float, default=1, help='softmax temperature when decoding') 153 | 154 | #parser.add_argument('--jaccard_stop', action='store_true', default=False, help='use jaccard index to stop decoding') 155 | parser.add_argument('--adaptive_decoding', type=str, default=None, choices=["oracle", "jaccard", "equality"]) 156 | parser.add_argument('--adaptive_window', type=int, default=5, help='window size for adaptive decoding') 157 | parser.add_argument('--len_stop', action='store_true', default=False, help='use length of sentence to stop decoding') 158 | parser.add_argument('--jaccard_thresh', type=float, default=1.0) 159 | 160 | # model saving/reloading, output translations 161 | parser.add_argument('--load_from', type=str, default=None, help='load from checkpoint') 162 | parser.add_argument('--load_encoder_from', type=str, default=None, help='load from checkpoint') 163 | parser.add_argument('--resume', action='store_true', help='when loading from the saved model, it resumes from that.') 164 | parser.add_argument('--use_distillation', action='store_true', default=False, help='train a NAR model from output of an AR model') 165 | 166 | # debugging 167 | parser.add_argument('--debug', action='store_true', help='debug mode: no saving or tensorboard') 168 | parser.add_argument('--tensorboard', action='store_true', help='use TensorBoard') 169 | 170 | # save path 171 | parser.add_argument('--main_path', type=str, default="./") # /misc/vlgscratch2/ChoGroup/mansimov/ 172 | parser.add_argument('--model_path', type=str, default="models") # /misc/vlgscratch2/ChoGroup/mansimov/ 173 | parser.add_argument('--log_path', type=str, default="logs") # /misc/vlgscratch2/ChoGroup/mansimov/ 174 | parser.add_argument('--event_path', type=str, default="events") # /misc/vlgscratch2/ChoGroup/mansimov/ 175 | parser.add_argument('--decoding_path', type=str, default="decoding") # /misc/vlgscratch2/ChoGroup/mansimov/ 176 | parser.add_argument('--distill_path', type=str, default="distill") # /misc/vlgscratch2/ChoGroup/mansimov/ 177 | 178 | parser.add_argument('--model_str', type=str, default="") # /misc/vlgscratch2/ChoGroup/mansimov/ 179 | 180 | # ----------------------------------------------------------------------------------------------------------------- # 181 | 182 | args = parser.parse_args() 183 | 184 | if args.model is Transformer: 185 | args.num_decs = 1 186 | args.train_repeat_dec = 1 187 | args.valid_repeat_dec = 1 188 | 189 | args.main_path = Path(args.main_path) 190 | 191 | args.model_path = args.main_path / args.model_path / args.dataset 192 | args.log_path = args.main_path / args.log_path / args.dataset 193 | args.event_path = args.main_path / args.event_path / args.dataset 194 | args.decoding_path = args.main_path / args.decoding_path / args.dataset 195 | args.distill_path = args.main_path / args.distill_path / args.dataset 196 | 197 | if not args.debug: 198 | for path in [args.model_path, args.log_path, args.event_path, args.decoding_path, args.distill_path]: 199 | path.mkdir(parents=True, exist_ok=True) 200 | 201 | if args.prefix == '[time]': 202 | args.prefix = strftime("%m.%d_%H.%M.", gmtime()) 203 | 204 | if args.train_repeat_dec == 1: 205 | args.num_decs = 1 206 | 207 | # get the langauage pairs: 208 | if args.dataset != "mscoco": 209 | args.src = args.dataset[-4:][:2] # source language 210 | args.trg = args.dataset[-4:][2:] # target language 211 | else: 212 | args.src = "" 213 | args.trg = "" 214 | 215 | if args.params == 'small': 216 | hparams = {'d_model': 278, 'd_hidden': 507, 'n_layers': 5, 'n_heads': 2, 'warmup': 746} 217 | args.__dict__.update(hparams) 218 | elif args.params == 'big': 219 | if args.dataset != "mscoco": 220 | hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'warmup': 16000} 221 | else: 222 | hparams = {'d_model': 512, 'd_hidden': 512, 'n_heads': 8, 'warmup': 16000} 223 | args.__dict__.update(hparams) 224 | 225 | hp_str = "{}".format('' if args.model is FastTransformer else 'ar_') + \ 226 | "{}".format(args.model_str+"_" if args.model_str != "" else "") + \ 227 | "{}".format("ar_distil_" if args.use_distillation else "") + \ 228 | "{}".format("ptrn_enc_" if not args.load_encoder_from is None else "") + \ 229 | "{}".format("ptrn_model_" if not args.load_from is None else "") + \ 230 | "voc{}k_".format(args.vocab_size//1000) + \ 231 | "{}_".format(args.batch_size) + \ 232 | "{}".format("" if args.share_embed else "no_share_emb_") + \ 233 | "{}".format("" if args.share_vocab else "no_share_voc_") + \ 234 | "{}".format("share_emb_enc_dec1_" if args.share_embed_enc_dec1 else "") + \ 235 | "{}_{}_{}_{}_".format(args.n_layers, args.d_model, args.d_hidden, args.n_heads) + \ 236 | "{}".format("enc_last_" if args.enc_last else "") + \ 237 | "drop_{}_".format(args.drop_ratio) + \ 238 | "{}".format("drop_len_pred_{}_".format(args.drop_len_pred) if args.finetune_trg_len else "") + \ 239 | "{}_".format(args.lr) + \ 240 | "{}_".format("{}".format(args.lr_schedule[:4])) + \ 241 | "{}".format("anneal_steps_{}_".format(args.anneal_steps) if args.lr_schedule == "anneal" else "") + \ 242 | "{}_".format(args.ffw_block[:4]) + \ 243 | "{}".format("clip_{}_".format(args.grad_clip) if args.grad_clip != -1.0 else "") + \ 244 | "{}".format("diag_" if args.diag else "") + \ 245 | ("tr{}_".format(args.train_repeat_dec) + \ 246 | "{}decs_".format(args.num_decs) + \ 247 | "{}_".format(args.bp if args.trg_len_option == "noisy_ref" else "") + \ 248 | "{}_".format(args.trg_len_option[:4]) + \ 249 | "{}_".format(args.next_dec_input) + \ 250 | "{}".format("trg_{}x_".format(args.trg_len_ratio) if "fixed" in args.trg_len_option else "") + \ 251 | "{}_".format(args.decoder_input_how[:4]) + \ 252 | "{}".format("dn_{}_".format(args.denoising_prob) if args.denoising_prob != 0.0 else "") + \ 253 | "{}".format("dn_w{}_".format(args.denoising_weight) if args.denoising_prob != 0.0 and not args.anneal_denoising_weight and not args.layerwise_denoising_weight else "") + \ 254 | "{}".format("dn_anneal_" if args.anneal_denoising_weight else "") + \ 255 | "{}".format("dn_layer_" if args.layerwise_denoising_weight else "") + \ 256 | "{}".format("dn_out_w{}_".format(args.denoising_out_weight) if args.denoising_out_weight != 0.0 else "") + \ 257 | "{}".format("distil{}_".format(args.self_distil) if args.self_distil != 0.0 else "") + \ 258 | "{}".format("argmax_" if args.use_argmax else "sample_") + \ 259 | "{}".format("out_norm_" if args.out_norm else "") + \ 260 | "" if args.model is FastTransformer else "" ) 261 | 262 | args.id_str = Path(args.prefix + hp_str) 263 | 264 | args.corruption_probs = [int(xx) for xx in args.corruption_probs.split("-") ] 265 | c_probs_sum = sum(args.corruption_probs) 266 | args.corruption_probs = [xx/c_probs_sum for xx in args.corruption_probs] 267 | 268 | if args.ffw_block == "nonresidual": 269 | args.block_cls = NonresidualBlock 270 | elif args.ffw_block == "residual": 271 | args.block_cls = ResidualBlock 272 | elif args.ffw_block == "highway": 273 | args.block_cls = HighwayBlock 274 | else: 275 | raise 276 | 277 | # setup logger settings 278 | logger = logging.getLogger() 279 | logger.setLevel(logging.DEBUG) 280 | formatter = logging.Formatter('%(asctime)s %(levelname)s: - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 281 | 282 | ch = logging.StreamHandler() 283 | ch.setLevel(logging.DEBUG) 284 | ch.setFormatter(formatter) 285 | logger.addHandler(ch) 286 | if not args.debug: 287 | fh = logging.FileHandler( str( args.log_path / args.id_str ) + ".txt" ) 288 | fh.setLevel(logging.DEBUG) 289 | fh.setFormatter(formatter) 290 | logger.addHandler(fh) 291 | 292 | # setup random seeds 293 | random.seed(args.seed) 294 | np.random.seed(args.seed) 295 | torch.manual_seed(args.seed) 296 | torch.cuda.manual_seed_all(args.seed) 297 | 298 | # ----------------------------------------------------------------------------------------------------------------- # 299 | if args.dataset != "mscoco": 300 | DataField = NormalField 301 | TRG = DataField(init_token='', eos_token='', batch_first=True) 302 | SRC = DataField(batch_first=True) if not args.share_vocab else TRG 303 | # NOTE : UNK, PAD, INIT, EOS 304 | 305 | # setup many datasets (need to manaually setup) 306 | data_prefix = Path(data_path(args.dataset)) 307 | args.data_prefix = data_prefix 308 | if args.dataset == "mscoco": 309 | data_prefix = str(data_prefix) 310 | train_dir = "train" if not args.use_distillation else "distill/" + args.dataset[-4:] 311 | if args.dataset == 'iwslt-ende' or args.dataset == 'iwslt-deen': 312 | #if args.resume: 313 | # train_dir += "2" 314 | logger.info("TRAINING CORPUS : " + str(data_prefix / train_dir / 'train.tags.en-de.bpe')) 315 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / 'train.tags.en-de.bpe'), 316 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 317 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 318 | if args.mode in ["train", "distill"] else None 319 | 320 | dev_dir = "dev" 321 | dev_file = "valid.en-de.bpe" 322 | if args.mode == "test" and args.decode_which > 0: 323 | dev_dir = "dev_split" 324 | dev_file += ".{}".format(args.decode_which) 325 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 326 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 327 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 328 | 329 | test_data = None 330 | 331 | elif args.dataset == 'wmt14-ende' or args.dataset == 'wmt14-deen': 332 | train_file = 'all_en-de.bpe' 333 | if args.mode == "distill" and args.distill_which > 0: 334 | train_file += ".{}".format(args.distill_which) 335 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / train_file), 336 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 337 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 338 | if args.mode in ["train", "distill"] else None 339 | 340 | dev_dir = "dev" 341 | dev_file = "wmt13-en-de.bpe" 342 | if args.mode == "test" and args.decode_which > 0: 343 | dev_dir = "dev_split" 344 | dev_file += ".{}".format(args.decode_which) 345 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 346 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 347 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 348 | 349 | test_dir = "test" 350 | test_file = "wmt14-en-de.bpe" 351 | if args.mode == "test" and args.decode_which > 0: 352 | test_dir = "test_split" 353 | test_file += ".{}".format(args.decode_which) 354 | test_data = NormalTranslationDataset(path=str(data_prefix / test_dir / test_file), 355 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 356 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 357 | 358 | elif args.dataset == 'wmt16-enro' or args.dataset == 'wmt16-roen': 359 | train_file = 'corpus.bpe' 360 | if args.mode == "distill" and args.distill_which > 0: 361 | train_file += ".{}".format(args.distill_which) 362 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / train_file), 363 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 364 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 365 | if args.mode in ["train", "distill"] else None 366 | 367 | dev_dir = "dev" 368 | dev_file = "dev.bpe" 369 | if args.mode == "test" and args.decode_which > 0: 370 | dev_dir = "dev_split" 371 | dev_file += ".{}".format(args.decode_which) 372 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 373 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 374 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 375 | 376 | test_dir = "test" 377 | test_file = "test.bpe" 378 | if args.mode == "test" and args.decode_which > 0: 379 | test_dir = "test_split" 380 | test_file += ".{}".format(args.decode_which) 381 | test_data = NormalTranslationDataset(path=str(data_prefix / test_dir / test_file), 382 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 383 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 384 | 385 | elif args.dataset == 'wmt17-enlv' or args.dataset == 'wmt17-lven': 386 | train_data, dev_data, test_data = NormalTranslationDataset.splits( 387 | path=data_prefix, train='{}/corpus.bpe'.format(train_dir), test='test/newstest2017.bpe', 388 | validation='dev/newsdev2017.bpe', exts=('.{}'.format(args.src), '.{}'.format(args.trg)), 389 | fields=(SRC, TRG), load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 390 | 391 | elif args.dataset == "mscoco": 392 | mscoco_dataset = MSCOCODataset(path=data_prefix, batch_size=args.batch_size, \ 393 | max_len=args.max_len, valid_size=args.valid_size, \ 394 | distill=(args.mode == "distill"), use_distillation=args.use_distillation) 395 | train_data, train_sampler = mscoco_dataset.train_data, mscoco_dataset.train_sampler 396 | dev_data, dev_sampler = mscoco_dataset.valid_data, mscoco_dataset.valid_sampler 397 | test_data, test_sampler = mscoco_dataset.test_data, mscoco_dataset.test_sampler 398 | if args.trg_len_option == "predict" and args.max_offset == None: 399 | args.max_offset = mscoco_dataset.max_dataset_length 400 | else: 401 | raise NotImplementedError 402 | # build vocabularies for translation dataset 403 | if args.dataset != "mscoco": 404 | vocab_path = data_prefix / 'vocab' / '{}_{}_{}.pt'.format('{}-{}'.format(args.src, args.trg), args.vocab_size, 'shared' if args.share_vocab else '') 405 | if args.load_vocab and vocab_path.exists(): 406 | src_vocab, trg_vocab = torch.load(str(vocab_path)) 407 | SRC.vocab = src_vocab 408 | TRG.vocab = trg_vocab 409 | logger.info('vocab loaded') 410 | else: 411 | assert (not train_data is None) 412 | if not args.share_vocab: 413 | SRC.build_vocab(train_data, dev_data, max_size=args.vocab_size) 414 | TRG.build_vocab(train_data, dev_data, max_size=args.vocab_size) 415 | if not args.debug: 416 | logger.info('save the vocabulary') 417 | vocab_path.parent.mkdir(parents=True, exist_ok=True) 418 | torch.save([SRC.vocab, TRG.vocab], str(vocab_path)) 419 | args.__dict__.update({'trg_vocab': len(TRG.vocab), 'src_vocab': len(SRC.vocab)}) 420 | # for mscoco 421 | else: 422 | vocab_path = os.path.join(data_prefix, "vocab.pkl") 423 | assert (args.load_vocab == True) 424 | if args.load_vocab and os.path.exists(vocab_path): 425 | vocab = pickle.load(open(vocab_path, 'rb')) 426 | mscoco_dataset.vocab = vocab 427 | else: 428 | logger.info('save the vocabulary') 429 | mscoco_dataset.build_vocab() 430 | pickle.dump(mscoco_dataset.vocab, open(vocab_path, 'wb')) 431 | print ('vocab building done') 432 | args.__dict__.update({'vocab': len(mscoco_dataset.vocab)}) 433 | 434 | def dyn_batch_with_padding(new, i, sofar): 435 | prev_max_len = sofar / (i - 1) if i > 1 else 0 436 | return max(len(new.src), len(new.trg), prev_max_len) * i 437 | 438 | def dyn_batch_without_padding(new, i, sofar): 439 | return sofar + max(len(new.src), len(new.trg)) 440 | 441 | # not sure if absolutely necessary? seems to mess things up. 442 | if args.dataset != "mscoco" and args.share_vocab: 443 | SRC = copy.deepcopy(SRC) 444 | SRC.init_token = None 445 | SRC.eos_token = None 446 | 447 | for data_ in [train_data, dev_data, test_data]: 448 | if not data_ is None: 449 | data_.fields['src'] = SRC 450 | 451 | if args.dataset != "mscoco": 452 | if not train_data is None: 453 | logger.info("before pruning : {} training examples".format(len(train_data.examples))) 454 | if args.max_len is not None: 455 | if args.dataset != "mscoco": 456 | train_data.examples = [ex for ex in train_data.examples if len(ex.trg) <= args.max_len] 457 | if args.max_train_data is not None: 458 | train_data.examples = train_data.examples[:args.max_train_data] 459 | logger.info("after pruning : {} training examples".format(len(train_data.examples))) 460 | 461 | if args.batch_size == 1: # speed-test: one sentence per batch. 462 | batch_size_fn = lambda new, count, sofar: count 463 | else: 464 | batch_size_fn = dyn_batch_without_padding# if args.model is Transformer else dyn_batch_with_padding 465 | 466 | if args.dataset != "mscoco": 467 | if args.mode == "train": 468 | train_flag = True 469 | elif args.mode == "distill": 470 | train_flag = False 471 | else: 472 | train_flag = False 473 | train_real = data.BucketIterator(train_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 474 | train=train_flag, repeat=train_flag, shuffle=train_flag) if not train_data is None else None 475 | dev_real = data.BucketIterator(dev_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 476 | train=False, repeat=False, shuffle=False) if not dev_data is None else None 477 | test_real = data.BucketIterator(test_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 478 | train=False, repeat=False, shuffle=False) if not test_data is None else None 479 | else: 480 | train_real = torch.utils.data.DataLoader( 481 | train_data, batch_sampler=train_sampler, pin_memory=args.gpu>-1, num_workers=8) 482 | dev_real = torch.utils.data.DataLoader( 483 | dev_data, batch_sampler=dev_sampler, pin_memory=args.gpu>-1, num_workers=8) 484 | test_real = torch.utils.data.DataLoader( 485 | test_data, batch_sampler=test_sampler, pin_memory=args.gpu>-1, num_workers=8) 486 | def rcycle(iterable): 487 | saved = [] # In-memory cache 488 | for element in iterable: 489 | yield element 490 | saved.append(element) 491 | while saved: 492 | random.shuffle(saved) # Shuffle every batch 493 | for element in saved: 494 | yield element 495 | if args.mode != "distill": 496 | train_real = rcycle(train_real) 497 | 498 | logger.info("build the dataset. done!") 499 | # ----------------------------------------------------------------------------------------------------------------- # 500 | 501 | # ----------------------------------------------------------------------------------------------------------------- # 502 | if args.mode == "train": 503 | logger.info(args) 504 | 505 | logger.info('Starting with HPARAMS: {}'.format(hp_str)) 506 | 507 | # build the model 508 | if args.dataset != "mscoco": 509 | model = args.model(src=SRC, trg=TRG, args=args) 510 | else: 511 | model = args.model(src=None, trg=mscoco_dataset, args=args) 512 | 513 | if args.mode == "train": 514 | logger.info(str(model)) 515 | 516 | if args.load_encoder_from is not None: 517 | if args.gpu > -1: 518 | with torch.cuda.device(args.gpu): 519 | encoder = torch.load(str(args.model_path / args.load_encoder_from) + '.pt', 520 | map_location=lambda storage, loc: storage.cuda()) 521 | else: 522 | encoder = torch.load(str(args.model_path / args.load_encoder_from) + '.pt', 523 | map_location=lambda storage, loc: storage) 524 | init_encoder(model, encoder) 525 | logger.info("Pretrained encoder loaded.") 526 | 527 | if args.load_from is not None: 528 | if args.gpu > -1: 529 | with torch.cuda.device(args.gpu): 530 | model.load_state_dict(torch.load(str(args.model_path / args.load_from) + '.pt', 531 | map_location=lambda storage, loc: storage.cuda()), strict=False) # load the pretrained models. 532 | else: 533 | model.load_state_dict(torch.load(str(args.model_path / args.load_from) + '.pt', 534 | map_location=lambda storage, loc: storage), strict=False) # load the pretrained models. 535 | logger.info("Pretrained model loaded.") 536 | 537 | params, param_names = [], [] 538 | for name, param in model.named_parameters(): 539 | params.append(param) 540 | param_names.append(name) 541 | 542 | if args.mode == "train": 543 | logger.info(param_names) 544 | logger.info("Size {}".format( sum( [ np.prod(x.size()) for x in params ] )) ) 545 | 546 | # use cuda 547 | if args.gpu > -1: 548 | model.cuda(args.gpu) 549 | 550 | # additional information 551 | args.__dict__.update({'hp_str': hp_str, 'logger': logger}) 552 | 553 | # ----------------------------------------------------------------------------------------------------------------- # 554 | 555 | trg_len_dic = None 556 | if args.dataset != "mscoco" and (not "ro" in args.dataset or "predict" in args.trg_len_option or "average" in args.trg_len_option): 557 | #if "predict" in args.trg_len_option or "average" in args.trg_len_option: 558 | #trg_len_dic = torch.load(os.path.join(data_path(args.dataset), "trg_len")) 559 | trg_len_dic = torch.load( str(args.data_prefix / "trg_len_dic" / args.dataset[-4:]) ) 560 | trg_len_dic = organise_trg_len_dic(trg_len_dic) 561 | if args.mode == 'train': 562 | logger.info('starting training') 563 | 564 | if args.dataset != "mscoco": 565 | train_model(args, model, train_real, dev_real, src=SRC, trg=TRG, trg_len_dic=trg_len_dic) 566 | else: 567 | train_model(args, model, train_real, dev_real, src=None, trg=mscoco_dataset, trg_len_dic=trg_len_dic) 568 | 569 | elif args.mode == 'test': 570 | logger.info('starting decoding from the pre-trained model, on the test set...') 571 | args.decoding_path = args.decoding_path / args.load_from 572 | name_suffix = 'b={}_{}.txt'.format(args.beam_size, args.load_from) 573 | names = ['src.{}'.format(name_suffix), 'trg.{}'.format(name_suffix), 'dec.{}'.format(name_suffix)] 574 | 575 | if args.test_which == "test" and (not test_real is None): 576 | logger.info("---------- Decoding TEST set ----------") 577 | decode_model(args, model, test_real, evaluate=True, trg_len_dic=trg_len_dic, decoding_path=args.decoding_path, \ 578 | names=["test."+xx for xx in names], maxsteps=None) 579 | else: 580 | logger.info("---------- Decoding VALID set ----------") 581 | decode_model(args, model, dev_real, evaluate=True, trg_len_dic=trg_len_dic, decoding_path=args.decoding_path, \ 582 | names=["valid."+xx for xx in names], maxsteps=None) 583 | 584 | elif args.mode == 'distill': 585 | logger.info('starting decoding the training set from an AR model') 586 | args.distill_path = args.distill_path / args.id_str 587 | args.distill_path.mkdir(parents=True, exist_ok=True) 588 | name_suffix = args.distill_which 589 | names = ['src.{}'.format(name_suffix), 'trg.{}'.format(name_suffix), 'dec.{}'.format(name_suffix)] 590 | 591 | distill_model(args, model, train_real, evaluate=False, distill_path=args.distill_path, \ 592 | names=["train."+xx for xx in names], maxsteps=None) 593 | 594 | logger.info("done.") 595 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ipdb 3 | import torch 4 | from torch import nn 5 | import torch.nn.init as init 6 | from torch.nn import functional as F 7 | from torch.autograd import Variable, Function 8 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 9 | from collections import Counter 10 | 11 | import math 12 | import random 13 | 14 | from utils import computeGLEU, masked_sort, unsorted, make_decoder_masks, query_trg_len_dic,my_sentence_gleu 15 | from nltk.translate.gleu_score import sentence_gleu, corpus_gleu 16 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 17 | 18 | INF = 1e10 19 | TINY = 1e-9 20 | def shape(targets, target_lens): 21 | 22 | list_targets = [] 23 | begin = 0 24 | end = 0 25 | for length in target_lens: 26 | end += length 27 | list_targets.append([str(index) for index in targets[begin:end]]) 28 | begin += length 29 | 30 | return list_targets 31 | def parallel_gleu( inputs): 32 | (sample_idx, list_samples, list_targets, count,target_lens) = inputs 33 | l_samples = shape(sample_idx,target_lens) 34 | gleus = [] 35 | for j in range(count): 36 | for k in range(len(l_samples[j])): 37 | t = l_samples[j][k] 38 | l_samples[j][k] = list_samples[j][k] 39 | gleu = my_sentence_gleu([l_samples[j]], list_targets[j]) 40 | l_samples[j][k] = t 41 | gleus.append(gleu) 42 | return gleus 43 | 44 | class GradReverse(Function): 45 | @staticmethod 46 | def forward(ctx, x): 47 | return x.view_as(x) 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | return grad_output.neg() 52 | 53 | def grad_reverse(x): 54 | return GradReverse.apply(x) 55 | 56 | def positional_encodings_like(x, t=None): # hope to be differentiable 57 | if t is None: 58 | positions = torch.arange(0, x.size(-2)) # .expand(*x.size()[:2]) 59 | if x.is_cuda: 60 | positions = positions.cuda(x.get_device()) 61 | positions = Variable(positions.float()) 62 | else: 63 | positions = t 64 | # channels 65 | channels = torch.arange(0, x.size(-1), 2).float() / x.size(-1) # 0 2 4 6 ... (256) 66 | if x.is_cuda: 67 | channels = channels.cuda(x.get_device()) 68 | channels = 1 / (10000 ** Variable(channels)) 69 | # get the positional encoding: batch x target_len 70 | encodings = positions.unsqueeze(-1) @ channels.unsqueeze(0) # batch x target_len x 256 71 | encodings = torch.cat([torch.sin(encodings).unsqueeze(-1), torch.cos(encodings).unsqueeze(-1)], -1) 72 | encodings = encodings.contiguous().view(*encodings.size()[:-2], -1) # batch x target_len x 512 73 | 74 | if encodings.ndimension() == 2: 75 | encodings = encodings.unsqueeze(0).expand_as(x) 76 | 77 | return encodings 78 | 79 | class Linear(nn.Linear): 80 | def __init__(self, d_in, d_out, bias=True, out_norm=False): 81 | super().__init__(d_in, d_out, bias) 82 | self.out_norm = out_norm 83 | stdv = 1. / math.sqrt(self.weight.size(1)) 84 | init.uniform(self.weight, -stdv, stdv) 85 | if bias: 86 | self.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | size = x.size() 90 | if self.out_norm: 91 | weight = self.weight / (1e-6 + torch.sqrt((self.weight ** 2).sum(0, keepdim=True))) 92 | x_ = x / (1e-6 + torch.sqrt((x ** 2).sum(-1, keepdim=True))) 93 | logit_ = torch.mm(x_.contiguous().view(-1, size[-1]), weight.t()).view(*size[:-1], -1) 94 | if self.bias: 95 | logit_ = logit_ + self.bias 96 | return logit_ 97 | return super().forward( 98 | x.contiguous().view(-1, size[-1])).view(*size[:-1], -1) 99 | 100 | def demask(inputs, the_mask): 101 | # inputs: 1-D sequences 102 | # the_mask: batch x max-len 103 | outputs = Variable((the_mask == 0).long().view(-1)) # 1-D 104 | indices = torch.arange(0, outputs.size(0)) 105 | if inputs.is_cuda: 106 | indices = indices.cuda(inputs.get_device()) 107 | indices = indices.view(*the_mask.size()).long() 108 | indices = indices[the_mask] 109 | outputs[indices] = inputs 110 | return outputs.view(*the_mask.size()) 111 | 112 | # F.softmax has strange default behavior, normalizing over dim 0 for 3D inputs 113 | def softmax(x, T=1): 114 | return F.softmax(x/T, dim=-1) 115 | """ 116 | if x.dim() == 3: 117 | return F.softmax(x.transpose(0, 2)).transpose(0, 2) 118 | return F.softmax(x) 119 | """ 120 | 121 | def log_softmax(x): 122 | if x.dim() == 3: 123 | return F.log_softmax(x.transpose(0, 2)).transpose(0, 2) 124 | return F.log_softmax(x) 125 | 126 | def logsumexp(x, dim=-1): 127 | x_max = x.max(dim, keepdim=True)[0] 128 | return torch.log(torch.exp(x - x_max.expand_as(x)).sum(dim, keepdim=True) + TINY) + x_max 129 | 130 | def gumbel_softmax(input, beta=0.5, tau=1.0): 131 | noise = input.data.new(*input.size()).uniform_() 132 | noise.add_(TINY).log_().neg_().add_(TINY).log_().neg_() 133 | return softmax((input + beta * Variable(noise)) / tau) 134 | 135 | # (4, 3, 2) @ (4, 2) -> (4, 3) 136 | # (4, 3) @ (4, 3, 2) -> (4, 3) 137 | # (4, 3, 2) @ (4, 2, 4) -> (4, 3, 4) 138 | def matmul(x, y): 139 | if x.dim() == y.dim(): 140 | return x @ y 141 | if x.dim() == y.dim() - 1: 142 | return (x.unsqueeze(-2) @ y).squeeze(-2) 143 | return (x @ y.unsqueeze(-1)).squeeze(-1) 144 | 145 | def pad_to_match(x, y): 146 | x_len, y_len = x.size(1), y.size(1) 147 | if x_len == y_len: 148 | return x, y 149 | add_to = x if x_len < y_len else y 150 | fill = 1 if add_to.dim() == 2 else 0 151 | extra = add_to.data.new( 152 | x.size(0), abs(y_len - x_len), *add_to.size()[2:]).fill_(fill) 153 | if x_len < y_len: 154 | return torch.cat((x, extra), 1), y 155 | return x, torch.cat((y, extra), 1) 156 | 157 | # --- Top K search with PQ 158 | def topK_search(logits, mask_src, N=100): 159 | # prepare data 160 | nlogP = -log_softmax(logits).data 161 | maxL = nlogP.size(-1) 162 | overmask = torch.cat([mask_src[:, :, None], 163 | (1 - mask_src[:, :, None]).expand(*mask_src.size(), maxL-1) * INF 164 | + mask_src[:, :, None]], 2) 165 | nlogP = nlogP * overmask 166 | 167 | batch_size, src_len, L = logits.size() 168 | _, R = nlogP.sort(-1) 169 | 170 | def get_score(data, index): 171 | # avoid all zero 172 | # zero_mask = (index.sum(-2) == 0).float() * INF 173 | return data.gather(-1, index).sum(-2) 174 | 175 | heap_scores = torch.ones(batch_size, N) * INF 176 | heap_inx = torch.zeros(batch_size, src_len, N).long() 177 | heap_scores[:, :1] = get_score(nlogP, R[:, :, :1]) 178 | if nlogP.is_cuda: 179 | heap_scores = heap_scores.cuda(nlogP.get_device()) 180 | heap_inx = heap_inx.cuda(nlogP.get_device()) 181 | 182 | def span(ins): 183 | inds = torch.eye(ins.size(1)).long() 184 | if ins.is_cuda: 185 | inds = inds.cuda(ins.get_device()) 186 | return ins[:, :, None].expand(ins.size(0), ins.size(1), ins.size(1)) + inds[None, :, :] 187 | 188 | # iteration starts 189 | for k in range(1, N): 190 | cur_inx = heap_inx[:, :, k-1] 191 | I_t = span(cur_inx).clamp(0, L-1) # B x N x N 192 | S_t = get_score(nlogP, R.gather(-1, I_t)) 193 | S_t, _inx = torch.cat([heap_scores[:, k:], S_t], 1).sort(1) 194 | S_t[:, 1:] += ((S_t[:, 1:] - S_t[:, :-1]) == 0).float() * INF # remove duplicates 195 | S_t, _inx2 = S_t.sort(1) 196 | I_t = torch.cat([heap_inx[:, :, k:], I_t], 2).gather( 197 | 2, _inx.gather(1, _inx2)[:, None, :].expand(batch_size, src_len, _inx.size(-1))) 198 | heap_scores[:, k:] = S_t[:, :N-k] 199 | heap_inx[:, :, k:] = I_t[:, :, :N-k] 200 | 201 | # get the searched 202 | output = R.gather(-1, heap_inx) 203 | output = output.transpose(2, 1).contiguous().view(batch_size * N, src_len) # (B x N) x Ts 204 | output = Variable(output) 205 | mask_src = mask_src[:, None, :].expand(batch_size, N, src_len).contiguous().view(batch_size * N, src_len) 206 | 207 | return output, mask_src 208 | 209 | class LayerNorm(nn.Module): 210 | 211 | def __init__(self, d_model, eps=1e-6): 212 | super().__init__() 213 | self.gamma = nn.Parameter(torch.ones(d_model)) 214 | self.beta = nn.Parameter(torch.zeros(d_model)) 215 | self.eps = eps 216 | 217 | def forward(self, x): 218 | mean = x.mean(-1, keepdim=True) 219 | std = x.std(-1, keepdim=True) 220 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 221 | 222 | class Attention(nn.Module): 223 | 224 | def __init__(self, d_key, drop_ratio, causal, diag=False): 225 | super().__init__() 226 | self.scale = math.sqrt(d_key) 227 | self.dropout = nn.Dropout(drop_ratio) 228 | self.causal = causal 229 | self.diag = diag 230 | 231 | def forward(self, query, key, value=None, mask=None, 232 | feedback=None, beta=0, tau=1, weights=None): 233 | dot_products = matmul(query, key.transpose(1, 2)) # batch x trg_len x trg_len 234 | 235 | if weights is not None: 236 | dot_products = dot_products + weights # additive bias 237 | 238 | if query.dim() == 3 and self.causal and (query.size(1) == key.size(1)): 239 | tri = key.data.new(key.size(1), key.size(1)).fill_(1).triu(1) * INF 240 | dot_products.data.sub_(tri.unsqueeze(0)) 241 | 242 | if self.diag: 243 | inds = torch.arange(0, key.size(1)).long().view(1, 1, -1) 244 | if key.is_cuda: 245 | inds = inds.cuda(key.get_device()) 246 | dot_products.data.scatter_(1, inds.expand(dot_products.size(0), 1, inds.size(-1)), -INF) 247 | # eye = key.data.new(key.size(1), key.size(1)).fill_(1).eye() * INF 248 | # dot_products.data.sub_(eye.unsqueeze(0)) 249 | 250 | if mask is not None: 251 | if dot_products.dim() == 2: 252 | dot_products.data -= ((1 - mask) * INF) 253 | else: 254 | dot_products.data -= ((1 - mask[:, None, :]) * INF) 255 | 256 | if value is None: 257 | return dot_products 258 | 259 | logits = dot_products / self.scale 260 | probs = softmax(logits) 261 | 262 | if feedback is not None: 263 | feedback.append(probs.contiguous()) 264 | 265 | return matmul(self.dropout(probs), value) 266 | 267 | class MultiHead2(nn.Module): 268 | 269 | def __init__(self, d_key, d_value, n_heads, drop_ratio, 270 | causal=False, diag=False, use_wo=True): 271 | super().__init__() 272 | self.attention = Attention(d_key, drop_ratio, causal=causal, diag=diag) 273 | self.wq = Linear(d_key, d_key, bias=use_wo) 274 | self.wk = Linear(d_key, d_key, bias=use_wo) 275 | self.wv = Linear(d_value, d_value, bias=use_wo) 276 | if use_wo: 277 | self.wo = Linear(d_value, d_key, bias=use_wo) 278 | self.use_wo = use_wo 279 | self.n_heads = n_heads 280 | 281 | def forward(self, query, key, value, mask=None, feedback=None, weights=None, beta=0, tau=1): 282 | # query : B x T1 x D 283 | # key : B x T2 x D 284 | # value : B x T2 x D 285 | query, key, value = self.wq(query), self.wk(key), self.wv(value) # B x T x D 286 | B, Tq, D = query.size() 287 | _, Tk, _ = key.size() 288 | N = self.n_heads 289 | probs = [] 290 | 291 | query, key, value = (x.contiguous().view(B, -1, N, D//N).transpose(2, 1).contiguous().view(B*N, -1, D//N) 292 | for x in (query, key, value)) 293 | if mask is not None: 294 | mask = mask[:, None, :].expand(B, N, Tk).contiguous().view(B*N, -1) 295 | outputs = self.attention(query, key, value, mask, probs, beta, tau, weights) # (B x N) x T x (D/N) 296 | outputs = outputs.view(B, N, -1, D//N).transpose(2, 1).contiguous().view(B, -1, D) 297 | 298 | if feedback is not None: 299 | feedback.append(probs[0].view(B, N, Tq, Tk)) 300 | 301 | if self.use_wo: 302 | return self.wo(outputs) 303 | return outputs 304 | 305 | class NonresidualBlock(nn.Module): 306 | 307 | def __init__(self, layer, d_model, d_hidden, drop_ratio, pos=0): 308 | super().__init__() 309 | self.layer = layer 310 | self.dropout = nn.Dropout(drop_ratio) 311 | self.layernorm = LayerNorm(d_model) 312 | self.pos = pos 313 | 314 | def forward(self, *x): 315 | return self.layernorm(self.dropout(self.layer(*x))) 316 | 317 | 318 | class ResidualBlock(nn.Module): 319 | 320 | def __init__(self, layer, d_model, d_hidden, drop_ratio, pos=0): 321 | super().__init__() 322 | self.layer = layer 323 | self.dropout = nn.Dropout(drop_ratio) 324 | self.layernorm = LayerNorm(d_model) 325 | self.pos = pos 326 | 327 | def forward(self, *x): 328 | return self.layernorm(x[self.pos] + self.dropout(self.layer(*x))) 329 | 330 | class HighwayBlock(nn.Module): 331 | 332 | def __init__(self, layer, d_model, d_hidden, drop_ratio, pos=0): 333 | super().__init__() 334 | self.layer = layer 335 | self.gate = FeedForward(d_model, d_hidden) 336 | self.dropout = nn.Dropout(drop_ratio) 337 | self.layernorm = LayerNorm(d_model) 338 | self.pos = pos 339 | 340 | def forward(self, *x): 341 | g = F.sigmoid(self.gate(x[self.pos])) 342 | return self.layernorm(x[self.pos] * g + self.dropout(self.layer(*x)) * (1 - g)) 343 | 344 | class FeedForward(nn.Module): 345 | 346 | def __init__(self, d_model, d_hidden): 347 | super().__init__() 348 | self.linear1 = Linear(d_model, d_hidden) 349 | self.linear2 = Linear(d_hidden, d_model) 350 | 351 | def forward(self, x): 352 | return self.linear2(F.relu(self.linear1(x))) 353 | 354 | class EncoderLayer(nn.Module): 355 | 356 | def __init__(self, args): 357 | super().__init__() 358 | self.selfattn = ResidualBlock( 359 | MultiHead2(args.d_model, args.d_model, args.n_heads, 360 | args.drop_ratio, use_wo=args.use_wo), 361 | args.d_model, args.d_hidden, args.drop_ratio) 362 | self.feedforward = args.block_cls( 363 | FeedForward(args.d_model, args.d_hidden), 364 | args.d_model, args.d_hidden, args.drop_ratio ) 365 | 366 | def forward(self, x, mask=None): 367 | x = self.selfattn(x, x, x, mask) 368 | x = self.feedforward(x) 369 | return x 370 | 371 | class DecoderLayer(nn.Module): 372 | 373 | def __init__(self, args, causal=True, diag=False, 374 | positional=False): 375 | super().__init__() 376 | self.positional = positional 377 | self.selfattn = ResidualBlock( 378 | MultiHead2(args.d_model, args.d_model, args.n_heads, 379 | args.drop_ratio, causal=causal, diag=diag, 380 | use_wo=args.use_wo), 381 | args.d_model, args.d_hidden, args.drop_ratio) 382 | 383 | self.attention = ResidualBlock( 384 | MultiHead2(args.d_model, args.d_model, args.n_heads, 385 | args.drop_ratio, use_wo=args.use_wo), 386 | args.d_model, args.d_hidden, args.drop_ratio) 387 | 388 | if positional: 389 | self.pos_selfattn = ResidualBlock( 390 | MultiHead2(args.d_model, args.d_model, args.n_heads, 391 | args.drop_ratio, causal=causal, diag=diag, 392 | use_wo=args.use_wo), 393 | args.d_model, args.d_hidden, args.drop_ratio, pos=2) 394 | 395 | self.feedforward = args.block_cls( 396 | FeedForward(args.d_model, args.d_hidden), 397 | args.d_model, args.d_hidden, args.drop_ratio ) 398 | 399 | def forward(self, x, encoding, p=None, mask_src=None, mask_trg=None, feedback=None): 400 | 401 | feedback_src = [] 402 | feedback_trg = [] 403 | 404 | x = self.selfattn(x, x, x, mask_trg, feedback_trg) # 405 | 406 | if self.positional: 407 | pos_encoding, weights = positional_encodings_like(x), None 408 | x = self.pos_selfattn(pos_encoding, pos_encoding, x, mask_trg, None, weights) # positional attention 409 | 410 | x = self.attention(x, encoding, encoding, mask_src, feedback_src) 411 | 412 | x = self.feedforward(x) 413 | 414 | if feedback is not None: 415 | if 'source' not in feedback: 416 | feedback['source'] = feedback_src 417 | else: 418 | feedback['source'] += feedback_src 419 | 420 | if 'target' not in feedback: 421 | feedback['target'] = feedback_trg 422 | else: 423 | feedback['target'] += feedback_trg 424 | return x 425 | 426 | class Encoder(nn.Module): 427 | 428 | def __init__(self, field, args): 429 | super().__init__() 430 | 431 | if args.dataset != "mscoco": 432 | if args.share_embed: 433 | self.out = Linear(args.d_model, len(field.vocab), bias=False) 434 | else: 435 | self.embed = nn.Embedding(len(field.vocab), args.d_model) 436 | self.layers = nn.ModuleList( 437 | [EncoderLayer(args) for i in range(args.n_layers)]) 438 | self.dropout = nn.Dropout(args.input_drop_ratio) 439 | if args.dataset != "mscoco": 440 | self.field = field 441 | self.d_model = args.d_model 442 | self.share_embed = args.share_embed 443 | self.dataset = args.dataset 444 | 445 | def forward(self, x, mask=None): 446 | if self.dataset != "mscoco": 447 | if self.share_embed: 448 | x = F.embedding(x, self.out.weight * math.sqrt(self.d_model)) 449 | else: 450 | x = self.embed(x) 451 | x += positional_encodings_like(x) 452 | encoding = [x] 453 | 454 | x = self.dropout(x) 455 | for layer in self.layers: 456 | x = layer(x, mask) 457 | encoding.append(x) 458 | return encoding 459 | 460 | def conv3x3(in_planes, out_planes, stride=1): 461 | """3x3 convolution with padding""" 462 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 463 | padding=1, bias=False) 464 | 465 | class EncoderCNN(nn.Module): 466 | def __init__(self, args): 467 | super(EncoderCNN, self).__init__() 468 | 469 | self.d_encoder = 512 # hardcoded because of resnet 512 hidden size 470 | self.d_model = args.d_model 471 | if self.d_encoder != self.d_model: 472 | self.conv = conv3x3(self.d_encoder, self.d_model, stride=1) 473 | self.bn = nn.BatchNorm2d(self.d_model) 474 | 475 | def forward(self, features): 476 | if self.d_encoder != self.d_model: 477 | return self.bn(self.conv(features)) 478 | else: 479 | return features 480 | 481 | class Decoder(nn.Module): 482 | 483 | def __init__(self, field, args, causal=True, positional=False, diag=False, out=None): 484 | 485 | super().__init__() 486 | 487 | self.layers = nn.ModuleList( 488 | [DecoderLayer(args, causal, diag, positional) 489 | for i in range(args.n_layers)]) 490 | 491 | if out is None: 492 | self.out = Linear(args.d_model, len(field.vocab), bias=False, out_norm=args.out_norm) 493 | else: 494 | self.out = out 495 | 496 | self.dropout = nn.Dropout(args.input_drop_ratio) 497 | self.out_norm = args.out_norm 498 | self.d_model = args.d_model 499 | self.field = field 500 | self.length_ratio = args.length_ratio 501 | self.positional = positional 502 | self.enc_last = args.enc_last 503 | self.dataset = args.dataset 504 | self.length_dec = args.length_dec 505 | 506 | def forward(self, x, encoding, source_masks=None, decoder_masks=None, 507 | input_embeddings=False, positions=None, feedback=None): 508 | # x : decoder_inputs 509 | 510 | if self.out_norm: 511 | out_weight = self.out.weight / (1e-6 + torch.sqrt((self.out.weight ** 2).sum(0, keepdim=True))) 512 | else: 513 | out_weight = self.out.weight 514 | 515 | if not input_embeddings: # NOTE only for Transformer 516 | if x.ndimension() == 2: 517 | x = F.embedding(x, out_weight * math.sqrt(self.d_model)) 518 | elif x.ndimension() == 3: # softmax relaxiation 519 | x = x @ out_weight * math.sqrt(self.d_model) # batch x len x embed_size 520 | 521 | x += positional_encodings_like(x) 522 | x = self.dropout(x) 523 | 524 | if self.enc_last: 525 | for l, layer in enumerate(self.layers): 526 | x = layer(x, encoding[-1], mask_src=source_masks, mask_trg=decoder_masks, feedback=feedback) 527 | else: 528 | for l, (layer, enc) in enumerate(zip(self.layers, encoding[1:])): 529 | x = layer(x, enc, mask_src=source_masks, mask_trg=decoder_masks, feedback=feedback) 530 | return x 531 | 532 | def greedy(self, encoding, mask_src=None, mask_trg=None, feedback=None): 533 | 534 | encoding = encoding[1:] 535 | B, T, C = encoding[0].size() # batch-size, decoding-length, size 536 | if self.dataset == "mscoco": 537 | T = self.length_dec 538 | else: 539 | T *= self.length_ratio 540 | 541 | outs = Variable(encoding[0].data.new(B, T + 1).long().fill_( 542 | self.field.vocab.stoi[''])) 543 | hiddens = [Variable(encoding[0].data.new(B, T, C).zero_()) 544 | for l in range(len(self.layers) + 1)] 545 | embedW = self.out.weight * math.sqrt(self.d_model) 546 | hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) 547 | 548 | eos_yet = encoding[0].data.new(B).byte().zero_() 549 | 550 | attentions = [] 551 | 552 | for t in range(T): 553 | #torch.cuda.nvtx.mark(f'greedy:{t}') 554 | torch.cuda.nvtx.mark('greedy:{}'.format(t)) 555 | hiddens[0][:, t] = self.dropout( 556 | hiddens[0][:, t] + F.embedding(outs[:, t], embedW)) 557 | 558 | inter_attention = [] 559 | for l in range(len(self.layers)): 560 | x = hiddens[l][:, :t+1] 561 | x = self.layers[l].selfattn(hiddens[l][:, t:t+1], x, x) # we need to make the dimension 3D 562 | hiddens[l + 1][:, t] = self.layers[l].feedforward( 563 | self.layers[l].attention(x, encoding[l], encoding[l], mask_src, inter_attention))[:, 0] 564 | 565 | inter_attention = torch.cat(inter_attention, 1) 566 | attentions.append(inter_attention) 567 | 568 | _, preds = self.out(hiddens[-1][:, t]).max(-1) 569 | preds[eos_yet] = self.field.vocab.stoi[''] 570 | 571 | eos_yet = eos_yet | (preds.data == self.field.vocab.stoi['']) 572 | outs[:, t + 1] = preds 573 | if eos_yet.all(): 574 | break 575 | 576 | if feedback is not None: 577 | feedback['source'] = torch.cat(attentions, 2) 578 | 579 | return outs[:, 1:t+2] 580 | 581 | def beam_search(self, encoding, mask_src=None, mask_trg=None, width=2, alpha=0.6): # width: beamsize, alpha: length-norm 582 | encoding = encoding[1:] 583 | W = width 584 | B, T, C = encoding[0].size() 585 | 586 | # expanding 587 | for i in range(len(encoding)): 588 | encoding[i] = encoding[i][:, None, :].expand( 589 | B, W, T, C).contiguous().view(B * W, T, C) 590 | mask_src = mask_src[:, None, :].expand(B, W, T).contiguous().view(B * W, T) 591 | 592 | T *= self.length_ratio 593 | outs = Variable(encoding[0].data.new(B, W, T + 1).long().fill_( 594 | self.field.vocab.stoi[''])) 595 | 596 | logps = Variable(encoding[0].data.new(B, W).float().fill_(0)) # scores 597 | hiddens = [Variable(encoding[0].data.new(B, W, T, C).zero_()) # decoder states: batch x beamsize x len x h 598 | for l in range(len(self.layers) + 1)] 599 | embedW = self.out.weight * math.sqrt(self.d_model) 600 | hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) 601 | eos_yet = encoding[0].data.new(B, W).byte().zero_() # batch x beamsize, all the sentences are not finished yet. 602 | eos_mask = eos_yet.float().fill_(-INF)[:, :, None].expand(B, W, W) 603 | eos_mask[:, :, 0] = 0 # batch x beam x beam 604 | 605 | for t in range(T): 606 | hiddens[0][:, :, t] = self.dropout( 607 | hiddens[0][:, :, t] + F.embedding(outs[:, :, t], embedW)) 608 | for l in range(len(self.layers)): 609 | x = hiddens[l][:, :, :t + 1].contiguous().view(B * W, -1, C) 610 | x = self.layers[l].selfattn(x[:, -1:, :], x, x) 611 | hiddens[l + 1][:, :, t] = self.layers[l].feedforward( 612 | self.layers[l].attention(x, encoding[l], encoding[l], mask_src)).view( 613 | B, W, C) 614 | 615 | # topk2_logps: scores, topk2_inds: top word index at each beam, batch x beam x beam 616 | topk2_logps, topk2_inds = log_softmax( 617 | self.out(hiddens[-1][:, :, t])).topk(W, dim=-1) 618 | 619 | # mask out the sentences which are finished 620 | topk2_logps = topk2_logps * Variable(eos_yet[:, :, None].float() * eos_mask + 1 - eos_yet[:, :, None].float()) 621 | topk2_logps = topk2_logps + logps[:, :, None] 622 | 623 | if t == 0: 624 | logps, topk_inds = topk2_logps[:, 0].topk(W, dim=-1) 625 | else: 626 | logps, topk_inds = topk2_logps.view(B, W * W).topk(W, dim=-1) 627 | 628 | topk_beam_inds = topk_inds.div(W) 629 | topk_token_inds = topk2_inds.view(B, W * W).gather(1, topk_inds) 630 | eos_yet = eos_yet.gather(1, topk_beam_inds.data) 631 | 632 | logps = logps * (1 - Variable(eos_yet.float()) * 1 / (t + 2)).pow(alpha) 633 | outs = outs.gather(1, topk_beam_inds[:, :, None].expand_as(outs)) 634 | outs[:, :, t + 1] = topk_token_inds 635 | topk_beam_inds = topk_beam_inds[:, :, None, None].expand_as( 636 | hiddens[0]) 637 | for i in range(len(hiddens)): 638 | hiddens[i] = hiddens[i].gather(1, topk_beam_inds) 639 | eos_yet = eos_yet | (topk_token_inds.data == self.field.vocab.stoi['']) 640 | if eos_yet.all(): 641 | return outs[:, 0, 1:] 642 | return outs[:, 0, 1:] 643 | 644 | class Transformer(nn.Module): 645 | 646 | def __init__(self, src=None, trg=None, args=None): 647 | super().__init__() 648 | if args.dataset != "mscoco": 649 | self.is_mscoco = False 650 | # prepare regular translation encoder and decoder 651 | self.encoder = Encoder(src, args) 652 | self.decoder = Decoder(trg, args) 653 | self.field = trg 654 | self.share_embed = args.share_embed 655 | if args.share_embed: 656 | self.encoder.out.weight = self.decoder.out.weight 657 | else: 658 | # prepare image encoder and decoder 659 | self.is_mscoco = True 660 | mscoco_dataset = trg 661 | #self.encoder = EncoderCNN(args) 662 | self.encoder = Encoder(src, args) 663 | self.decoder = Decoder(mscoco_dataset, args) 664 | self.field = mscoco_dataset 665 | self.share_embed = False 666 | 667 | self.n_layers = args.n_layers 668 | self.d_model = args.d_model 669 | self.gpu = args.gpu 670 | 671 | def denum(self, data, target=True): 672 | field = self.decoder.field if target else self.encoder.field 673 | return field.reverse(data.unsqueeze(0))[0] 674 | 675 | def apply_mask(self, inputs, mask, p=1): 676 | _mask = Variable(mask.long()) 677 | #outputs = inputs * _mask + (1 - _mask) * p 678 | outputs = inputs * _mask + (torch.mul(_mask, -1) + 1 ) * p 679 | return outputs 680 | 681 | def apply_mask_cost(self, loss, mask, batched=False): 682 | loss.data *= mask 683 | cost = loss.sum() / (mask.sum() + TINY) 684 | 685 | if not batched: 686 | return cost 687 | 688 | loss = loss.sum(1, keepdim=True) / (TINY + Variable(mask).sum(1, keepdim=True)) 689 | return cost, loss 690 | 691 | def output_decoding(self, outputs, unbpe=True): 692 | field, text = outputs 693 | if field is 'src': 694 | return self.encoder.field.reverse(text.data, unbpe) 695 | else: 696 | return self.decoder.field.reverse(text.data, unbpe) 697 | 698 | def prepare_sources(self, batch, masks=None): 699 | masks = self.prepare_masks(batch.src) if masks is None else masks 700 | return batch.src, masks 701 | 702 | def encoding(self, encoder_inputs, source_masks=None): 703 | if self.is_mscoco: 704 | return self.encoder(encoder_inputs) 705 | else: 706 | return self.encoder(encoder_inputs, source_masks) 707 | 708 | def prepare_targets(self, batch, targets=None, masks=None): 709 | if targets is None: 710 | targets = batch.trg[:, 1:].contiguous() 711 | masks = self.prepare_masks(targets) if masks is None else masks 712 | return targets, masks 713 | 714 | def prepare_decoder_inputs(self, trg_inputs, inputs=None, masks=None, bp=1.00): 715 | decoder_inputs = trg_inputs[:, :-1].contiguous() 716 | decoder_masks = self.prepare_masks(trg_inputs[:, 1:], bp=bp) if masks is None else masks 717 | # NOTE why [1:], not [:-1]? 718 | 719 | return decoder_inputs, decoder_masks 720 | 721 | def change_bp_masks(self, masks, bp): 722 | input_lengths = np.int32( masks.sum(1).cpu().numpy() ) 723 | batch_size, seq_len = masks.size() 724 | add_pad = [ int( math.floor( each_len * ( (1 / bp) - 1.0 ) ) ) for each_len in input_lengths] 725 | if max(add_pad) > 0 : 726 | add_mask = torch.zeros((batch_size, max(add_pad))).float() # NOTE we add masks of ones at the front! 727 | if masks.is_cuda: 728 | add_mask = add_mask.cuda(masks.get_device()) 729 | masks = torch.cat((masks, add_mask), dim=1) 730 | for bidx in range(input_lengths.shape[0]): 731 | if add_pad[bidx] > 0: 732 | masks[bidx, input_lengths[bidx]:input_lengths[bidx]+add_pad[bidx]] = 1 733 | return masks 734 | 735 | def prepare_masks(self, inputs, bp=1.0): 736 | if inputs.ndimension() == 2: 737 | masks = (inputs.data != self.field.vocab.stoi['']).float() 738 | else: # NOTE FALSE 739 | masks = (inputs.data[:, :, self.field.vocab.stoi['']] != 1).float() 740 | 741 | if bp < 1.0: 742 | masks = self.change_bp_masks(masks, bp) 743 | 744 | return masks 745 | 746 | def find_captions_length(self, all_captions): 747 | # find length of each caption 748 | all_captions_lengths = [] 749 | # list of lists 750 | if type(all_captions[0]) == list: 751 | num_captions = len(all_captions[0]) 752 | for i in range(num_captions): 753 | caption_length = 0 754 | for j in range(len(all_captions)): 755 | caption_length += len(all_captions[j][i].split(' ')) 756 | caption_length = int(caption_length / len(all_captions)) 757 | all_captions_lengths.append(caption_length) 758 | else: 759 | for cap in all_captions: 760 | all_captions_lengths.append(len(cap.split(' '))) 761 | 762 | return all_captions_lengths 763 | 764 | def quick_prepare_mscoco(self, batch, all_captions=None, fast=True, inputs_dec='pool', trg_len_option=None, max_len=20, trg_len_dic=None, decoder_inputs=None, targets=None, decoder_masks=None, target_masks=None, source_masks=None, bp=1.00, gpu=True): 765 | features_beforepool, captions = batch[0], batch[1] 766 | batch_size, d_model = features_beforepool.size(0), features_beforepool.size(1) 767 | 768 | # batch_size x 49 x 512 769 | features_beforepool = features_beforepool.view(batch_size, d_model, 49).transpose(1, 2) 770 | if gpu: 771 | encoding = self.encoding(Variable(features_beforepool, requires_grad=False).cuda(), source_masks) # batch of resnet features 772 | source_masks = torch.FloatTensor(batch_size, 49).fill_(1).cuda() 773 | targets = self.prepare_target_captions(captions, self.field.vocab.stoi).cuda() 774 | else: 775 | encoding = self.encoding(Variable(features_beforepool, requires_grad=False), source_masks) # batch of resnet features 776 | source_masks = torch.FloatTensor(batch_size, 49).fill_(1) 777 | targets = self.prepare_target_captions(captions, self.field.vocab.stoi) 778 | 779 | # list of batch_size 780 | all_captions_lengths = self.find_captions_length(all_captions) 781 | 782 | # predicted decoder lens 783 | if trg_len_option == "predict": 784 | # batch_size tensor 785 | if gpu: 786 | target_len = Variable(torch.from_numpy(np.clip(np.array(all_captions_lengths), 0, self.max_offset)).cuda(), requires_grad=False) 787 | else: 788 | target_len = Variable(torch.from_numpy(np.clip(np.array(all_captions_lengths), 0, self.max_offset)), requires_grad=False) 789 | 790 | # HARDCODED (4 layer model) !!! 791 | pred_target_len_logits = self.pred_len((encoding[0]+encoding[1]+encoding[2]+encoding[3]+encoding[4]).mean(1)) 792 | pred_target_len_loss = F.cross_entropy(pred_target_len_logits, target_len.long()) 793 | pred_target_len = pred_target_len_logits.max(-1)[1] 794 | 795 | if fast == False: 796 | decoder_inputs, decoder_masks = self.prepare_decoder_inputs(targets, decoder_inputs, decoder_masks) # prepare decoder-inputs 797 | else: 798 | if trg_len_option == "fixed": 799 | decoder_len = int(max_len) 800 | decoder_masks = torch.ones(batch_size, decoder_len) 801 | if gpu: 802 | decoder_masks = decoder_masks.cuda(encoding[0].get_device()) 803 | 804 | # TODO ADD BP OPTION 805 | elif trg_len_option == "reference" or (trg_len_option == "predict" and self.use_predicted_trg_len == False): 806 | decoder_len = max(all_captions_lengths) 807 | decoder_masks = np.zeros((batch_size, decoder_len)) 808 | for idx in range(decoder_masks.shape[0]): 809 | decoder_masks[idx][:all_captions_lengths[idx]] = 1 810 | decoder_masks = torch.from_numpy(decoder_masks).float() 811 | if gpu: 812 | decoder_masks = decoder_masks.cuda(encoding[0].get_device()) 813 | 814 | if trg_len_option == "predict": 815 | if self.use_predicted_trg_len: 816 | pred_target_len = pred_target_len.data.cpu().numpy() 817 | decoder_len = np.max(pred_target_len) 818 | decoder_masks = np.zeros((batch_size, decoder_len)) 819 | for idx in range(pred_target_len.shape[0]): 820 | decoder_masks[idx][:pred_target_len[idx]] = 1 821 | decoder_masks = torch.from_numpy(decoder_masks).float() 822 | if gpu: 823 | decoder_masks = decoder_masks.cuda(encoding[0].get_device()) 824 | if bp < 1.0: 825 | decoder_masks = self.change_bp_masks(decoder_masks, bp) 826 | 827 | if not self.use_predicted_trg_len: 828 | pred_target_len = pred_target_len.data.cpu().numpy() 829 | 830 | target_len = target_len.data.cpu().numpy() 831 | 832 | # calculate error for predicted target length 833 | pred_target_len_correct = np.sum(pred_target_len == target_len)*100/batch_size 834 | pred_target_len_approx = np.sum(np.abs(pred_target_len - target_len) < 5)*100/batch_size 835 | average_target_len_correct = 0 836 | average_target_len_approx = 0 837 | 838 | rest = [pred_target_len_loss, pred_target_len_correct, pred_target_len_approx, average_target_len_correct, average_target_len_approx] 839 | 840 | if inputs_dec == 'pool': 841 | # batch_size x 1 x 512 842 | decoder_inputs = torch.mean(features_beforepool, 1, keepdim=True) 843 | decoder_inputs = decoder_inputs.repeat(1, int(decoder_len), 1) 844 | decoder_inputs = Variable(decoder_inputs, requires_grad=False) 845 | if gpu: 846 | decoder_inputs = decoder_inputs.cuda(encoding[0].get_device()) 847 | elif inputs_dec == 'zeros': 848 | decoder_inputs = Variable(torch.zeros(batch_size, int(decoder_len), d_model), requires_grad=False) 849 | if gpu: 850 | decoder_inputs = decoder_inputs.cuda(encoding[0].get_device()) 851 | 852 | # REMOVE THE FIRST TAG FROM CAPTIONS 853 | targets = targets[:, 1:] 854 | if gpu: 855 | target_masks = (targets != 1).float().cuda().data 856 | else: 857 | target_masks = (targets != 1).float().data 858 | 859 | if trg_len_option != "predict": 860 | rest = [] 861 | sources = None 862 | 863 | return decoder_inputs, decoder_masks, targets, target_masks, sources, source_masks, encoding, decoder_inputs.size(0), rest 864 | 865 | def prepare_target_captions(self, captions, vocab): 866 | # captions : batch_size X seq_len 867 | lst = [] 868 | batch_size = len(captions) 869 | for bidx in range(batch_size): 870 | lst.append( [""] + captions[ bidx ].lower().split() + [""] ) 871 | #lst.append( [ vocab[idx] for idx in captions[ random.randint(0,4) ][ bidx ].lower().split() ] ) 872 | lst = [[vocab[idx] if idx in vocab else 0 for idx in sentence] for sentence in lst] 873 | seq_len = max( [len(xx) for xx in lst] ) 874 | captions = np.ones((batch_size, seq_len)) 875 | for bidx in range(batch_size): 876 | min_len = min(seq_len, len(lst[bidx])) 877 | captions[bidx, :min_len] = np.array(lst[bidx][:min_len]) 878 | captions = torch.from_numpy(captions).long() 879 | return Variable(captions, requires_grad=False) 880 | 881 | def quick_prepare(self, batch, fast=True, trg_len_option=None, trg_len_ratio=2.0, trg_len_dic=None, decoder_inputs=None, targets=None, decoder_masks=None, target_masks=None, source_masks=None, bp=1.00): 882 | sources, source_masks = self.prepare_sources(batch, source_masks) 883 | sources = sources.cuda(self.gpu) 884 | source_masks = source_masks.cuda(self.gpu) 885 | encoding = self.encoding(sources, source_masks) 886 | targets, target_masks = self.prepare_targets(batch, targets, decoder_masks) # prepare decoder-targets 887 | targets = targets.cuda(self.gpu) 888 | target_masks = target_masks.cuda(self.gpu) 889 | # predicted decoder masks 890 | if trg_len_option == "predict": 891 | target_offset = Variable((target_masks.sum(-1) - source_masks.sum(-1)).clamp_(-self.max_offset, self.max_offset), requires_grad=False) # batch_size tensor 892 | source_len = Variable(source_masks.sum(-1), requires_grad=False) 893 | 894 | pred_target_offset_logits = self.pred_len((encoding[0]+encoding[1]+encoding[2]+encoding[3]+encoding[4]+encoding[5]).mean(1)) 895 | pred_target_offset_logits = self.pred_len_drop( pred_target_offset_logits ) 896 | pred_target_len_loss = F.cross_entropy(pred_target_offset_logits, (target_offset + self.max_offset).long()) 897 | pred_target_offset = pred_target_offset_logits.max(-1)[1] - self.max_offset 898 | pred_target_len = source_len.long() + pred_target_offset 899 | 900 | d_model = encoding[0].size(-1) 901 | batch_size, src_max_len = source_masks.size() 902 | rest = [] 903 | 904 | if fast: 905 | # compute decoder_masks 906 | if trg_len_option == "reference": 907 | _, decoder_masks = self.prepare_decoder_inputs(batch.trg.cuda(self.gpu), decoder_inputs, decoder_masks, bp=bp) 908 | 909 | elif trg_len_option == "noisy_ref": 910 | bp = np.random.uniform(bp, 1.0) 911 | _, decoder_masks = self.prepare_decoder_inputs(batch.trg, decoder_inputs, decoder_masks, bp=bp) 912 | 913 | elif trg_len_option == "average": 914 | decoder_masks = make_decoder_masks(source_masks, trg_len_dic) 915 | # we use the average target lengths 916 | 917 | elif trg_len_option == "predict": 918 | # convert to numpy arrays first 919 | source_len = source_masks.sum(-1).cpu().numpy() 920 | target_len = target_masks.sum(-1).cpu().numpy() 921 | pred_target_len = pred_target_len.data.cpu().numpy() 922 | 923 | if not self.use_predicted_trg_len: 924 | _, decoder_masks = self.prepare_decoder_inputs(batch.trg, decoder_inputs, decoder_masks, bp=bp) 925 | else: 926 | decoder_max_len = max(pred_target_len) 927 | decoder_masks = np.zeros((batch_size, decoder_max_len)) 928 | for idx in range(pred_target_len.shape[0]): 929 | decoder_masks[idx][:pred_target_len[idx]] = 1 930 | decoder_masks = torch.from_numpy(decoder_masks).float() 931 | if source_masks.is_cuda: 932 | decoder_masks = decoder_masks.cuda(source_masks.get_device()) 933 | if bp < 1.0: 934 | decoder_masks = self.change_bp_masks(decoder_masks, bp) 935 | 936 | # check the results of predicting target length 937 | pred_target_len_correct = np.sum(pred_target_len == target_len)*100/batch_size 938 | pred_target_len_approx = np.sum(np.abs(pred_target_len - target_len) < 5)*100/batch_size 939 | 940 | # results with average len 941 | average_target_len = [query_trg_len_dic(trg_len_dic, source) for source in source_len] 942 | average_target_len = np.array(average_target_len) 943 | average_target_len_correct = np.sum(average_target_len == target_len)*100/batch_size 944 | average_target_len_approx = np.sum(np.abs(average_target_len - target_len) < 5)*100/batch_size 945 | 946 | rest = [pred_target_len_loss, pred_target_len_correct, pred_target_len_approx, average_target_len_correct, average_target_len_approx] 947 | 948 | elif "fixed" in trg_len_option: 949 | trg_len = (batch.trg != 1).sum(-1).int().data.cpu().numpy().tolist() 950 | 951 | source_lens = source_masks.sum(-1).cpu().numpy() 952 | decoder_masks = torch.zeros(batch_size, int(round(trg_len_ratio * src_max_len))) 953 | dec_len = int(round(trg_len_ratio * src_max_len)) 954 | 955 | for bi in range(batch_size): 956 | ss = source_lens[bi] 957 | decoder_masks[bi,:int(round(trg_len_ratio*ss))] = 1 958 | 959 | if encoding[0].is_cuda: 960 | decoder_masks = decoder_masks.cuda(encoding[0].get_device()) 961 | decoder_inputs, decoder_masks = self.prepare_initial(encoding, sources, source_masks, decoder_masks) 962 | else: 963 | decoder_inputs, decoder_masks = self.prepare_decoder_inputs(batch.trg, decoder_inputs, decoder_masks) # prepare decoder-inputs 964 | 965 | return decoder_inputs, decoder_masks, targets, target_masks, sources, source_masks, encoding, decoder_inputs.size(0), rest 966 | 967 | def forward(self, encoding, source_masks, decoder_inputs, decoder_masks, 968 | decoding=False, beam=1, alpha=0.6, return_probs=False, positions=None, feedback=None): 969 | 970 | if (return_probs and decoding) or (not decoding): 971 | out = self.decoder(decoder_inputs, encoding, source_masks, decoder_masks) 972 | 973 | if decoding: 974 | if beam == 1: # greedy decoding 975 | output = self.decoder.greedy(encoding, source_masks, decoder_masks, feedback=feedback) 976 | else: 977 | output = self.decoder.beam_search(encoding, source_masks, decoder_masks, beam, alpha) 978 | 979 | if return_probs: 980 | return output, out, self.decoder.out(out) # NOTE don't do softmax for validation 981 | #return output, out, softmax(self.decoder.out(out)) 982 | return output 983 | 984 | if return_probs: 985 | return out, softmax(self.decoder.out(out)) 986 | return out 987 | 988 | def cost(self, decoder_targets, decoder_masks, out=None): 989 | # get loss in a sequence-format to save computational time. 990 | decoder_targets, out = prepare_cost(decoder_targets, out, decoder_masks.byte()) 991 | logits = self.decoder.out(out) 992 | loss = F.cross_entropy(logits, decoder_targets) 993 | return loss 994 | 995 | def batched_cost(self, decoder_targets, decoder_masks, probs, batched=False): 996 | # get loss in a batch-mode 997 | 998 | if decoder_targets.ndimension() == 2: # batch x length 999 | loss = -torch.log(probs + TINY).gather(2, decoder_targets[:, :, None])[:, :, 0] # batch x length 1000 | else: 1001 | loss = -(torch.log(probs + TINY) * decoder_targets).sum(-1) 1002 | return self.apply_mask_cost(loss, decoder_masks, batched) 1003 | 1004 | class FastTransformer(Transformer): 1005 | 1006 | def __init__(self, src=None, trg=None, args=None): 1007 | super(Transformer, self).__init__() 1008 | self.is_mscoco = args.dataset == "mscoco" 1009 | self.decoder_input_how = args.decoder_input_how 1010 | self.encoder = Encoder(src, args) 1011 | ''' 1012 | if self.is_mscoco == False: 1013 | self.encoder = Encoder(src, args) 1014 | else: 1015 | self.encoder = EncoderCNN(args) 1016 | ''' 1017 | self.decoder = nn.ModuleList() 1018 | for ni in range(args.num_decs): 1019 | self.decoder.append(Decoder(trg, args, 1020 | causal=False, 1021 | positional=args.positional, 1022 | diag=args.diag, 1023 | out=self.encoder.out if args.share_embed_enc_dec1 and ni == 0 else None) 1024 | ) 1025 | self.field = trg 1026 | if self.is_mscoco == False: 1027 | self.share_embed = args.share_embed 1028 | else: 1029 | self.share_embed = False 1030 | self.train_repeat_dec = args.train_repeat_dec 1031 | self.num_decs = args.num_decs 1032 | if args.trg_len_option == "predict": 1033 | if args.dataset != "mscoco": 1034 | self.pred_len = Linear(args.d_model, 2*args.max_offset + 1) 1035 | else: 1036 | self.pred_len = Linear(args.d_model, args.max_offset+1) 1037 | self.pred_len_drop = nn.Dropout(args.drop_len_pred) 1038 | self.max_offset = args.max_offset 1039 | self.use_predicted_trg_len = args.use_predicted_trg_len 1040 | self.n_layers = args.n_layers 1041 | self.d_model = args.d_model 1042 | self.softmax = nn.Softmax(dim = -1) 1043 | self.gpu=args.gpu 1044 | 1045 | def output_decoding(self, outputs, unbpe = True): 1046 | field, text = outputs 1047 | if field is 'src': 1048 | return self.encoder.field.reverse(text.data, unbpe) 1049 | else: 1050 | return self.decoder[0].field.reverse(text.data, unbpe) 1051 | 1052 | # decoder_masks already decided 1053 | # computes decoder_inputs 1054 | def prepare_initial(self, encoding, source=None, source_masks=None, decoder_masks=None, 1055 | N=1, tau=1): 1056 | 1057 | decoder_input_how = self.decoder_input_how 1058 | d_model = encoding[0].size()[-1] 1059 | attention = linear_attention(source_masks, decoder_masks, decoder_input_how) 1060 | 1061 | if decoder_input_how in ["copy", "pad", "wrap"]: 1062 | attention = self.apply_mask(attention, decoder_masks, p=1) # p doesn't matter cos masked out 1063 | attention = attention[:,:,None].expand(*attention.size(), d_model) 1064 | decoder_inputs = torch.gather(encoding[0], dim=1, index=attention) 1065 | 1066 | elif decoder_input_how == "interpolate": 1067 | decoder_inputs = matmul(attention, encoding[0]) # batch x max_trg x size 1068 | 1069 | return decoder_inputs, decoder_masks 1070 | 1071 | def forward(self, encoding, source_masks, decoder_inputs, decoder_masks, 1072 | decoding=False, beam=1, alpha=0.6, 1073 | return_probs=False, positions=None, feedback=None, iter_=0, T=1): 1074 | 1075 | thedecoder = self.decoder[iter_] 1076 | out = thedecoder(decoder_inputs, encoding, source_masks, decoder_masks, 1077 | input_embeddings=True, positions=positions, feedback=feedback) 1078 | # out : output from the (-1)-th DecoderLayer 1079 | 1080 | if not decoding: # NOTE training 1081 | if not return_probs: 1082 | return out 1083 | return out, softmax(thedecoder.out(out), T=T) # probs 1084 | 1085 | logits = thedecoder.out(out) 1086 | 1087 | if beam == 1: 1088 | output = self.apply_mask(logits.max(-1)[1], decoder_masks) # NOTE given mask, set non-mask to 1 1089 | else: 1090 | output, decoder_masks = topK_search(logits, decoder_masks, N=beam) 1091 | output = self.apply_mask(output, decoder_masks) 1092 | 1093 | if not return_probs: 1094 | return output 1095 | else: 1096 | return output, out, logits # NOTE don't do softmax for validation 1097 | #return output, out, softmax(logits, T=T) 1098 | 1099 | def cost(self, targets, target_mask, out=None, iter_=0, return_logits=False): 1100 | # get loss in a sequence-format to save computational time. 1101 | targets, out = prepare_cost(targets, out, target_mask.byte()) 1102 | logits = self.decoder[iter_].out(out) 1103 | loss = F.cross_entropy(logits, targets) 1104 | if return_logits: 1105 | return loss, logits 1106 | return loss 1107 | 1108 | def rf_cost(self, args, targets, target_mask, out=None, iter_=0, return_logits=False): 1109 | # REINFORCE, Eq.(11) 1110 | targets, out = prepare_cost(targets, out, target_mask.byte()) 1111 | logits = self.decoder[iter_].out(out) 1112 | probs = self.softmax(logits) 1113 | 1114 | sample_index = torch.multinomial(probs,1) 1115 | sample_prob = torch.gather(probs, -1, sample_index) 1116 | target_lens = torch.sum(target_mask, dim = -1).long().tolist() 1117 | targets = targets.data.tolist() 1118 | sample_index =sample_index.data.view(-1).tolist() 1119 | if args.sample_method == 'sentence': 1120 | gleu = self.compute_sentence_gleu(args, sample_index, targets, target_lens) 1121 | else: 1122 | gleu = self.compute_stepwise_gleu(args.stepwise_sampletimes, args.workers, sample_index, probs, targets, target_lens) 1123 | 1124 | loss = torch.sum((-1 * torch.log(sample_prob).view(-1) * gleu),dim = 0).div(len(targets)) 1125 | return loss 1126 | 1127 | 1128 | def nat_cost(self, args, targets, target_mask, out=None, iter_=0, return_logits=False): 1129 | # RF-NAT, Eq.(12) 1130 | targets, out = prepare_cost(targets, out, target_mask.byte()) 1131 | logits = self.decoder[iter_].out(out) 1132 | probs = self.softmax(logits) 1133 | target_lens = torch.sum(target_mask, dim = -1).long().tolist() 1134 | targets = targets.data.tolist() 1135 | 1136 | top_probs, top_index = torch.topk(probs, args.topk, dim = -1) 1137 | 1138 | a,b = probs.size() 1139 | copy_probs = Variable(torch.zeros(a,b ).cuda(probs.get_device())) 1140 | copy_probs.data.copy_(probs.data) 1141 | copy_probs.scatter_add_(1, top_index, -1 * top_probs) 1142 | sample_index = torch.multinomial(copy_probs,1) 1143 | sample_prob = torch.gather(probs, -1, sample_index) 1144 | sample_index = sample_index.data.view(-1).tolist() 1145 | 1146 | top_index = top_index.t().data.tolist() 1147 | weight = torch.sum(top_probs, dim = -1).detach() 1148 | gleus = [] 1149 | 1150 | for i in range(args.topk): 1151 | top_idx = top_index[i] 1152 | gleu = self.compute_stepwise_gleu(args.stepwise_sampletimes, args.workers, top_idx, probs, targets, target_lens) 1153 | gleus.append(gleu) 1154 | gleus = torch.stack(gleus, dim = 1) 1155 | loss_traverse = -1 * torch.sum(top_probs * gleus) 1156 | 1157 | gleu = self.compute_stepwise_gleu(args.stepwise_sampletimes, args.workers, sample_index, probs, targets, target_lens) 1158 | loss_sample = torch.sum((-1 * (1-weight) * torch.log(sample_prob).view(-1) * gleu),dim = 0) 1159 | 1160 | loss = (loss_sample + loss_traverse).div(len(targets)) 1161 | return loss 1162 | 1163 | def compute_sentence_gleu(self, args, sample_index, targets, target_lens): 1164 | 1165 | #tokenizer = lambda x: x.replace('@@ ', '').split() 1166 | list_targets = self.shape(targets,target_lens) 1167 | list_samples = self.shape(sample_index,target_lens) 1168 | gleus = computeGLEU(list_samples, list_targets) 1169 | gleus = self.deshape(gleus,target_lens) 1170 | 1171 | return Variable(torch.Tensor(gleus).cuda(args.gpu)) 1172 | 1173 | def compute_stepwise_gleu(self, sample_times, workers, sample_index, sample_prob, targets, target_lens): 1174 | 1175 | list_targets = self.shape(targets,target_lens) 1176 | list_samples = self.shape(sample_index,target_lens) 1177 | count = len(list_samples) 1178 | gleus = [] 1179 | sample_idxs = [torch.multinomial(sample_prob,1).data.view(-1).tolist() for i in range(sample_times)] 1180 | inputs = [(sample_idxs[i], list_samples, list_targets, count, target_lens) for i in range(sample_times)] 1181 | pool = ProcessPoolExecutor(max_workers=workers) 1182 | gleus = list(pool.map(parallel_gleu, inputs)) 1183 | gleus = Variable(torch.Tensor(gleus).cuda(sample_prob.get_device())) 1184 | gleus = torch.mean(gleus,dim = 0) 1185 | 1186 | return gleus 1187 | 1188 | def shape(self, targets, target_lens): 1189 | 1190 | list_targets = [] 1191 | begin = 0 1192 | end = 0 1193 | for length in target_lens: 1194 | end += length 1195 | list_targets.append([str(index) for index in targets[begin:end]]) 1196 | begin += length 1197 | 1198 | return list_targets 1199 | 1200 | def deshape(self, prev_targets, target_lens): 1201 | targets = [] 1202 | for i in range(len(target_lens)): 1203 | targets += ([prev_targets[i]] * target_lens[i]) 1204 | return targets 1205 | 1206 | 1207 | 1208 | 1209 | def prepare_cost(targets, out, target_mask=None, return_mask=None): 1210 | # targets : batch_size, seq_len 1211 | # out : batch_size, seq_len, vocab_size 1212 | # target_mask : batch_size, seq_len 1213 | if target_mask is None: 1214 | target_mask = (targets != 1) 1215 | 1216 | if targets.size(1) < out.size(1): 1217 | out = out[:, :targets.size(1), :] 1218 | elif targets.size(1) > out.size(1): 1219 | targets = targets[:, :out.size(1)] 1220 | target_mask = target_mask[:, :out.size(1)] 1221 | 1222 | out_mask = target_mask.unsqueeze(-1).expand_as(out) 1223 | 1224 | if return_mask: 1225 | return targets[target_mask], out[out_mask].view(-1, out.size(-1)), out_mask 1226 | else: 1227 | return targets[target_mask], out[out_mask].view(-1, out.size(-1)) 1228 | 1229 | def linear_attention(source_masks, decoder_masks, decoder_input_how): 1230 | if decoder_input_how == "copy": 1231 | max_src_len = source_masks.size(1) 1232 | max_trg_len = decoder_masks.size(1) 1233 | 1234 | src_lens = source_masks.sum(-1).float()-1 # batch_size 1235 | trg_lens = decoder_masks.sum(-1).float()-1 # batch_size 1236 | steps = src_lens / trg_lens # batch_size 1237 | 1238 | index_s = torch.arange(max_trg_len).float() # max_trg_len 1239 | if decoder_masks.is_cuda: 1240 | index_s = index_s.cuda(decoder_masks.get_device()) 1241 | 1242 | index_s = steps[:,None] * index_s[None,:] # batch_size X max_trg_len 1243 | index_s = Variable(torch.round(index_s), requires_grad=False).long() 1244 | return index_s 1245 | 1246 | elif decoder_input_how == "wrap": 1247 | batch_size, max_src_len = source_masks.size() 1248 | max_trg_len = decoder_masks.size(1) 1249 | 1250 | src_lens = source_masks.sum(-1).int() # batch_size 1251 | 1252 | index_s = torch.arange(max_trg_len)[None,:] # max_trg_len 1253 | index_s = index_s.repeat(batch_size, 1) # (batch_size, max_trg_len) 1254 | 1255 | for sin in range(batch_size): 1256 | if src_lens[sin]+1 < max_trg_len: 1257 | index_s[sin, src_lens[sin]:2*src_lens[sin]] = index_s[sin, :src_lens[sin]] 1258 | 1259 | if decoder_masks.is_cuda: 1260 | index_s = index_s.cuda(decoder_masks.get_device()) 1261 | 1262 | return Variable(index_s, requires_grad=False).long() 1263 | 1264 | elif decoder_input_how == "pad": 1265 | batch_size, max_src_len = source_masks.size() 1266 | max_trg_len = decoder_masks.size(1) 1267 | 1268 | src_lens = source_masks.sum(-1).int() - 1 # batch_size 1269 | 1270 | index_s = torch.arange(max_trg_len)[None,:] # max_trg_len 1271 | index_s = index_s.repeat(batch_size, 1) # (batch_size, max_trg_len) 1272 | 1273 | for sin in range(batch_size): 1274 | if src_lens[sin]+1 < max_trg_len: 1275 | index_s[sin, src_lens[sin]+1:] = index_s[sin, src_lens[sin]] 1276 | 1277 | if decoder_masks.is_cuda: 1278 | index_s = index_s.cuda(decoder_masks.get_device()) 1279 | 1280 | return Variable(index_s, requires_grad=False).long() 1281 | 1282 | elif decoder_input_how == "interpolate": 1283 | max_src_len = source_masks.size(1) 1284 | max_trg_len = decoder_masks.size(1) 1285 | src_lens = source_masks.sum(-1).float() # batchsize 1286 | trg_lens = decoder_masks.sum(-1).float() # batchsize 1287 | steps = src_lens / trg_lens # batchsize 1288 | index_t = torch.arange(0, max_trg_len) # max_trg_len 1289 | if decoder_masks.is_cuda: 1290 | index_t = index_t.cuda(decoder_masks.get_device()) 1291 | index_t = steps[:, None] @ index_t[None, :] # batch x max_trg_len 1292 | index_s = torch.arange(0, max_src_len) # max_src_len 1293 | if decoder_masks.is_cuda: 1294 | index_s = index_s.cuda(decoder_masks.get_device()) 1295 | indexxx_ = (index_s[None, None, :] - index_t[:, :, None]) ** 2 # batch x max_trg x max_src 1296 | indexxx = softmax(Variable(-indexxx_.float() / 0.3 - INF * (1 - source_masks[:, None, :].float() ))) # batch x max_trg x max_src 1297 | return indexxx 1298 | --------------------------------------------------------------------------------