├── .gitignore ├── LICENSE ├── LICENSE_nyu ├── README.md ├── bontune_wmt.sh ├── data.py ├── decode.py ├── decode_wmt.sh ├── distill.py ├── joint_wmt.sh ├── mle_wmt.sh ├── model.py ├── mscoco.py ├── run.py ├── scripts ├── i2_iwslt-ende │ ├── bontune_iwslt.sh │ ├── decode_iwslt.sh │ ├── joint_iwslt.sh │ ├── mle_iwslt.sh │ └── tune_iwslt.sh ├── i2_wmt14-deen │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── i2_wmt14-ende │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── i2_wmt16-enro │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── i2_wmt16-roen │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── iwslt-ende │ ├── bontune_iwslt.sh │ ├── decode_iwslt.sh │ ├── joint_iwslt.sh │ ├── mle_iwslt.sh │ └── tune_iwslt.sh ├── wmt14-deen │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── wmt14-ende │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── wmt16-enro │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh └── wmt16-roen │ ├── bontune.sh │ ├── decode.sh │ ├── joint.sh │ ├── mle.sh │ └── tune.sh ├── slides.pdf ├── test.py ├── train.py ├── tune_wmt.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, University of Chinese Academy of Sciences (Chenze Shao) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /LICENSE_nyu: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, New York University (Kyunghyun Cho, Jason Lee, Elman Mansimov) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation 2 | ================================== 3 | PyTorch implementation of the models described in the paper [Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation 4 | ](https://arxiv.org/pdf/1911.09320.pdf "Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation"). 5 | 6 | Dependencies 7 | ------------------ 8 | ### Python 9 | * Python 3.6 10 | * PyTorch >= 0.4 11 | * Numpy 12 | * NLTK 13 | * torchtext 0.2.1 14 | * torchvision 15 | * revtok 16 | * multiset 17 | * ipdb 18 | 19 | Related code 20 | ------------------ 21 | * This code is based on [dl4mt-nonauto](https://github.com/nyu-dl/dl4mt-nonauto "dl4mt-nonauto") and [RSI-NAT](https://github.com/ictnlp/RSI-NAT "RSI-NAT"). We mainly modified the [`model.py`](https://github.com/ictnlp/BoN-NAT/blob/master/model.py "model.py") (line 1107-1292). 22 | 23 | Downloading Datasets 24 | ------------------ 25 | The original translation corpora can be downloaded from ([IWLST'16 En-De](https://wit3.fbk.eu/), [WMT'16 En-Ro](http://www.statmt.org/wmt16/translation-task.html), [WMT'14 En-De](http://www.statmt.org/wmt14/translation-task.html)). We recommend you to download the preprocessed corpora released in [dl4mt-nonauto](https://github.com/nyu-dl/dl4mt-nonauto/tree/multigpu "dl4mt-nonauto"). 26 | Set correct path to data in `data_path()` function located in [`data.py`](https://github.com/ictnlp/BoN-NAT/blob/master/data.py) before you run the code. 27 | 28 | BoN-Joint 29 | ------------------ 30 | Combine the BoN objective and the cross-entropy loss to train NAT from scratch. This process usually takes about 5 days. 31 | ```bash 32 | $ sh joint_wmt.sh 33 | ``` 34 | Take a checkpoint and train the length prediction model. This process usually takes about 1 day. 35 | ```bash 36 | $ sh tune_wmt.sh 37 | ``` 38 | Decode the test set. This process usually takes about 20 seconds. 39 | ```bash 40 | $ sh decode_wmt.sh 41 | ``` 42 | 43 | 44 | BoN-FT 45 | ------------------ 46 | First, train a NAT model using the cross-entropy loss. This process usually takes about 5 days. 47 | ```bash 48 | $ sh mle_wmt.sh 49 | ``` 50 | Then, take a pre-trained checkpoint and finetune the NAT model using the BoN objective. This process usually takes about 3 hours. 51 | ```bash 52 | $ sh bontune_wmt.sh 53 | ``` 54 | Take a finetuned checkpoint and train the length prediction model. This process usually takes about 1 day. 55 | ```bash 56 | $ sh tune_wmt.sh 57 | ``` 58 | Decode the test set. This process usually takes about 20 seconds. 59 | ```bash 60 | $ sh decode_wmt.sh 61 | ``` 62 | 63 | Reinforce-NAT 64 | ------------------ 65 | We also implement Reinforce-NAT (line 1294-1390) described in the paper [Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation](https://arxiv.org/abs/1906.09444 "Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation"). See [RSI-NAT](https://github.com/ictnlp/RSI-NAT "RSI-NAT") for the usage. 66 | 67 | Citation 68 | ------------------ 69 | If you find the resources in this repository useful, please consider citing: 70 | ``` 71 | @article{Shao:19, 72 | author = {Chenze Shao, Jinchao Zhang, Yang Feng, Fandong Meng and Jie Zhou}, 73 | title = {Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation}, 74 | year = {2019}, 75 | journal = {arXiv preprint arXiv:1911.09320}, 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /bontune_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import ipdb 3 | import torch 4 | import random 5 | import numpy as np 6 | import _pickle as pickle 7 | import revtok 8 | import os 9 | from itertools import groupby 10 | import getpass 11 | from collections import Counter 12 | 13 | from torch.autograd import Variable 14 | from torchtext import data, datasets 15 | from nltk.translate.gleu_score import sentence_gleu, corpus_gleu 16 | from nltk.translate.bleu_score import closest_ref_length, brevity_penalty, modified_precision, SmoothingFunction 17 | from contextlib import ExitStack 18 | from collections import OrderedDict 19 | import fractions 20 | 21 | from mscoco import CocoCaptionsIndexedImage, CocoCaptionsIndexedCaption, CocoCaptionsIndexedImageDistill, \ 22 | BatchSamplerImagesSameLength, BatchSamplerCaptionsSameLength 23 | from mscoco import process_json 24 | 25 | try: 26 | fractions.Fraction(0, 1000, _normalize=False) 27 | from fractions import Fraction 28 | except TypeError: 29 | from nltk.compat import Fraction 30 | 31 | def data_path(dataset): 32 | if dataset == "iwslt-ende" or dataset == "iwslt-deen": 33 | path="../IWSLT/en-de/" 34 | elif dataset == "wmt14-ende" or dataset == "wmt14-deen": 35 | path="../wmt14/en-de/" 36 | elif dataset == "wmt16-enro" or dataset == "wmt16-roen": 37 | path="../wmt16/en-ro/" 38 | elif dataset == "wmt17-enlv" or dataset == "wmt17-lven": 39 | path="../wmt17/en-lv/" 40 | elif dataset == "mscoco": 41 | path="mscoco" 42 | 43 | return path 44 | 45 | # load the dataset + reversible tokenization 46 | class NormalField(data.Field): 47 | 48 | def reverse(self, batch, unbpe=True): 49 | if not self.batch_first: 50 | batch.t_() 51 | 52 | with torch.cuda.device_of(batch): 53 | batch = batch.tolist() 54 | 55 | batch = [[self.vocab.itos[ind] for ind in ex] for ex in batch] # denumericalize 56 | 57 | def trim(s, t): 58 | sentence = [] 59 | for w in s: 60 | if w == t: 61 | break 62 | sentence.append(w) 63 | return sentence 64 | 65 | batch = [trim(ex, self.eos_token) for ex in batch] # trim past frst eos 66 | def filter_special(tok): 67 | return tok not in (self.init_token, self.pad_token) 68 | 69 | if unbpe: 70 | batch = [" ".join(filter(filter_special, ex)).replace("@@ ","") for ex in batch] 71 | else: 72 | batch = [" ".join(filter(filter_special, ex)) for ex in batch] 73 | return batch 74 | 75 | class MSCOCOVocab(object): 76 | """Simple vocabulary wrapper.""" 77 | def __init__(self): 78 | self.stoi = {} 79 | self.itos = {} 80 | self.idx = 0 81 | 82 | def add_word(self, word): 83 | if not word in self.stoi: 84 | self.stoi[word] = self.idx 85 | self.itos[self.idx] = word 86 | self.idx += 1 87 | 88 | def __call__(self, word): 89 | if not word in self.stoi: 90 | return self.stoi[''] 91 | return self.stoi[word] 92 | 93 | def __len__(self): 94 | return len(self.stoi) 95 | 96 | class MSCOCODataset(object): 97 | def __init__(self, path, batch_size, max_len=None, valid_size=None, distill=False, use_distillation=False): 98 | self.path = path 99 | 100 | if distill: 101 | self.train_data, self.train_sampler = self.prepare_distill_data(path, 'karpathy_split/train.json.bpe.fixed', batch_size, max_len=max_len, size=None) 102 | else: 103 | train_f = 'karpathy_split/train.json.bpe.fixed' 104 | if use_distillation: 105 | train_f = 'karpathy_split/train.json.bpe.fixed.high.distill' 106 | self.train_data, self.train_sampler = self.prepare_train_data(path, train_f, batch_size, max_len=max_len, size=None) 107 | 108 | self.valid_data, self.valid_sampler = self.prepare_test_data(path, 'karpathy_split/valid.json.bpe.fixed', batch_size, max_len=None, size=valid_size) 109 | self.test_data, self.test_sampler = self.prepare_test_data(path, 'karpathy_split/test.json.bpe.fixed', batch_size, max_len=None, size=valid_size) 110 | 111 | self.unk_token = 0 112 | self.pad_token = 1 113 | self.init_token = 2 114 | self.eos_token = 3 115 | 116 | def prepare_train_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 117 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 118 | 119 | # get max len of dataset 120 | self.max_dataset_length = 0 121 | for bpe in bpes: 122 | len_bpe = len(bpe.split(' ')) 123 | if len_bpe > self.max_dataset_length: 124 | self.max_dataset_length = len_bpe 125 | 126 | dataset_captions = CocoCaptionsIndexedCaption(bpes, features_path, bpe2img, img2bpes) 127 | sampler_captions = BatchSamplerCaptionsSameLength(dataset_captions, batch_size=batch_size) 128 | return dataset_captions, sampler_captions 129 | 130 | def prepare_test_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 131 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 132 | 133 | dataset_images = CocoCaptionsIndexedImage(bpes, features_path, bpe2img, img2bpes) 134 | sampler_images = BatchSamplerImagesSameLength(dataset_images, batch_size=batch_size) 135 | return dataset_images, sampler_images 136 | 137 | def prepare_distill_data(self, dataPath, annFile, batch_size, max_len=None, size=None): 138 | bpes, features_path, bpe2img, img2bpes = process_json(dataPath, annFile, max_len=max_len, size=size) 139 | 140 | dataset_images = CocoCaptionsIndexedImageDistill(bpes, features_path, bpe2img, img2bpes) 141 | sampler_images = BatchSamplerImagesSameLength(dataset_images, batch_size=batch_size) 142 | return dataset_images, sampler_images 143 | 144 | 145 | def build_vocab(self): 146 | """Build a simple vocabulary wrapper.""" 147 | from collections import Counter 148 | 149 | bpes = self.train_data.bpes 150 | 151 | counter = Counter() 152 | for bpe in bpes: 153 | counter.update(bpe.split()) 154 | 155 | words = [word for word, cnt in counter.items()] 156 | 157 | # Creates a vocab wrapper and add some special tokens. 158 | # MAKE SURE CONSTANTS ARE CONSISTENT WITH TRANSLATION DATASETS !!! 159 | self.vocab = MSCOCOVocab() 160 | self.vocab.add_word('') 161 | self.vocab.add_word('') 162 | self.vocab.add_word('') 163 | self.vocab.add_word('') 164 | 165 | # Adds the words to the vocabulary. 166 | for i, word in enumerate(words): 167 | self.vocab.add_word(word) 168 | 169 | def reverse(self, batch, unbpe=True): 170 | #batch = batch.t() 171 | with torch.cuda.device_of(batch): 172 | batch = batch.tolist() 173 | batch = [[self.vocab.itos[ind] for ind in ex] for ex in batch] # denumericalize 174 | 175 | def trim(s, t): 176 | sentence = [] 177 | for w in s: 178 | if w == t: 179 | break 180 | sentence.append(w) 181 | return sentence 182 | 183 | batch = [trim(ex, '') for ex in batch] # trim past frst eos 184 | 185 | def filter_special(tok): 186 | return tok not in ('', '') 187 | 188 | #batch = [filter(filter_special, ex) for ex in batch] 189 | if unbpe: 190 | batch = [" ".join(filter(filter_special, ex)).replace("@@ ","") for ex in batch] 191 | else: 192 | batch = [" ".join(filter(filter_special, ex)) for ex in batch] 193 | return batch 194 | 195 | class TranslationDataset(data.Dataset): 196 | """Defines a dataset for machine translation.""" 197 | 198 | @staticmethod 199 | def sort_key(ex): 200 | return data.interleave_keys(len(ex.src), len(ex.trg)) 201 | 202 | def __init__(self, path, exts, fields, **kwargs): 203 | """Create a TranslationDataset given paths and fields. 204 | Arguments: 205 | path: Common prefix of paths to the data files for both languages. 206 | exts: A tuple containing the extension to path for each language. 207 | fields: A tuple containing the fields that will be used for data 208 | in each language. 209 | Remaining keyword arguments: Passed to the constructor of 210 | data.Dataset. 211 | """ 212 | if not isinstance(fields[0], (tuple, list)): 213 | fields = [('src', fields[0]), ('trg', fields[1])] 214 | 215 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 216 | 217 | examples = [] 218 | with open(src_path) as src_file, open(trg_path) as trg_file: 219 | for src_line, trg_line in zip(src_file, trg_file): 220 | src_line, trg_line = src_line.strip(), trg_line.strip() 221 | if src_line != '' and trg_line != '': 222 | examples.append(data.Example.fromlist( 223 | [src_line, trg_line], fields)) 224 | 225 | super(TranslationDataset, self).__init__(examples, fields, **kwargs) 226 | 227 | @classmethod 228 | def splits(cls, path, exts, fields, root='.data', 229 | train='train', validation='val', test='test', **kwargs): 230 | """Create dataset objects for splits of a TranslationDataset. 231 | Arguments: 232 | root: Root dataset storage directory. Default is '.data'. 233 | exts: A tuple containing the extension to path for each language. 234 | fields: A tuple containing the fields that will be used for data 235 | in each language. 236 | train: The prefix of the train data. Default: 'train'. 237 | validation: The prefix of the validation data. Default: 'val'. 238 | test: The prefix of the test data. Default: 'test'. 239 | Remaining keyword arguments: Passed to the splits method of 240 | Dataset. 241 | """ 242 | #path = cls.download(root) 243 | 244 | train_data = None if train is None else cls( 245 | os.path.join(path, train), exts, fields, **kwargs) 246 | val_data = None if validation is None else cls( 247 | os.path.join(path, validation), exts, fields, **kwargs) 248 | test_data = None if test is None else cls( 249 | os.path.join(path, test), exts, fields, **kwargs) 250 | return tuple(d for d in (train_data, val_data, test_data) 251 | if d is not None) 252 | 253 | 254 | class NormalTranslationDataset(TranslationDataset): 255 | """Defines a dataset for machine translation.""" 256 | 257 | def __init__(self, path, exts, fields, load_dataset=False, save_dataset=False, prefix='', **kwargs): 258 | """Create a TranslationDataset given paths and fields. 259 | 260 | Arguments: 261 | path: Common prefix of paths to the data files for both languages. 262 | exts: A tuple containing the extension to path for each language. 263 | fields: A tuple containing the fields that will be used for data 264 | in each language. 265 | Remaining keyword arguments: Passed to the constructor of 266 | data.Dataset. 267 | """ 268 | if not isinstance(fields[0], (tuple, list)): 269 | fields = [('src', fields[0]), ('trg', fields[1])] 270 | 271 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 272 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 273 | examples = pickle.load(open(path + '.processed.{}.pt'.format(prefix), "rb")) 274 | print ("Loaded TorchText dataset") 275 | else: 276 | examples = [] 277 | with open(src_path,encoding='utf-8') as src_file, open(trg_path,encoding='utf-8') as trg_file: 278 | for src_line, trg_line in zip(src_file, trg_file): 279 | src_line, trg_line = src_line.strip(), trg_line.strip() 280 | if src_line != '' and trg_line != '': 281 | examples.append(data.Example.fromlist( 282 | [src_line, trg_line], fields)) 283 | if save_dataset: 284 | pickle.dump(examples, open(path + '.processed.{}.pt'.format(prefix), "wb")) 285 | print ("Saved TorchText dataset") 286 | 287 | super(TranslationDataset, self).__init__(examples, fields, **kwargs) 288 | 289 | class TripleTranslationDataset(datasets.TranslationDataset): 290 | """Define a triple-translation dataset: src, trg, dec(output of a pre-trained teacher)""" 291 | 292 | def __init__(self, path, exts, fields, load_dataset=False, prefix='', **kwargs): 293 | if not isinstance(fields[0], (tuple, list)): 294 | fields = [('src', fields[0]), ('trg', fields[1]), ('dec', fields[2])] 295 | 296 | src_path, trg_path, dec_path = tuple(os.path.expanduser(path + x) for x in exts) 297 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 298 | examples = torch.load(path + '.processed.{}.pt'.format(prefix)) 299 | else: 300 | examples = [] 301 | with open(src_path) as src_file, open(trg_path) as trg_file, open(dec_path) as dec_file: 302 | for src_line, trg_line, dec_line in zip(src_file, trg_file, dec_file): 303 | src_line, trg_line, dec_line = src_line.strip(), trg_line.strip(), dec_line.strip() 304 | if src_line != '' and trg_line != '' and dec_line != '': 305 | examples.append(data.Example.fromlist( 306 | [src_line, trg_line, dec_line], fields)) 307 | if load_dataset: 308 | torch.save(examples, path + '.processed.{}.pt'.format(prefix)) 309 | 310 | super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 311 | 312 | class ParallelDataset(datasets.TranslationDataset): 313 | """ Define a N-parallel dataset: supports abitriry numbers of input streams""" 314 | 315 | def __init__(self, path=None, exts=None, fields=None, 316 | load_dataset=False, prefix='', examples=None, **kwargs): 317 | 318 | if examples is None: 319 | assert len(exts) == len(fields), 'N parallel dataset must match' 320 | self.N = len(fields) 321 | 322 | paths = tuple(os.path.expanduser(path + x) for x in exts) 323 | if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))): 324 | examples = torch.load(path + '.processed.{}.pt'.format(prefix)) 325 | else: 326 | examples = [] 327 | with ExitStack() as stack: 328 | files = [stack.enter_context(open(fname)) for fname in paths] 329 | for lines in zip(*files): 330 | lines = [line.strip() for line in lines] 331 | if not any(line == '' for line in lines): 332 | examples.append(data.Example.fromlist(lines, fields)) 333 | if load_dataset: 334 | torch.save(examples, path + '.processed.{}.pt'.format(prefix)) 335 | 336 | super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 337 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import ipdb 3 | import math 4 | import os 5 | import torch 6 | import numpy as np 7 | import time 8 | 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, computeBLEU, computeBLEUMSCOCO, Batch, masked_sort, computeGroupBLEU, organise_trg_len_dic, make_decoder_masks, \ 16 | double_source_masks, remove_repeats, remove_repeats_tensor, print_bleu, oracle_converged, equality_converged, jaccard_converged 17 | from time import gmtime, strftime 18 | import copy 19 | from multiset import Multiset 20 | 21 | tokenizer = lambda x: x.replace('@@ ', '').split() 22 | 23 | def run_fast_transformer(decoder_inputs, decoder_masks,\ 24 | sources, source_masks,\ 25 | targets,\ 26 | encoding,\ 27 | model, args, use_argmax=True): 28 | 29 | trg_unidx = model.output_decoding( ('trg', targets),unbpe=False) 30 | src_unidx = model.output_decoding( ('src', sources),unbpe=False) 31 | batch_size, src_len, hsize = encoding[0].size() 32 | #s = open("decoding/dec_source","a") 33 | #r = open("decoding/dec_ref","a") 34 | #l = len(src_unidx) 35 | #for i in range(l): 36 | # s.write(src_unidx[i]+'\n') 37 | # r.write(trg_unidx[i]+'\n') 38 | all_decodings = [] 39 | all_probs = [] 40 | iter_ = 0 41 | bleu_hist = [ [] for xx in range(batch_size) ] 42 | output_hist = [ [] for xx in range(batch_size) ] 43 | multiset_hist = [ [] for xx in range(batch_size) ] 44 | num_iters = [ 0 for xx in range(batch_size) ] 45 | done_ = [False for xx in range(batch_size)] 46 | final_decoding = [ None for xx in range(batch_size) ] 47 | 48 | while True: 49 | curr_iter = min(iter_, args.num_decs-1) 50 | next_iter = min(iter_+1, args.num_decs-1) 51 | 52 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, 53 | decoding=True, return_probs=True, iter_=curr_iter) 54 | 55 | dec_output = decoding.data.cpu().numpy().tolist() 56 | #out_unidx = model.output_decoding( ('trg', decoding ),unbpe=False ) 57 | #o = open("decoding/decode_out" + str(iter_),"a") 58 | #l = len(src_unidx) 59 | #for i in range(l): 60 | # o.write(out_unidx[i]+'\n') 61 | 62 | """ 63 | if args.trg_len_option != "reference": 64 | decoder_masks = 0. * decoder_masks 65 | for bidx in range(batch_size): 66 | try: 67 | decoder_masks[bidx,:(dec_output[bidx].index(3))+1] = 1. 68 | except: 69 | decoder_masks[bidx,:] = 1. 70 | """ 71 | 72 | if args.adaptive_decoding == "oracle": 73 | out_unidx = model.output_decoding( ('trg', decoding ) ) 74 | sentence_bleus = computeBLEU(out_unidx, trg_unidx, corpus=False, tokenizer=tokenizer) 75 | 76 | for bidx in range(batch_size): 77 | output_hist[bidx].append( dec_output[bidx] ) 78 | bleu_hist[bidx].append(sentence_bleus[bidx]) 79 | 80 | converged = oracle_converged( bleu_hist, num_items=args.adaptive_window ) 81 | for bidx in range(batch_size): 82 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 83 | num_iters[bidx] = iter_ + 1 - (args.adaptive_window -1) 84 | done_[bidx] = True 85 | final_decoding[bidx] = output_hist[bidx][-args.adaptive_window] 86 | 87 | elif args.adaptive_decoding == "equality": 88 | for bidx in range(batch_size): 89 | #if 3 in dec_output[bidx]: 90 | # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] 91 | output_hist[bidx].append( dec_output[bidx] ) 92 | 93 | converged = equality_converged( output_hist, num_items=args.adaptive_window ) 94 | 95 | for bidx in range(batch_size): 96 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 97 | num_iters[bidx] = iter_ + 1 98 | done_[bidx] = True 99 | final_decoding[bidx] = output_hist[bidx][-1] 100 | 101 | elif args.adaptive_decoding == "jaccard": 102 | for bidx in range(batch_size): 103 | #if 3 in dec_output[bidx]: 104 | # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] 105 | output_hist[bidx].append( dec_output[bidx] ) 106 | multiset_hist[bidx].append( Multiset(dec_output[bidx]) ) 107 | 108 | converged = jaccard_converged( multiset_hist, num_items=args.adaptive_window ) 109 | 110 | for bidx in range(batch_size): 111 | if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: 112 | num_iters[bidx] = iter_ + 1 113 | done_[bidx] = True 114 | final_decoding[bidx] = output_hist[bidx][-1] 115 | 116 | all_decodings.append( decoding ) 117 | all_probs.append(probs) 118 | 119 | decoder_inputs = 0 120 | if args.next_dec_input in ["both", "emb"]: 121 | if use_argmax: 122 | _, argmax = torch.max(probs, dim=-1) 123 | else: 124 | probs_sz = probs.size() 125 | probs_ = Variable(probs.data, requires_grad=False) 126 | argmax = torch.multinomial(probs_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) 127 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 128 | decoder_inputs += emb 129 | 130 | if args.next_dec_input in ["both", "out"]: 131 | decoder_inputs += out 132 | 133 | iter_ += 1 134 | if iter_ == args.valid_repeat_dec or (False not in done_): 135 | break 136 | 137 | if args.adaptive_decoding != None: 138 | for bidx in range(batch_size): 139 | if num_iters[bidx] == 0: 140 | num_iters[bidx] = 20 141 | if final_decoding[bidx] == None: 142 | if args.adaptive_decoding == "oracle": 143 | final_decoding[bidx] = output_hist[bidx][np.argmax(bleu_hist[bidx])] 144 | else: 145 | final_decoding[bidx] = output_hist[bidx][-1] 146 | 147 | decoding = Variable(torch.LongTensor(np.array(final_decoding))) 148 | if decoder_masks.is_cuda: 149 | decoding = decoding.cuda() 150 | 151 | return decoding, all_decodings, num_iters, all_probs 152 | 153 | def decode_model(args, model, dev, evaluate=True, trg_len_dic=None, 154 | decoding_path=None, names=None, maxsteps=None): 155 | 156 | args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(args.f_size, args.beam_size, args.alpha)) 157 | dev.train = False # make iterator volatile=True 158 | 159 | if not args.no_tqdm: 160 | progressbar = tqdm(total=200, desc='start decoding') 161 | 162 | model.eval() 163 | if not args.debug: 164 | decoding_path.mkdir(parents=True, exist_ok=True) 165 | handles = [(decoding_path / name ).open('w') for name in names] 166 | 167 | corpus_size = 0 168 | src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] 169 | all_decs = [ [] for idx in range(args.valid_repeat_dec)] 170 | decoded_words, target_words, decoded_info = 0, 0, 0 171 | 172 | attentions = None 173 | decoder = model.decoder[0] if args.model is FastTransformer else model.decoder 174 | pad_id = decoder.field.vocab.stoi[''] 175 | eos_id = decoder.field.vocab.stoi[''] 176 | 177 | curr_time = 0 178 | cum_sentences = 0 179 | cum_tokens = 0 180 | cum_images = 0 # used for mscoco 181 | num_iters_total = [] 182 | 183 | for iters, dev_batch in enumerate(dev): 184 | start_t = time.time() 185 | 186 | if args.dataset != "mscoco": 187 | decoder_inputs, decoder_masks,\ 188 | targets, target_masks,\ 189 | sources, source_masks,\ 190 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 191 | else: 192 | # only use first caption for calculating log likelihood 193 | all_captions = dev_batch[1] 194 | dev_batch[1] = dev_batch[1][0] 195 | decoder_inputs, decoder_masks,\ 196 | targets, target_masks,\ 197 | _, source_masks,\ 198 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_len, trg_len_dic=trg_len_dic, bp=args.bp, gpu=args.gpu) 199 | sources = None 200 | 201 | cum_sentences += batch_size 202 | 203 | batch_size, src_len, hsize = encoding[0].size() 204 | 205 | # for now 206 | if type(model) is Transformer: 207 | all_decodings = [] 208 | decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, 209 | beam=args.beam_size, alpha=args.alpha, \ 210 | decoding=True, feedback=attentions) 211 | all_decodings.append( decoding ) 212 | num_iters = [0] 213 | 214 | elif type(model) is FastTransformer: 215 | decoding, all_decodings, num_iters, argmax_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ 216 | sources, source_masks, targets, encoding, model, args, use_argmax=True) 217 | num_iters_total.extend( num_iters ) 218 | 219 | if not args.use_argmax: 220 | for _ in range(args.num_samples): 221 | _, _, _, sampled_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ 222 | sources, source_masks, encoding, model, args, use_argmax=False) 223 | for iter_ in range(args.valid_repeat_dec): 224 | argmax_all_probs[iter_] = argmax_all_probs[iter_] + sampled_all_probs[iter_] 225 | 226 | all_decodings = [] 227 | for iter_ in range(args.valid_repeat_dec): 228 | argmax_all_probs[iter_] = argmax_all_probs[iter_] / args.num_samples 229 | all_decodings.append(torch.max(argmax_all_probs[iter_], dim=-1)[-1]) 230 | decoding = all_decodings[-1] 231 | 232 | used_t = time.time() - start_t 233 | curr_time += used_t 234 | 235 | if args.dataset != "mscoco": 236 | if args.remove_repeats: 237 | outputs_unidx = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', remove_repeats_tensor(decoding))]] 238 | else: 239 | outputs_unidx = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] 240 | 241 | else: 242 | # make sure that 5 captions per each example 243 | num_captions = len(all_captions[0]) 244 | for c in range(1, len(all_captions)): 245 | assert (num_captions == len(all_captions[c])) 246 | 247 | # untokenize reference captions 248 | for n_ref in range(len(all_captions)): 249 | n_caps = len(all_captions[0]) 250 | for c in range(n_caps): 251 | all_captions[n_ref][c] = all_captions[n_ref][c].replace("@@ ","") 252 | 253 | outputs_unidx = [ list(map(list, zip(*all_captions))) ] 254 | 255 | if args.remove_repeats: 256 | all_dec_outputs = [model.output_decoding(d) for d in [('trg', remove_repeats_tensor(all_decodings[ii])) for ii in range(len(all_decodings))]] 257 | else: 258 | all_dec_outputs = [model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))]] 259 | 260 | corpus_size += batch_size 261 | if args.dataset != "mscoco": 262 | cum_tokens += sum([len(xx.split(" ")) for xx in outputs_unidx[0]]) # NOTE source tokens, not target 263 | 264 | if args.dataset != "mscoco": 265 | src_outputs += outputs_unidx[0] 266 | trg_outputs += outputs_unidx[1] 267 | if args.remove_repeats: 268 | dec_outputs += remove_repeats(outputs_unidx[-1]) 269 | else: 270 | dec_outputs += outputs_unidx[-1] 271 | 272 | else: 273 | trg_outputs += outputs_unidx[0] 274 | 275 | for idx, each_output in enumerate(all_dec_outputs): 276 | if args.remove_repeats: 277 | all_decs[idx] += remove_repeats(each_output) 278 | else: 279 | all_decs[idx] += each_output 280 | 281 | #if True: 282 | if False and decoding_path is not None: 283 | for sent_i in range(len(outputs_unidx[0])): 284 | if args.dataset != "mscoco": 285 | print ('SRC') 286 | print (outputs_unidx[0][sent_i]) 287 | for ii in range(len(all_decodings)): 288 | print ('DEC iter {}'.format(ii)) 289 | print (all_dec_outputs[ii][sent_i]) 290 | print ('TRG') 291 | print (outputs_unidx[1][sent_i]) 292 | else: 293 | print ('TRG') 294 | trg = outputs_unidx[0] 295 | for subsent_i in range(len(trg[sent_i])): 296 | print ('TRG {}'.format(subsent_i)) 297 | print (trg[sent_i][subsent_i]) 298 | for ii in range(len(all_decodings)): 299 | print ('DEC iter {}'.format(ii)) 300 | print (all_dec_outputs[ii][sent_i]) 301 | print ('---------------------------') 302 | 303 | timings += [used_t] 304 | 305 | if not args.debug: 306 | for s, t, d in zip(outputs_unidx[0], outputs_unidx[1], outputs_unidx[2]): 307 | s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') 308 | print(s, file=handles[0], flush=True) 309 | print(t, file=handles[1], flush=True) 310 | print(d, file=handles[2], flush=True) 311 | 312 | if not args.no_tqdm: 313 | progressbar.update(iters) 314 | progressbar.set_description('finishing sentences={}/batches={}, \ 315 | length={}/average iter={}, speed={} sec/batch'.format(\ 316 | corpus_size, iters, src_len, np.mean(np.array(num_iters)), curr_time / (1 + iters))) 317 | 318 | if evaluate: 319 | for idx, each_dec in enumerate(all_decs): 320 | if len(all_decs[idx]) != len(trg_outputs): 321 | break 322 | if args.dataset != "mscoco": 323 | bleu_output = computeBLEU(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) 324 | else: 325 | bleu_output = computeBLEUMSCOCO(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) 326 | args.logger.info("iter {} | {}".format(idx+1, print_bleu(bleu_output))) 327 | 328 | if args.adaptive_decoding != None: 329 | args.logger.info("----------------------------------------------") 330 | args.logger.info("Average # iters {}".format(np.mean(num_iters_total))) 331 | bleu_output = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) 332 | args.logger.info("Adaptive BLEU | {}".format(print_bleu(bleu_output))) 333 | 334 | args.logger.info("----------------------------------------------") 335 | args.logger.info("Decoding speed analysis :") 336 | args.logger.info("{} sentences".format(cum_sentences)) 337 | if args.dataset != "mscoco": 338 | args.logger.info("{} tokens".format(cum_tokens)) 339 | args.logger.info("{:.3f} seconds".format(curr_time)) 340 | 341 | args.logger.info("{:.3f} ms / sentence".format((curr_time / float(cum_sentences) * 1000))) 342 | if args.dataset != "mscoco": 343 | args.logger.info("{:.3f} ms / token".format((curr_time / float(cum_tokens) * 1000))) 344 | 345 | args.logger.info("{:.3f} sentences / s".format(float(cum_sentences) / curr_time)) 346 | if args.dataset != "mscoco": 347 | args.logger.info("{:.3f} tokens / s".format(float(cum_tokens) / curr_time)) 348 | args.logger.info("----------------------------------------------") 349 | 350 | if args.decode_which > 0: 351 | args.logger.info("Writing to special file") 352 | parent = decoding_path / "speed" / "b_{}{}".format(args.beam_size if args.model is Transformer else args.valid_repeat_dec, 353 | "" if args.model is Transformer else "_{}".format(args.adaptive_decoding != None)) 354 | args.logger.info(str(parent)) 355 | parent.mkdir(parents=True, exist_ok=True) 356 | speed_handle = (parent / "results.{}".format(args.decode_which) ).open('w') 357 | 358 | print("----------------------------------------------", file=speed_handle, flush=True) 359 | print("Decoding speed analysis :", file=speed_handle, flush=True) 360 | print("{} sentences".format(cum_sentences), file=speed_handle, flush=True) 361 | if args.dataset != "mscoco": 362 | print("{} tokens".format(cum_tokens), file=speed_handle, flush=True) 363 | print("{:.3f} seconds".format(curr_time), file=speed_handle, flush=True) 364 | 365 | print("{:.3f} ms / sentence".format((curr_time / float(cum_sentences) * 1000)), file=speed_handle, flush=True) 366 | if args.dataset != "mscoco": 367 | print("{:.3f} ms / token".format((curr_time / float(cum_tokens) * 1000)), file=speed_handle, flush=True) 368 | 369 | print("{:.3f} sentences / s".format(float(cum_sentences) / curr_time), file=speed_handle, flush=True) 370 | if args.dataset != "mscoco": 371 | print("{:.3f} tokens / s".format(float(cum_tokens) / curr_time), file=speed_handle, flush=True) 372 | print("----------------------------------------------", file=speed_handle, flush=True) 373 | -------------------------------------------------------------------------------- /decode_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import ipdb 3 | import math 4 | import os 5 | import torch 6 | import numpy as np 7 | import time 8 | 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, computeBLEU, Batch, masked_sort, computeGroupBLEU, organise_trg_len_dic, make_decoder_masks, double_source_masks, remove_repeats, remove_repeats_tensor, print_bleu 16 | from time import gmtime, strftime 17 | import copy 18 | from multiset import Multiset 19 | import json 20 | 21 | tokenizer = lambda x: x.replace('@@ ', '').split() 22 | 23 | def distill_model(args, model, dev, evaluate=True, 24 | distill_path=None, names=None, maxsteps=None): 25 | 26 | if not args.no_tqdm: 27 | progressbar = tqdm(total=200, desc='start decoding') 28 | 29 | trg_len_dic = None 30 | 31 | args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(args.f_size, args.beam_size, args.alpha)) 32 | dev.train = False # make iterator volatile=True 33 | 34 | model.eval() 35 | if distill_path is not None: 36 | if args.dataset != "mscoco": 37 | handles = [open(os.path.join(distill_path, name), 'w') for name in names] 38 | else: 39 | distill_annots = [] 40 | distill_filepath = os.path.join(str(distill_path), "train.bpe.fixed.distill") 41 | 42 | 43 | corpus_size = 0 44 | src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] 45 | all_decs = [ [] for idx in range(args.valid_repeat_dec)] 46 | decoded_words, target_words, decoded_info = 0, 0, 0 47 | 48 | attentions = None 49 | decoder = model.decoder[0] if args.model is FastTransformer else model.decoder 50 | pad_id = decoder.field.vocab.stoi[''] 51 | eos_id = decoder.field.vocab.stoi[''] 52 | 53 | curr_time = 0 54 | cum_bs = 0 55 | 56 | for iters, dev_batch in enumerate(dev): 57 | 58 | start_t = time.time() 59 | 60 | if args.dataset != "mscoco": 61 | decoder_inputs, decoder_masks,\ 62 | targets, target_masks,\ 63 | sources, source_masks,\ 64 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 65 | else: 66 | all_captions = dev_batch[1] 67 | all_img_names = dev_batch[2] 68 | dev_batch[1] = dev_batch[1][0] 69 | decoder_inputs, decoder_masks,\ 70 | targets, target_masks,\ 71 | _, source_masks,\ 72 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 73 | 74 | 75 | corpus_size += batch_size 76 | 77 | batch_size, src_len, hsize = encoding[0].size() 78 | 79 | # for now 80 | if type(model) is Transformer: 81 | all_decodings = [] 82 | decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, 83 | beam=args.beam_size, alpha=args.alpha, \ 84 | decoding=True, feedback=attentions) 85 | all_decodings.append( decoding ) 86 | curr_iter = [0] 87 | 88 | used_t = time.time() - start_t 89 | curr_time += used_t 90 | 91 | real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() 92 | if args.dataset != "mscoco": 93 | outputs = [model.output_decoding(d, False) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] 94 | 95 | for s, t, d in zip(outputs[0], outputs[1], outputs[-1]): 96 | #s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') 97 | print(s, file=handles[0], flush=True) 98 | print(t, file=handles[1], flush=True) 99 | print(d, file=handles[2], flush=True) 100 | else: 101 | outputs = [model.output_decoding(d, unbpe=False) for d in [('trg', targets), ('trg', decoding)]] 102 | 103 | for c, (t, d) in enumerate(zip(outputs[0], outputs[1])): 104 | annot = {} 105 | annot['bpes'] = [d] 106 | annot['img_name'] = all_img_names[c] 107 | distill_annots.append(annot) 108 | 109 | json.dump(distill_annots, open(distill_filepath, 'w')) 110 | 111 | if not args.no_tqdm: 112 | progressbar.update(iters) 113 | progressbar.set_description('finishing sentences={}/batches={}, \ 114 | length={}/average iter={}, speed={} sec/batch'.format(\ 115 | corpus_size, iters, src_len, np.mean(np.array(curr_iter)), curr_time / (1 + iters))) 116 | 117 | if args.dataset == "mscoco": 118 | json.dump(distill_annots, open(distill_filepath, 'w')) 119 | 120 | args.logger.info("Total time {}".format((curr_time / float(cum_bs) * 1000))) 121 | -------------------------------------------------------------------------------- /joint_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /mle_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /mscoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import _pickle as pickle 7 | import json 8 | import numpy as np 9 | import time 10 | import random 11 | from collections import OrderedDict 12 | import ipdb 13 | 14 | def process_json(dataPath, annFile, max_len=None, size=None): 15 | annPath = os.path.join(dataPath, annFile) 16 | 17 | # load dataset 18 | annots = json.load(open(annPath, 'r')) 19 | if size != None: 20 | annots = annots[:size] 21 | 22 | bpes = [] 23 | features_path = [] 24 | bpe2img = {} 25 | img2bpes = {} 26 | 27 | bpe_i, feature_i = 0, 0 28 | 29 | for annot in annots: 30 | bpes_i = [] 31 | for bpe in annot['bpes']: 32 | len_bpe = len(bpe.split(' ')) 33 | if max_len != None and len_bpe > max_len: 34 | continue 35 | bpes.append(bpe) 36 | bpe2img[bpe_i] = feature_i 37 | bpes_i.append(bpe_i) 38 | bpe_i = bpe_i + 1 39 | img2bpes[feature_i] = bpes_i 40 | img_name = annot['img_name'] + '.npy' 41 | if 'train' in img_name: 42 | load_path = os.path.join(dataPath, 'train2014_features') 43 | elif 'val' in img_name: 44 | load_path = os.path.join(dataPath, 'val2014_features') 45 | else: 46 | sys.exit() 47 | features_path.append(os.path.join(load_path, img_name)) 48 | feature_i = feature_i + 1 49 | 50 | return bpes, features_path, bpe2img, img2bpes 51 | 52 | def minibatch_same_length(lengths, batch_size): 53 | # make sure all of them are integers 54 | all(isinstance(ll, int) for ll in lengths) 55 | 56 | # sort them out 57 | len_unique = np.unique(lengths) 58 | 59 | # indices of unique lengths 60 | len_indices = OrderedDict() 61 | len_counts = OrderedDict() 62 | for ll in len_unique: 63 | len_indices[ll] = np.where(lengths == ll)[0] 64 | len_counts[ll] = len(len_indices[ll]) 65 | 66 | # sort indicies into minibatches 67 | minibatches = [] 68 | len_indices_keys = list(len_indices.keys()) 69 | for k in len_indices_keys: 70 | avg_samples = max(1, int(batch_size / k)) 71 | for j in range(0, len_counts[k], avg_samples): 72 | minibatches.append(len_indices[k][j:j+avg_samples]) 73 | 74 | return minibatches 75 | 76 | class BatchSamplerCaptionsSameLength(object): 77 | def __init__(self, dataset, batch_size): 78 | assert (type(dataset) == CocoCaptionsIndexedCaption) 79 | self.bpes = dataset.bpes 80 | lengths = [] 81 | 82 | for bpe in self.bpes: 83 | len_bpe = len(bpe.split(' ')) 84 | lengths.append(len_bpe) 85 | 86 | self.minibatches = minibatch_same_length(lengths, batch_size) 87 | random.shuffle(self.minibatches) 88 | 89 | def __iter__(self): 90 | # randomly sample minibatch index 91 | for i in range(len(self.minibatches)): 92 | minibatch = self.minibatches[i] 93 | yield minibatch 94 | 95 | def __len__(self): 96 | return len(self.minibatches) 97 | 98 | class BatchSamplerImagesSameLength(object): 99 | def __init__(self, dataset, batch_size): 100 | assert (type(dataset) == CocoCaptionsIndexedImage or type(dataset) == CocoCaptionsIndexedImageDistill) 101 | self.img2bpes = dataset.img2bpes 102 | self.bpes = dataset.bpes 103 | 104 | # calculate average length of 5 captions for each image 105 | lengths = [] 106 | img_keys = self.img2bpes.keys() 107 | for i in img_keys: 108 | length_i = [] 109 | for bpe_i in self.img2bpes[i]: 110 | length_i.append(len(self.bpes[bpe_i].split())) 111 | lengths.append(int(np.mean(np.array(length_i)))) 112 | 113 | self.minibatches = minibatch_same_length(lengths, batch_size) 114 | random.shuffle(self.minibatches) 115 | 116 | 117 | def __iter__(self): 118 | # randomly sample minibatch index 119 | for i in range(len(self.minibatches)): 120 | minibatch = self.minibatches[i] 121 | yield minibatch 122 | 123 | def __len__(self): 124 | return len(self.minibatches) 125 | 126 | # dataset indexed based on images 127 | class CocoCaptionsIndexedImage(data.Dataset): 128 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 129 | self.bpes = bpes 130 | self.features_path = features_path 131 | self.bpe2img = bpe2img 132 | self.img2bpes = img2bpes 133 | 134 | def __getitem__(self, index): 135 | feature = np.float32(np.load(self.features_path[index])) 136 | bpes = [] 137 | for i in self.img2bpes[index]: 138 | bpes.append(self.bpes[i]) 139 | return torch.from_numpy(feature), bpes 140 | 141 | def __len__(self): 142 | return len(self.img2bpes.keys()) 143 | 144 | class CocoCaptionsIndexedImageDistill(data.Dataset): 145 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 146 | self.bpes = bpes 147 | self.features_path = features_path 148 | self.bpe2img = bpe2img 149 | self.img2bpes = img2bpes 150 | 151 | def __getitem__(self, index): 152 | feature = np.float32(np.load(self.features_path[index])) 153 | img_name = self.features_path[index].split('/')[-1].split('.')[0] 154 | bpes = [] 155 | for i in self.img2bpes[index]: 156 | bpes.append(self.bpes[i]) 157 | return torch.from_numpy(feature), bpes, img_name 158 | 159 | def __len__(self): 160 | return len(self.img2bpes.keys()) 161 | 162 | # dataset indexed based on captions 163 | class CocoCaptionsIndexedCaption(data.Dataset): 164 | def __init__(self, bpes, features_path, bpe2img, img2bpes): 165 | self.bpes = bpes 166 | self.features_path = features_path 167 | self.bpe2img = bpe2img 168 | self.img2bpes = img2bpes 169 | 170 | def __getitem__(self, index): 171 | bpe = self.bpes[index] 172 | feature = np.float32(np.load(self.features_path[self.bpe2img[index]])) 173 | return torch.from_numpy(feature), bpe 174 | 175 | def __len__(self): 176 | return len(self.bpe2img.keys()) 177 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['QT_QPA_PLATFORM']='offscreen' # weird can't ipdb with mscoco without this flag 3 | import torch 4 | import numpy as np 5 | from torchtext import data 6 | from torchtext import datasets 7 | from torch.nn import functional as F 8 | from torch.autograd import Variable 9 | 10 | import revtok 11 | import logging 12 | import random 13 | import ipdb 14 | import string 15 | import traceback 16 | import math 17 | import uuid 18 | import argparse 19 | import copy 20 | import time 21 | import pickle 22 | 23 | from train import train_model 24 | from distill import distill_model 25 | from model import FastTransformer, Transformer, INF, TINY, HighwayBlock, ResidualBlock, NonresidualBlock 26 | from utils import mkdir, organise_trg_len_dic, init_encoder 27 | from data import NormalField, NormalTranslationDataset, MSCOCODataset, data_path 28 | from time import gmtime, strftime 29 | from decode import decode_model 30 | 31 | import itertools 32 | from traceback import extract_tb 33 | from code import interact 34 | from pathlib import Path 35 | 36 | parser = argparse.ArgumentParser(description='Train a Transformer / FastTransformer.') 37 | 38 | # dataset settings 39 | parser.add_argument('--strong', action='store_true',default=False) 40 | parser.add_argument('--n',type=int, default=2) 41 | parser.add_argument('--alph',type=float, default=0.1) 42 | parser.add_argument('--joint', action='store_true', default=False) 43 | parser.add_argument('--ng_finetune', action='store_true', default=False) 44 | parser.add_argument('--rf_finetune', action='store_true', default=False) 45 | parser.add_argument('--nat_finetune', action='store_true', default=False) 46 | parser.add_argument('--sample_method', type=str, default='sentence', choices=['sentence','stepwise']) 47 | parser.add_argument('--stepwise_sampletimes', type=int, default=10) 48 | parser.add_argument('--topk', type=int, default=5) 49 | parser.add_argument('--workers', type=int, default=5) 50 | 51 | 52 | parser.add_argument('--dataset', type=str, default='iwslt-ende', choices=['iwslt-ende', 'iwslt-deen', \ 53 | 'wmt14-ende', 'wmt14-deen', \ 54 | 'wmt16-enro', 'wmt16-roen', \ 55 | 'wmt17-enlv', 'wmt17-lven', \ 56 | 'mscoco']) 57 | parser.add_argument('--vocab_size', type=int, default=40000, help='limit the train set sentences to this many tokens') 58 | 59 | parser.add_argument('--valid_size', type=int, default=None, help='size of valid dataset (tested on coco only)') 60 | parser.add_argument('--load_vocab', action='store_true', help='load a pre-computed vocabulary') 61 | parser.add_argument('--load_dataset', action='store_true', default=False, help='load a pre-processed dataset') 62 | parser.add_argument('--save_dataset', action='store_true', default=False, help='save a pre-processed dataset') 63 | parser.add_argument('--max_len', type=int, default=None, help='limit the train set sentences to this many tokens') 64 | parser.add_argument('--max_train_data', type=int, default=None, help='limit the train set sentences to this many sentences') 65 | 66 | # model basic settings 67 | parser.add_argument('--prefix', type=str, default='[time]', help='prefix to denote the model, nothing or [time]') 68 | parser.add_argument('--fast', dest='model', action='store_const', const=FastTransformer, default=Transformer) 69 | 70 | # model ablation settings 71 | parser.add_argument('--ffw_block', type=str, default="residual", choices=['residual', 'highway', 'nonresidual']) 72 | parser.add_argument('--diag', action='store_true', default=False, help='ignore diagonal attention when doing self-attention.') 73 | parser.add_argument('--use_wo', action='store_true', default=True, help='use output weight matrix in multihead attention') 74 | parser.add_argument('--inputs_dec', type=str, default='pool', choices=['zeros', 'pool'], help='inputs to first decoder') 75 | parser.add_argument('--out_norm', action='store_true', default=False, help='normalize last softmax layer') 76 | parser.add_argument('--share_embed', action='store_true', default=True, help='share embeddings and linear out weight') 77 | parser.add_argument('--share_vocab', action='store_true', default=True, help='share vocabulary between src and target') 78 | parser.add_argument('--share_embed_enc_dec1', action='store_true', default=False, help='share embedding weigth between encoder and first decoder') 79 | parser.add_argument('--positional', action='store_true', default=True, help='incorporate positional information in key/value') 80 | parser.add_argument('--enc_last', action='store_true', default=False, help='attend only to last encoder hidden states') 81 | 82 | parser.add_argument('--params', type=str, default='user', choices=['user', 'small', 'big']) 83 | parser.add_argument('--n_layers', type=int, default=5, help='number of layers') 84 | parser.add_argument('--n_heads', type=int, default=2, help='number of heads') 85 | parser.add_argument('--d_model', type=int, default=278, help='number of heads') 86 | parser.add_argument('--d_hidden', type=int, default=507, help='number of heads') 87 | 88 | parser.add_argument('--num_decs', type=int, default=2, help='1 (one shared decoder) \ 89 | 2 (2nd decoder and above is shared) \ 90 | -1 (no decoder is shared)') 91 | parser.add_argument('--train_repeat_dec', type=int, default=4, help='number of times to repeat generation') 92 | parser.add_argument('--valid_repeat_dec', type=int, default=4, help='number of times to repeat generation') 93 | parser.add_argument('--use_argmax', action='store_true', default=False) 94 | parser.add_argument('--next_dec_input', type=str, default='both', choices=['emb', 'out', 'both']) 95 | 96 | parser.add_argument('--bp', type=float, default=1.0, help='number of heads') 97 | 98 | # running setting 99 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'distill']) # distill : take a trained AR model and decode a training set 100 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use or -1 for CPU') 101 | parser.add_argument('--seed', type=int, default=19920206, help='seed for randomness') 102 | parser.add_argument('--distill_which', type=int, default=0 ) 103 | parser.add_argument('--decode_which', type=int, default=0 ) 104 | parser.add_argument('--test_which', type=str, default='test', choices=['valid', 'test']) # distill : take a trained AR model and decode a training set 105 | 106 | # training 107 | parser.add_argument('--no_tqdm', action="store_true", default=False) 108 | parser.add_argument('--eval_every', type=int, default=1000, help='run dev every') 109 | parser.add_argument('--save_every', type=int, default=-1, help='5000') 110 | parser.add_argument('--batch_size', type=int, default=2048, help='# of tokens processed per batch') 111 | parser.add_argument('--optimizer', type=str, default='Adam') 112 | parser.add_argument('--lr', type=float, default=3e-4) 113 | parser.add_argument('--lr_schedule', type=str, default='anneal', choices=['transformer', 'anneal', 'fixed']) 114 | parser.add_argument('--warmup', type=int, default=16000, help='maximum steps to linearly anneal the learning rate') 115 | parser.add_argument('--anneal_steps', type=int, default=250000, help='maximum steps to linearly anneal the learning rate') 116 | parser.add_argument('--maximum_steps', type=int, default=5000000, help='maximum steps you take to train a model') 117 | parser.add_argument('--drop_ratio', type=float, default=0.1, help='dropout ratio') 118 | parser.add_argument('--drop_len_pred', type=float, default=0.3, help='dropout ratio for length prediction module') 119 | parser.add_argument('--input_drop_ratio', type=float, default=0.1, help='dropout ratio only for inputs') 120 | parser.add_argument('--grad_clip', type=float, default=-1.0, help='gradient clipping') 121 | 122 | # target length 123 | parser.add_argument('--trg_len_option', type=str, default="reference", choices=['reference', "noisy_ref", 'average', 'fixed', 'predict']) 124 | #parser.add_argument('--trg_len_option_valid', type=str, default="average", choices=['reference', "noisy_ref", 'average', 'fixed', 'predict']) 125 | parser.add_argument('--trg_len_ratio', type=float, default=2.0) 126 | parser.add_argument('--decoder_input_how', type=str, default='copy', choices=['copy', 'interpolate', 'pad', 'wrap']) 127 | parser.add_argument('--finetune_trg_len', action='store_true', default=False, help="finetune one layer that predicts target len offset") 128 | parser.add_argument('--use_predicted_trg_len', action='store_true', default=False, help="use predicted target len masks") 129 | parser.add_argument('--max_offset', type=int, default=20, help='max target len offset of the whole dataset') 130 | 131 | # denoising 132 | parser.add_argument('--denoising_prob', type=float, default=0.0, help="use denoising with this probability") 133 | parser.add_argument('--denoising_weight', type=float, default=0.1, help="use denoising with this weight.") 134 | parser.add_argument('--corruption_probs', type=str, default="0-0-0-1-1-1-0", help="probs for \ 135 | repeat\ 136 | add random word\ 137 | repeat and drop next\ 138 | replace with random word\ 139 | swap\ 140 | global swap") 141 | parser.add_argument('--denoising_out_weight', type=float, default=0.0, help="use denoising for decoder output with this weight.") 142 | parser.add_argument('--anneal_denoising_weight', action='store_true', default=False, help="anneal denoising weight over time") 143 | parser.add_argument('--layerwise_denoising_weight', action='store_true', default=False, help="use different denoising weight per iteration") 144 | 145 | # self-distillation 146 | parser.add_argument('--self_distil', type=float, default=0.0) 147 | 148 | # decoding 149 | parser.add_argument('--length_ratio', type=int, default=2, help='maximum lengths of decoding') 150 | parser.add_argument('--length_dec', type=int, default=20, help='maximum length of decoding for MSCOCO dataset') 151 | parser.add_argument('--beam_size', type=int, default=1, help='beam-size used in Beamsearch, default using greedy decoding') 152 | parser.add_argument('--f_size', type=int, default=1, help='heap size for sampling/searching in the fertility space') 153 | parser.add_argument('--alpha', type=float, default=1, help='length normalization weights') 154 | parser.add_argument('--temperature', type=float, default=1, help='smoothing temperature for noisy decodig') 155 | parser.add_argument('--remove_repeats', action='store_true', default=False, help='debug mode: no saving or tensorboard') 156 | parser.add_argument('--num_samples', type=int, default=2, help='number of samples to use when using non-argmax decoding') 157 | parser.add_argument('--T', type=float, default=1, help='softmax temperature when decoding') 158 | 159 | #parser.add_argument('--jaccard_stop', action='store_true', default=False, help='use jaccard index to stop decoding') 160 | parser.add_argument('--adaptive_decoding', type=str, default=None, choices=["oracle", "jaccard", "equality"]) 161 | parser.add_argument('--adaptive_window', type=int, default=5, help='window size for adaptive decoding') 162 | parser.add_argument('--len_stop', action='store_true', default=False, help='use length of sentence to stop decoding') 163 | parser.add_argument('--jaccard_thresh', type=float, default=1.0) 164 | 165 | # model saving/reloading, output translations 166 | parser.add_argument('--load_from', type=str, default=None, help='load from checkpoint') 167 | parser.add_argument('--load_encoder_from', type=str, default=None, help='load from checkpoint') 168 | parser.add_argument('--resume', action='store_true', help='when loading from the saved model, it resumes from that.') 169 | parser.add_argument('--use_distillation', action='store_true', default=False, help='train a NAR model from output of an AR model') 170 | 171 | # debugging 172 | parser.add_argument('--debug', action='store_true', help='debug mode: no saving or tensorboard') 173 | parser.add_argument('--tensorboard', action='store_true', help='use TensorBoard') 174 | 175 | # save path 176 | parser.add_argument('--main_path', type=str, default="./") # /misc/vlgscratch2/ChoGroup/mansimov/ 177 | parser.add_argument('--model_path', type=str, default="models") # /misc/vlgscratch2/ChoGroup/mansimov/ 178 | parser.add_argument('--log_path', type=str, default="logs") # /misc/vlgscratch2/ChoGroup/mansimov/ 179 | parser.add_argument('--event_path', type=str, default="events") # /misc/vlgscratch2/ChoGroup/mansimov/ 180 | parser.add_argument('--decoding_path', type=str, default="decoding") # /misc/vlgscratch2/ChoGroup/mansimov/ 181 | parser.add_argument('--distill_path', type=str, default="distill") # /misc/vlgscratch2/ChoGroup/mansimov/ 182 | 183 | parser.add_argument('--model_str', type=str, default="") # /misc/vlgscratch2/ChoGroup/mansimov/ 184 | 185 | # ----------------------------------------------------------------------------------------------------------------- # 186 | 187 | args = parser.parse_args() 188 | 189 | if args.model is Transformer: 190 | args.num_decs = 1 191 | args.train_repeat_dec = 1 192 | args.valid_repeat_dec = 1 193 | 194 | args.main_path = Path(args.main_path) 195 | 196 | args.model_path = args.main_path / args.model_path / args.dataset 197 | args.log_path = args.main_path / args.log_path / args.dataset 198 | args.event_path = args.main_path / args.event_path / args.dataset 199 | args.decoding_path = args.main_path / args.decoding_path / args.dataset 200 | args.distill_path = args.main_path / args.distill_path / args.dataset 201 | 202 | if not args.debug: 203 | for path in [args.model_path, args.log_path, args.event_path, args.decoding_path, args.distill_path]: 204 | path.mkdir(parents=True, exist_ok=True) 205 | 206 | if args.prefix == '[time]': 207 | args.prefix = strftime("%m.%d_%H.%M.", gmtime()) 208 | 209 | if args.train_repeat_dec == 1: 210 | args.num_decs = 1 211 | 212 | # get the langauage pairs: 213 | if args.dataset != "mscoco": 214 | args.src = args.dataset[-4:][:2] # source language 215 | args.trg = args.dataset[-4:][2:] # target language 216 | else: 217 | args.src = "" 218 | args.trg = "" 219 | 220 | if args.params == 'small': 221 | hparams = {'d_model': 278, 'd_hidden': 507, 'n_layers': 5, 'n_heads': 2, 'warmup': 746} 222 | args.__dict__.update(hparams) 223 | elif args.params == 'big': 224 | if args.dataset != "mscoco": 225 | hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'warmup': 16000} 226 | else: 227 | hparams = {'d_model': 512, 'd_hidden': 512, 'n_heads': 8, 'warmup': 16000} 228 | args.__dict__.update(hparams) 229 | 230 | hp_str = "{}".format('' if args.model is FastTransformer else 'ar_') + \ 231 | "{}".format(args.model_str+"_" if args.model_str != "" else "") + \ 232 | "{}".format("ar_distil_" if args.use_distillation else "") + \ 233 | "{}".format("ptrn_enc_" if not args.load_encoder_from is None else "") + \ 234 | "{}".format("ptrn_model_" if not args.load_from is None else "") + \ 235 | "voc{}k_".format(args.vocab_size//1000) + \ 236 | "{}_".format(args.batch_size) + \ 237 | "{}".format("" if args.share_embed else "no_share_emb_") + \ 238 | "{}".format("" if args.share_vocab else "no_share_voc_") + \ 239 | "{}".format("share_emb_enc_dec1_" if args.share_embed_enc_dec1 else "") + \ 240 | "{}_{}_{}_{}_".format(args.n_layers, args.d_model, args.d_hidden, args.n_heads) + \ 241 | "{}".format("enc_last_" if args.enc_last else "") + \ 242 | "drop_{}_".format(args.drop_ratio) + \ 243 | "{}".format("drop_len_pred_{}_".format(args.drop_len_pred) if args.finetune_trg_len else "") + \ 244 | "{}_".format(args.lr) + \ 245 | "{}_".format("{}".format(args.lr_schedule[:4])) + \ 246 | "{}".format("anneal_steps_{}_".format(args.anneal_steps) if args.lr_schedule == "anneal" else "") + \ 247 | "{}_".format(args.ffw_block[:4]) + \ 248 | "{}".format("clip_{}_".format(args.grad_clip) if args.grad_clip != -1.0 else "") + \ 249 | "{}".format("diag_" if args.diag else "") + \ 250 | ("tr{}_".format(args.train_repeat_dec) + \ 251 | "{}decs_".format(args.num_decs) + \ 252 | "{}_".format(args.bp if args.trg_len_option == "noisy_ref" else "") + \ 253 | "{}_".format(args.trg_len_option[:4]) + \ 254 | "{}_".format(args.next_dec_input) + \ 255 | "{}".format("trg_{}x_".format(args.trg_len_ratio) if "fixed" in args.trg_len_option else "") + \ 256 | "{}_".format(args.decoder_input_how[:4]) + \ 257 | "{}".format("dn_{}_".format(args.denoising_prob) if args.denoising_prob != 0.0 else "") + \ 258 | "{}".format("dn_w{}_".format(args.denoising_weight) if args.denoising_prob != 0.0 and not args.anneal_denoising_weight and not args.layerwise_denoising_weight else "") + \ 259 | "{}".format("dn_anneal_" if args.anneal_denoising_weight else "") + \ 260 | "{}".format("dn_layer_" if args.layerwise_denoising_weight else "") + \ 261 | "{}".format("dn_out_w{}_".format(args.denoising_out_weight) if args.denoising_out_weight != 0.0 else "") + \ 262 | "{}".format("distil{}_".format(args.self_distil) if args.self_distil != 0.0 else "") + \ 263 | "{}".format("argmax_" if args.use_argmax else "sample_") + \ 264 | "{}".format("out_norm_" if args.out_norm else "") + \ 265 | "" if args.model is FastTransformer else "" ) 266 | 267 | args.id_str = Path(args.prefix + hp_str) 268 | 269 | args.corruption_probs = [int(xx) for xx in args.corruption_probs.split("-") ] 270 | c_probs_sum = sum(args.corruption_probs) 271 | args.corruption_probs = [xx/c_probs_sum for xx in args.corruption_probs] 272 | 273 | if args.ffw_block == "nonresidual": 274 | args.block_cls = NonresidualBlock 275 | elif args.ffw_block == "residual": 276 | args.block_cls = ResidualBlock 277 | elif args.ffw_block == "highway": 278 | args.block_cls = HighwayBlock 279 | else: 280 | raise 281 | 282 | # setup logger settings 283 | logger = logging.getLogger() 284 | logger.setLevel(logging.DEBUG) 285 | formatter = logging.Formatter('%(asctime)s %(levelname)s: - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 286 | 287 | ch = logging.StreamHandler() 288 | ch.setLevel(logging.DEBUG) 289 | ch.setFormatter(formatter) 290 | logger.addHandler(ch) 291 | if not args.debug: 292 | fh = logging.FileHandler( str( args.log_path / args.id_str ) + ".txt" ) 293 | fh.setLevel(logging.DEBUG) 294 | fh.setFormatter(formatter) 295 | logger.addHandler(fh) 296 | 297 | # setup random seeds 298 | random.seed(args.seed) 299 | np.random.seed(args.seed) 300 | torch.manual_seed(args.seed) 301 | torch.cuda.manual_seed_all(args.seed) 302 | 303 | # ----------------------------------------------------------------------------------------------------------------- # 304 | if args.dataset != "mscoco": 305 | DataField = NormalField 306 | TRG = DataField(init_token='', eos_token='', batch_first=True) 307 | SRC = DataField(batch_first=True) if not args.share_vocab else TRG 308 | # NOTE : UNK, PAD, INIT, EOS 309 | 310 | # setup many datasets (need to manaually setup) 311 | data_prefix = Path(data_path(args.dataset)) 312 | args.data_prefix = data_prefix 313 | if args.dataset == "mscoco": 314 | data_prefix = str(data_prefix) 315 | train_dir = "train" if not args.use_distillation else "distill/" + args.dataset[-4:] 316 | if args.dataset == 'iwslt-ende' or args.dataset == 'iwslt-deen': 317 | #if args.resume: 318 | # train_dir += "2" 319 | logger.info("TRAINING CORPUS : " + str(data_prefix / train_dir / 'train.tags.en-de.bpe')) 320 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / 'train.tags.en-de.bpe'), 321 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 322 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 323 | if args.mode in ["train", "distill"] else None 324 | 325 | dev_dir = "dev" 326 | dev_file = "valid.en-de.bpe" 327 | if args.mode == "test" and args.decode_which > 0: 328 | dev_dir = "dev_split" 329 | dev_file += ".{}".format(args.decode_which) 330 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 331 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 332 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 333 | 334 | test_data = None 335 | 336 | elif args.dataset == 'wmt14-ende' or args.dataset == 'wmt14-deen': 337 | train_file = 'all_en-de.bpe' 338 | if args.strong == True: 339 | train_file += '.strong' 340 | if args.mode == "distill" and args.distill_which > 0: 341 | train_file += ".{}".format(args.distill_which) 342 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / train_file), 343 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 344 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 345 | if args.mode in ["train", "distill"] else None 346 | 347 | dev_dir = "dev" 348 | dev_file = "wmt13-en-de.bpe" 349 | if args.mode == "test" and args.decode_which > 0: 350 | dev_dir = "dev_split" 351 | dev_file += ".{}".format(args.decode_which) 352 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 353 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 354 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 355 | 356 | test_dir = "test" 357 | test_file = "wmt14-en-de.bpe" 358 | if args.mode == "test" and args.decode_which > 0: 359 | test_dir = "test_split" 360 | test_file += ".{}".format(args.decode_which) 361 | test_data = NormalTranslationDataset(path=str(data_prefix / test_dir / test_file), 362 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 363 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 364 | 365 | elif args.dataset == 'wmt16-enro' or args.dataset == 'wmt16-roen': 366 | train_file = 'corpus.bpe' 367 | if args.mode == "distill" and args.distill_which > 0: 368 | train_file += ".{}".format(args.distill_which) 369 | train_data = NormalTranslationDataset(path=str(data_prefix / train_dir / train_file), 370 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 371 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') \ 372 | if args.mode in ["train", "distill"] else None 373 | 374 | dev_dir = "dev" 375 | dev_file = "dev.bpe" 376 | if args.mode == "test" and args.decode_which > 0: 377 | dev_dir = "dev_split" 378 | dev_file += ".{}".format(args.decode_which) 379 | dev_data = NormalTranslationDataset(path=str(data_prefix / dev_dir / dev_file), 380 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 381 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 382 | 383 | test_dir = "test" 384 | test_file = "test.bpe" 385 | if args.mode == "test" and args.decode_which > 0: 386 | test_dir = "test_split" 387 | test_file += ".{}".format(args.decode_which) 388 | test_data = NormalTranslationDataset(path=str(data_prefix / test_dir / test_file), 389 | exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), 390 | load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 391 | 392 | elif args.dataset == 'wmt17-enlv' or args.dataset == 'wmt17-lven': 393 | train_data, dev_data, test_data = NormalTranslationDataset.splits( 394 | path=data_prefix, train='{}/corpus.bpe'.format(train_dir), test='test/newstest2017.bpe', 395 | validation='dev/newsdev2017.bpe', exts=('.{}'.format(args.src), '.{}'.format(args.trg)), 396 | fields=(SRC, TRG), load_dataset=args.load_dataset, save_dataset=args.save_dataset, prefix='normal') 397 | 398 | elif args.dataset == "mscoco": 399 | mscoco_dataset = MSCOCODataset(path=data_prefix, batch_size=args.batch_size, \ 400 | max_len=args.max_len, valid_size=args.valid_size, \ 401 | distill=(args.mode == "distill"), use_distillation=args.use_distillation) 402 | train_data, train_sampler = mscoco_dataset.train_data, mscoco_dataset.train_sampler 403 | dev_data, dev_sampler = mscoco_dataset.valid_data, mscoco_dataset.valid_sampler 404 | test_data, test_sampler = mscoco_dataset.test_data, mscoco_dataset.test_sampler 405 | if args.trg_len_option == "predict" and args.max_offset == None: 406 | args.max_offset = mscoco_dataset.max_dataset_length 407 | else: 408 | raise NotImplementedError 409 | # build vocabularies for translation dataset 410 | if args.dataset != "mscoco": 411 | vocab_path = data_prefix / 'vocab' / '{}_{}_{}_{}.pt'.format('{}-{}'.format(args.src, args.trg), args.vocab_size, 'shared' if args.share_vocab else '', 'strong' if args.strong else '') 412 | if args.load_vocab and vocab_path.exists(): 413 | src_vocab, trg_vocab = torch.load(str(vocab_path)) 414 | SRC.vocab = src_vocab 415 | TRG.vocab = trg_vocab 416 | logger.info('vocab loaded') 417 | else: 418 | assert (not train_data is None) 419 | if not args.share_vocab: 420 | SRC.build_vocab(train_data, dev_data, max_size=args.vocab_size) 421 | TRG.build_vocab(train_data, dev_data, max_size=args.vocab_size) 422 | if not args.debug: 423 | logger.info('save the vocabulary') 424 | vocab_path.parent.mkdir(parents=True, exist_ok=True) 425 | torch.save([SRC.vocab, TRG.vocab], str(vocab_path)) 426 | args.__dict__.update({'trg_vocab': len(TRG.vocab), 'src_vocab': len(SRC.vocab)}) 427 | # for mscoco 428 | else: 429 | vocab_path = os.path.join(data_prefix, "vocab.pkl") 430 | assert (args.load_vocab == True) 431 | if args.load_vocab and os.path.exists(vocab_path): 432 | vocab = pickle.load(open(vocab_path, 'rb')) 433 | mscoco_dataset.vocab = vocab 434 | else: 435 | logger.info('save the vocabulary') 436 | mscoco_dataset.build_vocab() 437 | pickle.dump(mscoco_dataset.vocab, open(vocab_path, 'wb')) 438 | print ('vocab building done') 439 | args.__dict__.update({'vocab': len(mscoco_dataset.vocab)}) 440 | 441 | def dyn_batch_with_padding(new, i, sofar): 442 | prev_max_len = sofar / (i - 1) if i > 1 else 0 443 | return max(len(new.src), len(new.trg), prev_max_len) * i 444 | 445 | def dyn_batch_without_padding(new, i, sofar): 446 | return sofar + max(len(new.src), len(new.trg)) 447 | 448 | # not sure if absolutely necessary? seems to mess things up. 449 | if args.dataset != "mscoco" and args.share_vocab: 450 | SRC = copy.deepcopy(SRC) 451 | SRC.init_token = None 452 | SRC.eos_token = None 453 | 454 | for data_ in [train_data, dev_data, test_data]: 455 | if not data_ is None: 456 | data_.fields['src'] = SRC 457 | 458 | if args.dataset != "mscoco": 459 | if not train_data is None: 460 | logger.info("before pruning : {} training examples".format(len(train_data.examples))) 461 | if args.max_len is not None: 462 | if args.dataset != "mscoco": 463 | train_data.examples = [ex for ex in train_data.examples if len(ex.trg) <= args.max_len] 464 | if args.max_train_data is not None: 465 | train_data.examples = train_data.examples[:args.max_train_data] 466 | logger.info("after pruning : {} training examples".format(len(train_data.examples))) 467 | 468 | if args.batch_size == 1: # speed-test: one sentence per batch. 469 | batch_size_fn = lambda new, count, sofar: count 470 | else: 471 | batch_size_fn = dyn_batch_without_padding if args.model is Transformer else dyn_batch_with_padding 472 | 473 | if args.dataset != "mscoco": 474 | if args.mode == "train": 475 | train_flag = True 476 | elif args.mode == "distill": 477 | train_flag = False 478 | else: 479 | train_flag = False 480 | train_real = data.BucketIterator(train_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 481 | train=train_flag, repeat=train_flag, shuffle=train_flag) if not train_data is None else None 482 | dev_real = data.BucketIterator(dev_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 483 | train=False, repeat=False, shuffle=False) if not dev_data is None else None 484 | test_real = data.BucketIterator(test_data, args.batch_size, device=args.gpu, batch_size_fn=batch_size_fn, 485 | train=False, repeat=False, shuffle=False) if not test_data is None else None 486 | else: 487 | train_real = torch.utils.data.DataLoader( 488 | train_data, batch_sampler=train_sampler, pin_memory=args.gpu>-1, num_workers=8) 489 | dev_real = torch.utils.data.DataLoader( 490 | dev_data, batch_sampler=dev_sampler, pin_memory=args.gpu>-1, num_workers=8) 491 | test_real = torch.utils.data.DataLoader( 492 | test_data, batch_sampler=test_sampler, pin_memory=args.gpu>-1, num_workers=8) 493 | def rcycle(iterable): 494 | saved = [] # In-memory cache 495 | for element in iterable: 496 | yield element 497 | saved.append(element) 498 | while saved: 499 | random.shuffle(saved) # Shuffle every batch 500 | for element in saved: 501 | yield element 502 | if args.mode != "distill": 503 | train_real = rcycle(train_real) 504 | 505 | logger.info("build the dataset. done!") 506 | # ----------------------------------------------------------------------------------------------------------------- # 507 | 508 | # ----------------------------------------------------------------------------------------------------------------- # 509 | if args.mode == "train": 510 | logger.info(args) 511 | 512 | logger.info('Starting with HPARAMS: {}'.format(hp_str)) 513 | 514 | # build the model 515 | if args.dataset != "mscoco": 516 | model = args.model(src=SRC, trg=TRG, args=args) 517 | else: 518 | model = args.model(src=None, trg=mscoco_dataset, args=args) 519 | 520 | if args.mode == "train": 521 | logger.info(str(model)) 522 | 523 | if args.load_encoder_from is not None: 524 | if args.gpu > -1: 525 | with torch.cuda.device(args.gpu): 526 | encoder = torch.load(str(args.model_path / args.load_encoder_from) + '.pt', 527 | map_location=lambda storage, loc: storage.cuda()) 528 | else: 529 | encoder = torch.load(str(args.model_path / args.load_encoder_from) + '.pt', 530 | map_location=lambda storage, loc: storage) 531 | init_encoder(model, encoder) 532 | logger.info("Pretrained encoder loaded.") 533 | 534 | if args.load_from is not None: 535 | if args.gpu > -1: 536 | with torch.cuda.device(args.gpu): 537 | model.load_state_dict(torch.load(str(args.model_path / args.load_from) + '.pt', 538 | map_location=lambda storage, loc: storage.cuda()), strict=False) # load the pretrained models. 539 | else: 540 | model.load_state_dict(torch.load(str(args.model_path / args.load_from) + '.pt', 541 | map_location=lambda storage, loc: storage), strict=False) # load the pretrained models. 542 | logger.info("Pretrained model loaded.") 543 | 544 | params, param_names = [], [] 545 | for name, param in model.named_parameters(): 546 | params.append(param) 547 | param_names.append(name) 548 | 549 | if args.mode == "train": 550 | logger.info(param_names) 551 | logger.info("Size {}".format( sum( [ np.prod(x.size()) for x in params ] )) ) 552 | 553 | # use cuda 554 | if args.gpu > -1: 555 | model.cuda(args.gpu) 556 | 557 | # additional information 558 | args.__dict__.update({'hp_str': hp_str, 'logger': logger}) 559 | 560 | # ----------------------------------------------------------------------------------------------------------------- # 561 | 562 | trg_len_dic = None 563 | if args.dataset != "mscoco" and (not "ro" in args.dataset or "predict" in args.trg_len_option or "average" in args.trg_len_option): 564 | #if "predict" in args.trg_len_option or "average" in args.trg_len_option: 565 | #trg_len_dic = torch.load(os.path.join(data_path(args.dataset), "trg_len")) 566 | trg_len_dic = torch.load( str(args.data_prefix / "trg_len_dic" / args.dataset[-4:]) ) 567 | trg_len_dic = organise_trg_len_dic(trg_len_dic) 568 | if args.mode == 'train': 569 | logger.info('starting training') 570 | 571 | if args.dataset != "mscoco": 572 | train_model(args, model, train_real, dev_real, src=SRC, trg=TRG, trg_len_dic=trg_len_dic) 573 | else: 574 | train_model(args, model, train_real, dev_real, src=None, trg=mscoco_dataset, trg_len_dic=trg_len_dic) 575 | 576 | elif args.mode == 'test': 577 | logger.info('starting decoding from the pre-trained model, on the test set...') 578 | args.decoding_path = args.decoding_path / args.load_from 579 | name_suffix = 'b={}_{}.txt'.format(args.beam_size, args.load_from) 580 | names = ['src.{}'.format(name_suffix), 'trg.{}'.format(name_suffix), 'dec.{}'.format(name_suffix)] 581 | 582 | if args.test_which == "test" and (not test_real is None): 583 | logger.info("---------- Decoding TEST set ----------") 584 | decode_model(args, model, test_real, evaluate=True, trg_len_dic=trg_len_dic, decoding_path=args.decoding_path, \ 585 | names=["test."+xx for xx in names], maxsteps=None) 586 | else: 587 | logger.info("---------- Decoding VALID set ----------") 588 | decode_model(args, model, dev_real, evaluate=True, trg_len_dic=trg_len_dic, decoding_path=args.decoding_path, \ 589 | names=["valid."+xx for xx in names], maxsteps=None) 590 | 591 | elif args.mode == 'distill': 592 | logger.info('starting decoding the training set from an AR model') 593 | args.distill_path = args.distill_path / args.id_str 594 | args.distill_path.mkdir(parents=True, exist_ok=True) 595 | name_suffix = args.distill_which 596 | names = ['src.{}'.format(name_suffix), 'trg.{}'.format(name_suffix), 'dec.{}'.format(name_suffix)] 597 | 598 | distill_model(args, model, train_real, evaluate=False, distill_path=args.distill_path, \ 599 | names=["train."+xx for xx in names], maxsteps=None) 600 | 601 | logger.info("done.") 602 | -------------------------------------------------------------------------------- /scripts/i2_iwslt-ende/bontune_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset iwslt-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_iwslt-ende/decode_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_iwslt-ende/joint_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation -------------------------------------------------------------------------------- /scripts/i2_iwslt-ende/mle_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_iwslt-ende/tune_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-deen/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-deen/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-deen/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 -------------------------------------------------------------------------------- /scripts/i2_wmt14-deen/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-deen/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-deen --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-ende/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-ende/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-ende/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 -------------------------------------------------------------------------------- /scripts/i2_wmt14-ende/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt14-ende/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict -------------------------------------------------------------------------------- /scripts/i2_wmt16-enro/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-enro/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-enro/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-enro/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-enro/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-enro --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-roen/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-roen/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-roen/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-roen/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 2 --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/i2_wmt16-roen/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-roen --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 2 --use_argmax --next_dec_input both --denoising_prob 0.5 --layerwise_denoising_weight --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/iwslt-ende/bontune_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset iwslt-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/iwslt-ende/decode_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/iwslt-ende/joint_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation -------------------------------------------------------------------------------- /scripts/iwslt-ende/mle_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/iwslt-ende/tune_iwslt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset iwslt-ende --vocab_size 40000 --gpu 0 --ffw_block highway --params small --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt14-deen/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt14-deen/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-deen/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-deen/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-deen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-deen/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-deen --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt14-ende/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt14-ende/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-ende/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-ende/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt14-ende/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt16-enro/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt16-enro/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-enro/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-enro/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-enro --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-enro/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-enro --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt16-roen/bontune.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --ng_finetune --load_from --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | 4 | -------------------------------------------------------------------------------- /scripts/wmt16-roen/decode.sh: -------------------------------------------------------------------------------- 1 | python run.py --load_vocab --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --mode test --remove_repeats --trg_len_option predict --use_predicted_trg_len --load_from 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-roen/joint.sh: -------------------------------------------------------------------------------- 1 | python run.py --n 2 --joint --alph 0.1 --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-roen/mle.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-roen --vocab_size 60000 --gpu 0 --ffw_block highway --params big --lr_schedule anneal --fast --train_repeat_dec 1 --valid_repeat_dec 1 --use_argmax --use_distillation --max_len 100 2 | 3 | -------------------------------------------------------------------------------- /scripts/wmt16-roen/tune.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt16-roen --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/BoN-NAT/458808b08421dfac18100dca88ab4148703d1459/slides.pdf -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ipdb 3 | import torch 4 | import numpy as np 5 | from torch.autograd import Variable 6 | from utils import corrupt_target 7 | 8 | def convert(lst): 9 | vocab = "what I 've come to realize about Afghanistan , and this is something that is often dismissed in the West".split() 10 | dd = {idx+4 : word for idx, word in enumerate(vocab)} 11 | dd[0] = "UNK" 12 | dd[1] = "PAD" 13 | dd[2] = "BOS" 14 | dd[3] = "EOS" 15 | return " ".join( dd[xx] for xx in lst ) 16 | 17 | trg = [ [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 3, 1, 1] ] 18 | decoder_masks = [ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0] ] 19 | weight = float(sys.argv[1]) 20 | 21 | cor_p = sys.argv[2] # repeat / drop / repeat and drop next / swap / add random word 22 | cor_p = [int(xx) for xx in cor_p.split("-")] 23 | cor_p = [xx/sum(cor_p) for xx in cor_p] 24 | 25 | trg = Variable( torch.from_numpy( np.array( trg ) ) ) 26 | decoder_masks = torch.from_numpy( np.array( decoder_masks ) ) 27 | 28 | print ( convert( trg.data.numpy().tolist()[0] ) ) 29 | print ( convert( corrupt_target( trg, decoder_masks, 15, weight, cor_p ).data.numpy().tolist()[0] ) ) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import numpy as np 4 | import math 5 | import gc 6 | import os 7 | 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | 12 | from tqdm import tqdm, trange 13 | from model import Transformer, FastTransformer, INF, TINY, softmax 14 | from data import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset, data_path 15 | from utils import Metrics, Best, TargetLength, computeBLEU, computeBLEUMSCOCO, compute_bp, Batch, masked_sort, computeGroupBLEU, \ 16 | corrupt_target, remove_repeats, remove_repeats_tensor, print_bleu, corrupt_target_fix, set_eos, organise_trg_len_dic 17 | from time import gmtime, strftime 18 | 19 | # helper functions 20 | def export(x): 21 | try: 22 | with torch.cuda.device_of(x): 23 | return x.data.cpu().float().mean() 24 | except Exception: 25 | return 0 26 | 27 | tokenizer = lambda x: x.replace('@@ ', '').split() 28 | 29 | def valid_model(args, model, dev, dev_metrics=None, dev_metrics_trg=None, dev_metrics_average=None, 30 | print_out=False, teacher_model=None, trg_len_dic=None): 31 | print_seq = (['REF '] if args.dataset == "mscoco" else ['SRC ', 'REF ']) + ['HYP{}'.format(ii+1) for ii in range(args.valid_repeat_dec)] 32 | 33 | trg_outputs = [] 34 | real_all_outputs = [ [] for ii in range(args.valid_repeat_dec)] 35 | short_all_outputs = [ [] for ii in range(args.valid_repeat_dec)] 36 | outputs_data = {} 37 | 38 | model.eval() 39 | for j, dev_batch in enumerate(dev): 40 | if args.dataset == "mscoco": 41 | # only use first caption for calculating log likelihood 42 | all_captions = dev_batch[1] 43 | dev_batch[1] = dev_batch[1][0] 44 | decoder_inputs, decoder_masks,\ 45 | targets, target_masks,\ 46 | _, source_masks,\ 47 | encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 48 | 49 | else: 50 | decoder_inputs, decoder_masks,\ 51 | targets, target_masks,\ 52 | sources, source_masks,\ 53 | encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 54 | 55 | losses, all_decodings = [], [] 56 | if type(model) is Transformer: 57 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=1, decoding=True, return_probs=True) 58 | loss = model.cost(targets, target_masks, out=out) 59 | losses.append(loss) 60 | all_decodings.append( decoding ) 61 | 62 | elif type(model) is FastTransformer: 63 | for iter_ in range(args.valid_repeat_dec): 64 | curr_iter = min(iter_, args.num_decs-1) 65 | next_iter = min(curr_iter + 1, args.num_decs-1) 66 | 67 | decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) 68 | 69 | #loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) 70 | #losses.append(loss) 71 | all_decodings.append( decoding ) 72 | 73 | decoder_inputs = 0 74 | if args.next_dec_input in ["both", "emb"]: 75 | _, argmax = torch.max(probs, dim=-1) 76 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 77 | decoder_inputs += emb 78 | 79 | if args.next_dec_input in ["both", "out"]: 80 | decoder_inputs += out 81 | 82 | if args.dataset == "mscoco": 83 | # make sure that 5 captions per each example 84 | num_captions = len(all_captions[0]) 85 | for c in range(1, len(all_captions)): 86 | assert (num_captions == len(all_captions[c])) 87 | 88 | # untokenize reference captions 89 | for n_ref in range(len(all_captions)): 90 | n_caps = len(all_captions[0]) 91 | for c in range(n_caps): 92 | all_captions[n_ref][c] = all_captions[n_ref][c].replace("@@ ","") 93 | 94 | src_ref = [ list(map(list, zip(*all_captions))) ] 95 | else: 96 | src_ref = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets)] ] 97 | 98 | real_outputs = [ model.output_decoding(d) for d in [('trg', xx) for xx in all_decodings] ] 99 | 100 | if print_out: 101 | if args.dataset != "mscoco": 102 | for k, d in enumerate(src_ref + real_outputs): 103 | args.logger.info("{} ({}): {}".format(print_seq[k], len(d[0].split(" ")), d[0])) 104 | else: 105 | for k in range(len(all_captions[0])): 106 | for c in range(len(all_captions)): 107 | args.logger.info("REF ({}): {}".format(len(all_captions[c][k].split(" ")), all_captions[c][k])) 108 | 109 | for c in range(len(real_outputs)): 110 | args.logger.info("HYP {} ({}): {}".format(c+1, len(real_outputs[c][k].split(" ")), real_outputs[c][k])) 111 | args.logger.info('------------------------------------------------------------------') 112 | 113 | trg_outputs += src_ref[-1] 114 | for ii, d_outputs in enumerate(real_outputs): 115 | real_all_outputs[ii] += d_outputs 116 | 117 | #if dev_metrics is not None: 118 | # dev_metrics.accumulate(batch_size, *losses) 119 | if dev_metrics_trg is not None: 120 | dev_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) 121 | if dev_metrics_average is not None: 122 | dev_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) 123 | 124 | if args.dataset != "mscoco": 125 | real_bleu = [computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs] 126 | else: 127 | real_bleu = [computeBLEUMSCOCO(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs] 128 | 129 | outputs_data['real'] = real_bleu 130 | 131 | if "predict" in args.trg_len_option: 132 | outputs_data['pred_target_len_loss'] = getattr(dev_metrics_trg, 'pred_target_len_loss') 133 | outputs_data['pred_target_len_correct'] = getattr(dev_metrics_trg, 'pred_target_len_correct') 134 | outputs_data['pred_target_len_approx'] = getattr(dev_metrics_trg, 'pred_target_len_approx') 135 | outputs_data['average_target_len_correct'] = getattr(dev_metrics_average, 'average_target_len_correct') 136 | outputs_data['average_target_len_approx'] = getattr(dev_metrics_average, 'average_target_len_approx') 137 | 138 | #if dev_metrics is not None: 139 | # args.logger.info(dev_metrics) 140 | if dev_metrics_trg is not None: 141 | args.logger.info(dev_metrics_trg) 142 | if dev_metrics_average is not None: 143 | args.logger.info(dev_metrics_average) 144 | 145 | for idx in range(args.valid_repeat_dec): 146 | print_str = "iter {} | {}".format(idx+1, print_bleu(real_bleu[idx], verbose=False)) 147 | args.logger.info( print_str ) 148 | 149 | return outputs_data 150 | 151 | def train_model(args, model, train, dev, src=None, trg=None, trg_len_dic=None, teacher_model=None, save_path=None, maxsteps=None): 152 | 153 | if args.tensorboard and (not args.debug): 154 | from tensorboardX import SummaryWriter 155 | writer = SummaryWriter(str(args.event_path / args.id_str)) 156 | 157 | if type(model) is FastTransformer and args.denoising_prob > 0.0: 158 | denoising_weights = [args.denoising_weight for idx in range(args.train_repeat_dec)] 159 | denoising_out_weights = [args.denoising_out_weight for idx in range(args.train_repeat_dec)] 160 | 161 | if type(model) is FastTransformer and args.layerwise_denoising_weight: 162 | start, end = 0.9, 0.1 163 | diff = (start-end)/(args.train_repeat_dec-1) 164 | denoising_weights = np.arange(start=end, stop=start, step=diff).tolist()[::-1] + [0.1] 165 | 166 | # optimizer 167 | for k, p in zip(model.state_dict().keys(), model.parameters()): 168 | # only finetune layers that are responsible to predicting target len 169 | if args.finetune_trg_len: 170 | if "pred_len" not in k: 171 | p.requires_grad = False 172 | else: 173 | print(k) 174 | else: 175 | if "pred_len" in k: 176 | p.requires_grad = False 177 | 178 | params = [p for p in model.parameters() if p.requires_grad] 179 | if args.optimizer == 'Adam': 180 | opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9) 181 | else: 182 | raise NotImplementedError 183 | 184 | # if resume training 185 | if (args.load_from is not None) and (args.resume) and not args.finetune_trg_len: 186 | with torch.cuda.device(args.gpu): # very important. 187 | offset, opt_states = torch.load(str(args.model_path / args.load_from) + '.pt.states', 188 | map_location=lambda storage, loc: storage.cuda()) 189 | opt.load_state_dict(opt_states) 190 | else: 191 | offset = 0 192 | 193 | if not args.finetune_trg_len: 194 | best = Best(max, *['BLEU_dec{}'.format(ii+1) for ii in range(args.valid_repeat_dec)], 195 | 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, 196 | which=range(args.valid_repeat_dec)) 197 | else: 198 | best = Best(max, *['pred_target_len_correct'], 199 | 'i', model=model, opt=opt, path=str(args.model_path / args.id_str), gpu=args.gpu, 200 | which=[0]) 201 | train_metrics = Metrics('train loss', *['loss_{}'.format(idx+1) for idx in range(args.train_repeat_dec)], data_type = "avg") 202 | dev_metrics = Metrics('dev loss', *['loss_{}'.format(idx+1) for idx in range(args.valid_repeat_dec)], data_type = "avg") 203 | 204 | if "predict" in args.trg_len_option: 205 | train_metrics_trg = Metrics('train loss target', *["pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx"], data_type="avg") 206 | train_metrics_average = Metrics('train loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") 207 | dev_metrics_trg = Metrics('dev loss target', *["pred_target_len_loss", "pred_target_len_correct", "pred_target_len_approx"], data_type="avg") 208 | dev_metrics_average = Metrics('dev loss average', *["average_target_len_correct", "average_target_len_approx"], data_type="avg") 209 | else: 210 | train_metrics_trg = None 211 | train_metrics_average = None 212 | dev_metrics_trg = None 213 | dev_metrics_average = None 214 | 215 | if not args.no_tqdm: 216 | progressbar = tqdm(total=args.eval_every, desc='start training.') 217 | 218 | if maxsteps is None: 219 | maxsteps = args.maximum_steps 220 | 221 | #targetlength = TargetLength() 222 | for iters, train_batch in enumerate(train): 223 | #targetlength.accumulate( train_batch ) 224 | #continue 225 | 226 | iters += offset 227 | 228 | if args.save_every > 0 and iters % args.save_every == 0: 229 | args.logger.info('save (back-up) checkpoints at iter={}'.format(iters)) 230 | with torch.cuda.device(args.gpu): 231 | torch.save(best.model.state_dict(), '{}_iter={}.pt'.format(str(args.model_path / args.id_str), iters)) 232 | torch.save([iters, best.opt.state_dict()], '{}_iter={}.pt.states'.format(str(args.model_path / args.id_str), iters)) 233 | 234 | if (iters+1) % args.eval_every == 0: 235 | torch.cuda.empty_cache() 236 | gc.collect() 237 | dev_metrics.reset() 238 | if dev_metrics_trg is not None: 239 | dev_metrics_trg.reset() 240 | if dev_metrics_average is not None: 241 | dev_metrics_average.reset() 242 | outputs_data = valid_model(args, model, dev, dev_metrics, dev_metrics_trg=dev_metrics_trg, dev_metrics_average=dev_metrics_average, teacher_model=None, print_out=False, trg_len_dic=trg_len_dic) 243 | #outputs_data = [0, [0,0,0,0], 0, 0] 244 | if args.tensorboard and (not args.debug): 245 | for ii in range(args.valid_repeat_dec): 246 | writer.add_scalar('dev/single/Loss_{}'.format(ii + 1), getattr(dev_metrics, "loss_{}".format(ii+1)), iters) # NLL averaged over dev corpus 247 | writer.add_scalar('dev/single/BLEU_{}'.format(ii + 1), outputs_data['real'][ii][0], iters) # NOTE corpus bleu 248 | 249 | if "predict" in args.trg_len_option: 250 | writer.add_scalar("dev/single/pred_target_len_loss", outputs_data["pred_target_len_loss"], iters) 251 | writer.add_scalar("dev/single/pred_target_len_correct", outputs_data["pred_target_len_correct"], iters) 252 | writer.add_scalar("dev/single/pred_target_len_approx", outputs_data["pred_target_len_approx"], iters) 253 | writer.add_scalar("dev/single/average_target_len_correct", outputs_data["average_target_len_correct"], iters) 254 | writer.add_scalar("dev/single/average_target_len_approx", outputs_data["average_target_len_approx"], iters) 255 | 256 | """ 257 | writer.add_scalars('dev/total/BLEUs', {"iter_{}".format(idx+1):bleu for idx, bleu in enumerate(outputs_data['bleu']) }, iters) 258 | writer.add_scalars('dev/total/Losses', 259 | { "iter_{}".format(idx+1):getattr(dev_metrics, "loss_{}".format(idx+1)) 260 | for idx in range(args.valid_repeat_dec) }, 261 | iters ) 262 | """ 263 | 264 | if not args.debug: 265 | if not args.finetune_trg_len: 266 | best.accumulate(*[xx[0] for xx in outputs_data['real']], iters) 267 | 268 | values = list( best.metrics.values() ) 269 | args.logger.info("best model : {}, {}".format( "BLEU=[{}]".format(", ".join( [ str(x) for x in values[:args.valid_repeat_dec] ] ) ), \ 270 | "i={}".format( values[args.valid_repeat_dec] ), ) ) 271 | else: 272 | best.accumulate(*[outputs_data['pred_target_len_correct']], iters) 273 | values = list( best.metrics.values() ) 274 | args.logger.info("best model : {}".format( "pred_target_len_correct = {}".format(values[0])) ) 275 | 276 | args.logger.info('model:' + args.prefix + args.hp_str) 277 | 278 | # ---set-up a new progressor--- 279 | if not args.no_tqdm: 280 | progressbar.close() 281 | progressbar = tqdm(total=args.eval_every, desc='start training.') 282 | 283 | if type(model) is FastTransformer and args.anneal_denoising_weight: 284 | for ii, bb in enumerate([xx[0] for xx in outputs_data['real']][:-1]): 285 | denoising_weights[ii] = 0.9 - 0.1 * int(math.floor(bb / 3.0)) 286 | 287 | if iters > maxsteps: 288 | args.logger.info('reached the maximum updating steps.') 289 | break 290 | 291 | model.train() 292 | 293 | def get_lr_transformer(i, lr0=0.1): 294 | return lr0 * 10 / math.sqrt(args.d_model) * min( 295 | 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) 296 | 297 | def get_lr_anneal(iters, lr0=0.1): 298 | lr_end = 1e-5 299 | return max( 0, (args.lr - lr_end) * (args.anneal_steps - iters) / args.anneal_steps ) + lr_end 300 | 301 | if args.lr_schedule == "fixed": 302 | opt.param_groups[0]['lr'] = args.lr 303 | elif args.lr_schedule == "anneal": 304 | opt.param_groups[0]['lr'] = get_lr_anneal(iters + 1) 305 | elif args.lr_schedule == "transformer": 306 | opt.param_groups[0]['lr'] = get_lr_transformer(iters + 1) 307 | opt.zero_grad() 308 | 309 | if args.dataset == "mscoco": 310 | decoder_inputs, decoder_masks,\ 311 | targets, target_masks,\ 312 | _, source_masks,\ 313 | encoding, batch_size, rest = model.quick_prepare_mscoco(train_batch, all_captions=train_batch[1], fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) 314 | else: 315 | decoder_inputs, decoder_masks,\ 316 | targets, target_masks,\ 317 | sources, source_masks,\ 318 | encoding, batch_size, rest = model.quick_prepare(train_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) 319 | 320 | losses = [] 321 | if type(model) is Transformer: 322 | loss = model.cost(targets, target_masks, out=model(encoding, source_masks, decoder_inputs, decoder_masks)) 323 | losses.append( loss ) 324 | 325 | elif type(model) is FastTransformer: 326 | all_logits = [] 327 | all_denoising_masks = [] 328 | for iter_ in range(args.train_repeat_dec): 329 | torch.cuda.empty_cache() 330 | curr_iter = min(iter_, args.num_decs-1) 331 | next_iter = min(curr_iter + 1, args.num_decs-1) 332 | 333 | out = model(encoding, source_masks, decoder_inputs, decoder_masks, iter_=curr_iter, return_probs=False) 334 | 335 | if args.rf_finetune is True: 336 | loss = model.rf_cost(args, targets, target_masks, out=out, iter_=curr_iter) 337 | elif args.nat_finetune is True: 338 | loss = model.nat_cost(args, targets, target_masks, out=out, iter_=curr_iter) 339 | elif args.ng_finetune or args.joint is True: 340 | loss = model.ngram_cost(args, iters, targets, target_masks, out=out, iter_=curr_iter) 341 | else: 342 | loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) 343 | 344 | logits = model.decoder[curr_iter].out(out) 345 | 346 | if args.use_argmax: 347 | _, argmax = torch.max(logits, dim=-1) 348 | else: 349 | probs = softmax(logits) 350 | probs_sz = probs.size() 351 | logits_ = Variable(probs.data, requires_grad=False) 352 | argmax = torch.multinomial(logits_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) 353 | 354 | if args.self_distil > 0.0: 355 | all_logits.append(logits_masked) 356 | del logits 357 | losses.append(loss) 358 | 359 | decoder_inputs_ = 0 360 | denoising_mask = 1 361 | if args.next_dec_input in ["both", "emb"]: 362 | if args.denoising_prob > 0.0 and np.random.rand() < args.denoising_prob: 363 | cor = corrupt_target(targets, decoder_masks, len(trg.vocab), denoising_weights[iter_], args.corruption_probs) 364 | 365 | emb = F.embedding(cor, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 366 | denoising_mask = 0 367 | else: 368 | emb = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 369 | 370 | if args.denoising_out_weight > 0: 371 | if denoising_out_weights[iter_] > 0.0: 372 | corrupted_argmax = corrupt_target(argmax, decoder_masks, denoising_out_weights[iter_]) 373 | else: 374 | corrupted_argmax = argmax 375 | emb = F.embedding(corrupted_argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) 376 | decoder_inputs_ += emb 377 | all_denoising_masks.append( denoising_mask ) 378 | 379 | if args.next_dec_input in ["both", "out"]: 380 | decoder_inputs_ += out 381 | decoder_inputs = decoder_inputs_ 382 | 383 | # self distillation loss if requested 384 | if args.self_distil > 0.0: 385 | self_distil_losses = [] 386 | 387 | for logits_i in range(1, len(all_logits)-1): 388 | self_distill_loss_i = 0.0 389 | for logits_j in range(logits_i+1, len(all_logits)): 390 | self_distill_loss_i += \ 391 | all_denoising_masks[logits_j] * \ 392 | all_denoising_masks[logits_i] * \ 393 | (1/(logits_j-logits_i)) * args.self_distil * F.mse_loss(all_logits[logits_i], all_logits[logits_j].detach()) 394 | 395 | self_distil_losses.append(self_distill_loss_i) 396 | 397 | self_distil_loss = sum(self_distil_losses) 398 | 399 | loss = sum(losses) 400 | 401 | # accmulate the training metrics 402 | train_metrics.accumulate(batch_size, *losses, print_iter=None) 403 | if train_metrics_trg is not None: 404 | train_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) 405 | if train_metrics_average is not None: 406 | train_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) 407 | if type(model) is FastTransformer and args.self_distil > 0.0: 408 | (loss+self_distil_loss).backward() 409 | else: 410 | if "predict" in args.trg_len_option: 411 | if args.finetune_trg_len: 412 | rest[0].backward() 413 | else: 414 | loss.backward() 415 | else: 416 | loss.backward() 417 | 418 | if args.grad_clip > 0: 419 | total_norm = nn.utils.clip_grad_norm(params, args.grad_clip) 420 | 421 | opt.step() 422 | 423 | mid_str = '' 424 | if type(model) is FastTransformer and args.self_distil > 0.0: 425 | mid_str += 'distil={:.5f}, '.format(self_distil_loss.cpu().data.numpy()[0]) 426 | #if type(model) is FastTransformer and "predict" in args.trg_len_option: 427 | # mid_str += 'pred_target_len_loss={:.5f}, '.format(rest[0].cpu().data.numpy()[0]) 428 | if type(model) is FastTransformer and args.denoising_prob > 0.0: 429 | mid_str += "/".join(["{:.1f}".format(ff) for ff in denoising_weights[:-1]])+", " 430 | 431 | info = 'update={}, loss={}, {}lr={:.1e}'.format( iters, 432 | "/".join(["{:.3f}".format(export(ll)) for ll in losses]), 433 | mid_str, 434 | opt.param_groups[0]['lr']) 435 | 436 | if args.no_tqdm: 437 | if iters % args.eval_every == 0: 438 | args.logger.info("update {} : {}".format(iters, str(train_metrics))) 439 | else: 440 | progressbar.update(1) 441 | progressbar.set_description(info) 442 | 443 | if (iters+1) % args.eval_every == 0 and args.tensorboard and (not args.debug): 444 | for idx in range(args.train_repeat_dec): 445 | writer.add_scalar('train/single/Loss_{}'.format(idx+1), getattr(train_metrics, "loss_{}".format(idx+1)), iters) 446 | if "predict" in args.trg_len_option: 447 | writer.add_scalar("train/single/pred_target_len_loss", getattr(train_metrics_trg, "pred_target_len_loss"), iters) 448 | writer.add_scalar("train/single/pred_target_len_correct", getattr(train_metrics_trg, "pred_target_len_correct"), iters) 449 | writer.add_scalar("train/single/pred_target_len_approx", getattr(train_metrics_trg, "pred_target_len_approx"), iters) 450 | writer.add_scalar("train/single/average_target_len_correct", getattr(train_metrics_average, "average_target_len_correct"), iters) 451 | writer.add_scalar("train/single/average_target_len_approx", getattr(train_metrics_average, "average_target_len_approx"), iters) 452 | 453 | train_metrics.reset() 454 | if train_metrics_trg is not None: 455 | train_metrics_trg.reset() 456 | if train_metrics_average is not None: 457 | train_metrics_average.reset() 458 | 459 | #torch.save(targetlength.lengths, str(args.data_prefix / "trg_len_dic" / args.dataset[-4:])) 460 | -------------------------------------------------------------------------------- /tune_wmt.sh: -------------------------------------------------------------------------------- 1 | python run.py --dataset wmt14-ende --vocab_size 60000 --gpu 2 --ffw_block highway --params big --lr_schedule anneal --fast --valid_repeat_dec 1 --use_argmax --next_dec_input both --use_distillation --load_from --resume --finetune_trg_len --trg_len_option predict 2 | 3 | 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import ipdb 3 | import torch 4 | import random 5 | import numpy as np 6 | import _pickle as pickle 7 | import revtok 8 | import os 9 | from itertools import groupby 10 | import getpass 11 | from collections import Counter 12 | 13 | from torch.autograd import Variable 14 | from torchtext import data, datasets 15 | from nltk.translate.gleu_score import sentence_gleu, corpus_gleu 16 | from nltk.translate.bleu_score import closest_ref_length, brevity_penalty, modified_precision, SmoothingFunction 17 | from contextlib import ExitStack 18 | from collections import OrderedDict 19 | import fractions 20 | 21 | 22 | try: 23 | fractions.Fraction(0, 1000, _normalize=False) 24 | from fractions import Fraction 25 | except TypeError: 26 | from nltk.compat import Fraction 27 | 28 | def sentence_bleu(references, hypothesis, weights=(0.25, 0.25, 0.25, 0.25), 29 | smoothing_function=None, auto_reweigh=False, 30 | emulate_multibleu=False): 31 | 32 | return corpus_bleu([references], [hypothesis], 33 | weights, smoothing_function, auto_reweigh, 34 | emulate_multibleu) 35 | 36 | 37 | def corpus_bleu(list_of_references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), 38 | smoothing_function=None, auto_reweigh=False, 39 | emulate_multibleu=False): 40 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 41 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 42 | hyp_lengths, ref_lengths = 0, 0 43 | 44 | if len(list_of_references) != len(hypotheses): 45 | print ("The number of hypotheses and their reference(s) should be the same") 46 | return (0, (0, 0, 0, 0), 0, 0, 0) 47 | 48 | # Iterate through each hypothesis and their corresponding references. 49 | for references, hypothesis in zip(list_of_references, hypotheses): 50 | # For each order of ngram, calculate the numerator and 51 | # denominator for the corpus-level modified precision. 52 | for i, _ in enumerate(weights, start=1): 53 | p_i = modified_precision(references, hypothesis, i) 54 | p_numerators[i] += p_i.numerator 55 | p_denominators[i] += p_i.denominator 56 | 57 | # Calculate the hypothesis length and the closest reference length. 58 | # Adds them to the corpus-level hypothesis and reference counts. 59 | hyp_len = len(hypothesis) 60 | hyp_lengths += hyp_len 61 | ref_lengths += closest_ref_length(references, hyp_len) 62 | 63 | # Calculate corpus-level brevity penalty. 64 | bp = brevity_penalty(ref_lengths, hyp_lengths) 65 | 66 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 67 | # order of n-grams < 4 and weights is set at default. 68 | if auto_reweigh: 69 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 70 | weights = ( 1 / hyp_lengths ,) * hyp_lengths 71 | 72 | # Collects the various precision values for the different ngram orders. 73 | p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) 74 | for i, _ in enumerate(weights, start=1)] 75 | 76 | p_n_ = [xx.numerator / xx.denominator * 100 for xx in p_n] 77 | 78 | # Returns 0 if there's no matching n-grams 79 | # We only need to check for p_numerators[1] == 0, since if there's 80 | # no unigrams, there won't be any higher order ngrams. 81 | if p_numerators[1] == 0: 82 | return (0, (0, 0, 0, 0), 0, 0, 0) 83 | 84 | # If there's no smoothing, set use method0 from SmoothinFunction class. 85 | if not smoothing_function: 86 | smoothing_function = SmoothingFunction().method0 87 | # Smoothen the modified precision. 88 | # Note: smoothing_function() may convert values into floats; 89 | # it tries to retain the Fraction object as much as the 90 | # smoothing method allows. 91 | p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, 92 | hyp_len=hyp_len, emulate_multibleu=emulate_multibleu) 93 | s = (w * math.log(p_i) for i, (w, p_i) in enumerate(zip(weights, p_n))) 94 | s = bp * math.exp(math.fsum(s)) * 100 95 | final_bleu = round(s, 4) if emulate_multibleu else s 96 | return (final_bleu, p_n_, bp, ref_lengths, hyp_lengths) 97 | 98 | INF = 1e10 99 | TINY = 1e-9 100 | 101 | def n_grams(list_words): 102 | set_1gram, set_2gram, set_3gram, set_4gram = set(), set(), set(), set() 103 | count = {} 104 | l = len(list_words) 105 | for i in range(l): 106 | word = list_words[i] 107 | if word not in set_1gram: 108 | set_1gram.add(word) 109 | count[word] = 1 110 | else: 111 | set_1gram.add((word,count[word])) 112 | count[word] += 1 113 | count = {} 114 | 115 | for i in range(l-1): 116 | word = (list_words[i],list_words[i+1]) 117 | if word not in set_2gram: 118 | set_2gram.add(word) 119 | count[word] = 1 120 | else: 121 | set_2gram.add((word,count[word])) 122 | count[word] += 1 123 | 124 | count = {} 125 | 126 | for i in range(l-2): 127 | word = (list_words[i],list_words[i+1], list_words[i+2]) 128 | if word not in set_3gram: 129 | set_3gram.add(word) 130 | count[word] = 1 131 | else: 132 | set_3gram.add((word,count[word])) 133 | count[word] += 1 134 | count = {} 135 | 136 | for i in range(l-3): 137 | word = (list_words[i],list_words[i+1], list_words[i+2], list_words[i+3]) 138 | if word not in set_4gram: 139 | set_4gram.add(word) 140 | count[word] = 1 141 | else: 142 | set_4gram.add((word,count[word])) 143 | count[word] += 1 144 | 145 | return set_1gram, set_2gram, set_3gram, set_4gram 146 | 147 | def my_sentence_gleu(references, hypothesis): 148 | global t1,t2 149 | reference = references[0] 150 | ref_grams = n_grams(reference) 151 | hyp_grams = n_grams(hypothesis) 152 | match_grams = [x.intersection(y) for (x,y) in zip(ref_grams, hyp_grams)] 153 | ref_count = sum([len(x) for x in ref_grams]) 154 | hyp_count = sum([len(x) for x in hyp_grams]) 155 | match_count = sum([len(x) for x in match_grams]) 156 | gleu = float(match_count) / float(max(ref_count,hyp_count)) 157 | return gleu 158 | 159 | def computeGLEU(outputs, targets, corpus=False, tokenizer=None): 160 | if tokenizer is None: 161 | tokenizer = revtok.tokenize 162 | 163 | 164 | if not corpus: 165 | return [my_sentence_gleu([t], o) for o, t in zip(outputs, targets)] 166 | 167 | return corpus_gleu([[t] for t in targets], [o for o in outputs]) 168 | 169 | def computeBLEU(outputs, targets, corpus=False, tokenizer=None): 170 | if tokenizer is None: 171 | tokenizer = revtok.tokenize 172 | 173 | outputs = [tokenizer(o) for o in outputs] 174 | targets = [tokenizer(t) for t in targets] 175 | 176 | if corpus: 177 | return corpus_bleu([[t] for t in targets], [o for o in outputs], emulate_multibleu=True) 178 | else: 179 | return [sentence_bleu([t], o)[0] for o, t in zip(outputs, targets)] 180 | #return torch.Tensor([sentence_bleu([t], o)[0] for o, t in zip(outputs, targets)]) 181 | 182 | def computeBLEUMSCOCO(outputs, targets, corpus=True, tokenizer=None): 183 | # outputs is list of 5000 captions 184 | # targets is list of 5000 lists each length of 5 185 | if tokenizer is None: 186 | tokenizer = revtok.tokenize 187 | 188 | outputs = [tokenizer(o) for o in outputs] 189 | new_targets = [] 190 | for i, t in enumerate(targets): 191 | new_targets.append([tokenizer(tt) for tt in t]) 192 | #targets[i] = [tokenizer(tt) for tt in t] 193 | 194 | if corpus: 195 | return corpus_bleu(new_targets, outputs, emulate_multibleu=True) 196 | else: 197 | return [sentence_bleu(new_t, o)[0] for o, new_t in zip(outputs, new_targets)] 198 | 199 | def compute_bp(hypotheses, list_of_references): 200 | hyp_lengths, ref_lengths = 0, 0 201 | for references, hypothesis in zip(list_of_references, hypotheses): 202 | hyp_len = len(hypothesis) 203 | hyp_lengths += hyp_len 204 | ref_lengths += closest_ref_length(references, hyp_len) 205 | 206 | # Calculate corpus-level brevity penalty. 207 | bp = brevity_penalty(ref_lengths, hyp_lengths) 208 | return bp 209 | 210 | def computeGroupBLEU(outputs, targets, tokenizer=None, bra=10, maxmaxlen=80): 211 | if tokenizer is None: 212 | tokenizer = revtok.tokenize 213 | 214 | outputs = [tokenizer(o) for o in outputs] 215 | targets = [tokenizer(t) for t in targets] 216 | maxlens = max([len(t) for t in targets]) 217 | print(maxlens) 218 | maxlens = min([maxlens, maxmaxlen]) 219 | nums = int(np.ceil(maxlens / bra)) 220 | outputs_buckets = [[] for _ in range(nums)] 221 | targets_buckets = [[] for _ in range(nums)] 222 | for o, t in zip(outputs, targets): 223 | idx = len(o) // bra 224 | if idx >= len(outputs_buckets): 225 | idx = -1 226 | outputs_buckets[idx] += [o] 227 | targets_buckets[idx] += [t] 228 | 229 | for k in range(nums): 230 | print(corpus_bleu([[t] for t in targets_buckets[k]], [o for o in outputs_buckets[k]], emulate_multibleu=True)) 231 | 232 | class TargetLength: 233 | def __init__(self, lengths=None): # data_type : sum, avg 234 | self.lengths = lengths if lengths != None else dict() 235 | 236 | def accumulate(self, batch): 237 | src_len = (batch.src != 1).sum(-1).cpu().data.numpy() 238 | trg_len = (batch.trg != 1).sum(-1).cpu().data.numpy() 239 | for (slen, tlen) in zip(src_len, trg_len): 240 | if not slen in self.lengths: 241 | self.lengths[slen] = (1, int(tlen)) 242 | else: 243 | (count, acc) = self.lengths[slen] 244 | self.lengths[slen] = (count + 1, acc + int(tlen)) 245 | 246 | def get_trg_len(self, src_len): 247 | if not src_len in self.lengths: 248 | return self.get_trg_len(src_len + 1) - 1 249 | else: 250 | (count, acc) = self.lengths[src_len] 251 | return acc / float(count) 252 | 253 | def organise_trg_len_dic(trg_len_dic): 254 | trg_len_dic = {k:int(v[1]/float(v[0])) for (k, v) in trg_len_dic.items()} 255 | return trg_len_dic 256 | 257 | def query_trg_len_dic(trg_len_dic, q): 258 | max_src_len = max(trg_len_dic.keys()) 259 | if q <= max_src_len: 260 | if q in trg_len_dic: 261 | return trg_len_dic[q] 262 | else: 263 | return query_trg_len_dic(trg_len_dic, q+1) - 1 264 | else: 265 | return int(math.floor( trg_len_dic[max_src_len] / max_src_len * q )) 266 | 267 | def make_decoder_masks(source_masks, trg_len_dic): 268 | batch_size, src_max_len = source_masks.size() 269 | src_len = (source_masks == 1).sum(-1).cpu().numpy() 270 | trg_len = [int(math.floor(query_trg_len_dic(trg_len_dic, src) * 1.1)) for src in src_len] 271 | trg_max_len = max(trg_len) 272 | decoder_masks = np.zeros((batch_size, trg_max_len)) 273 | #decoder_masks = Variable(torch.zeros(batch_size, trg_max_len), requires_grad=False) 274 | for idx, tt in enumerate(trg_len): 275 | decoder_masks[idx][:tt] = 1 276 | result = torch.from_numpy(decoder_masks).float() 277 | if source_masks.is_cuda: 278 | result = result.cuda() 279 | return result 280 | 281 | def double_source_masks(source_masks): 282 | batch_size, src_max_len = source_masks.size() 283 | src_len = (source_masks == 1).sum(-1).cpu().numpy() 284 | decoder_masks = np.zeros((batch_size, src_max_len * 2)) 285 | for idx, tt in enumerate(src_len): 286 | decoder_masks[idx][:2*tt] = 1 287 | result = torch.from_numpy(decoder_masks).float() 288 | if source_masks.is_cuda: 289 | result = result.cuda() 290 | return result 291 | 292 | class Metrics: 293 | 294 | def __init__(self, name, *metrics, data_type="sum"): # data_type : sum, avg 295 | self.count = 0 296 | self.metrics = OrderedDict((metric, 0) for metric in metrics) 297 | self.name = name 298 | self.data_type = data_type 299 | 300 | def accumulate(self, count, *values, print_iter=None): 301 | self.count += count 302 | if print_iter is not None: 303 | print(print_iter, end=' ') 304 | for value, metric in zip(values, self.metrics): 305 | if isinstance(value, torch.autograd.Variable): 306 | value = value.data 307 | if torch.is_tensor(value): 308 | with torch.cuda.device_of(value): 309 | value = value.cpu() 310 | value = value.float().sum() 311 | 312 | if print_iter is not None: 313 | print('%.3f' % value, end=' ') 314 | if self.data_type == "sum": 315 | self.metrics[metric] += value 316 | elif self.data_type == "avg": 317 | self.metrics[metric] += value * count 318 | 319 | if print_iter is not None: 320 | print() 321 | return values[0] # loss 322 | 323 | def __getattr__(self, key): 324 | if key in self.metrics: 325 | return self.metrics[key] / (self.count + 1e-9) 326 | raise AttributeError 327 | 328 | def __repr__(self): 329 | return ("{}: ".format(self.name) + 330 | "[{}]".format( ', '.join(["{:.4f}".format(getattr(self, metric)) for metric, value in self.metrics.items() if value is not 0 ] ) ) ) 331 | 332 | def tensorboard(self, expt, i): 333 | for metric in self.metrics: 334 | value = getattr(self, metric) 335 | if value != 0: 336 | #expt.add_scalar_value(f'{self.name}_{metric}', value, step=i) 337 | expt.add_scalar_value("{}_{}".format(self.name, metric), value, step=i) 338 | 339 | def reset(self): 340 | self.count = 0 341 | self.metrics.update({metric: 0 for metric in self.metrics}) 342 | 343 | class Best: 344 | def __init__(self, cmp_fn, *metrics, model=None, opt=None, path='', gpu=0, which=[0]): 345 | self.cmp_fn = cmp_fn 346 | self.model = model 347 | self.opt = opt 348 | self.path = path + '.pt' 349 | self.metrics = OrderedDict((metric, None) for metric in metrics) 350 | self.gpu = gpu 351 | self.which = which 352 | self.best_cmp_value = None 353 | 354 | def accumulate(self, *other_values): 355 | 356 | with torch.cuda.device(self.gpu): 357 | cmp_values = [other_values[which] for which in self.which] 358 | if self.best_cmp_value is None or \ 359 | self.cmp_fn(self.best_cmp_value, *cmp_values) != self.best_cmp_value: 360 | self.metrics.update( { metric: value for metric, value in zip( 361 | list(self.metrics.keys()), other_values) } ) 362 | self.best_cmp_value = self.cmp_fn( [ list(self.metrics.items())[which][1] for which in self.which ] ) 363 | 364 | #open(self.path + '.temp', 'w') 365 | if self.model is not None: 366 | torch.save(self.model.state_dict(), self.path) 367 | 368 | if self.opt is not None: 369 | torch.save([self.i, self.opt.state_dict()], self.path + '.states') 370 | #os.remove(self.path + '.temp') 371 | 372 | def __getattr__(self, key): 373 | if key in self.metrics: 374 | return self.metrics[key] 375 | raise AttributeError 376 | 377 | def __repr__(self): 378 | return ("BEST: " + 379 | ', '.join(["{}: {:.4f}".format(metric, getattr(self, metric)) for metric, value in self.metrics.items() if value is not 0])) 380 | 381 | class CacheExample(data.Example): 382 | 383 | @classmethod 384 | def fromsample(cls, data_lists, names): 385 | ex = cls() 386 | for data, name in zip(data_lists, names): 387 | setattr(ex, name, data) 388 | return ex 389 | 390 | 391 | class Cache: 392 | 393 | def __init__(self, size=10000, fileds=["src", "trg"]): 394 | self.cache = [] 395 | self.maxsize = size 396 | 397 | def demask(self, data, mask): 398 | with torch.cuda.device_of(data): 399 | data = [d[:l] for d, l in zip(data.data.tolist(), mask.sum(1).long().tolist())] 400 | return data 401 | 402 | def add(self, data_lists, masks, names): 403 | data_lists = [self.demask(d, m) for d, m in zip(data_lists, masks)] 404 | for data in zip(*data_lists): 405 | self.cache.append(CacheExample.fromsample(data, names)) 406 | 407 | if len(self.cache) >= self.maxsize: 408 | self.cache = self.cache[-self.maxsize:] 409 | 410 | 411 | class Batch: 412 | def __init__(self, src=None, trg=None, dec=None): 413 | self.src, self.trg, self.dec = src, trg, dec 414 | 415 | def masked_sort(x, mask, dim=-1): 416 | x.data += ((1 - mask) * INF).long() 417 | y, i = torch.sort(x, dim) 418 | y.data *= mask.long() 419 | return y, i 420 | 421 | def unsorted(y, i, dim=-1): 422 | z = Variable(y.data.new(*y.size())) 423 | z.scatter_(dim, i, y) 424 | return z 425 | 426 | 427 | def merge_cache(decoding_path, names0, last_epoch=0, max_cache=20): 428 | file_lock = open(decoding_path + '/_temp_decode', 'w') 429 | 430 | for name in names0: 431 | filenames = [] 432 | for i in range(max_cache): 433 | filenames.append('{}/{}.ep{}'.format(decoding_path, name, last_epoch - i)) 434 | if (last_epoch - i) <= 0: 435 | break 436 | code = 'cat {} > {}.train.{}'.format(" ".join(filenames), '{}/{}'.format(decoding_path, name), last_epoch) 437 | os.system(code) 438 | os.remove(decoding_path + '/_temp_decode') 439 | 440 | def corrupt_target_fix(trg, decoder_masks, vocab_size, weight=0.1, cor_p=[0.1, 0.1, 0.1, 0.1]): 441 | batch_size, max_trg_len = trg.size() # actual trg len 442 | max_dec_len = decoder_masks.size(1) # 2 * actual src len 443 | dec_lens = (decoder_masks == 1).sum(-1).cpu().numpy() 444 | trg_lens = (trg != 1).sum(-1).data.cpu().numpy() 445 | 446 | num_corrupts = np.array( [ np.random.choice(dec_lens[bidx]//2, 447 | min( max( math.floor(weight * (dec_lens[bidx]//2)), 1 ), dec_lens[bidx]//2), 448 | replace=False ) \ 449 | for bidx in range(batch_size) ] ) 450 | 451 | #min_len = min(max_trg_len, max_dec_len) 452 | decoder_input = np.ones((batch_size, max_dec_len)) 453 | decoder_input.fill(3) 454 | #decoder_input[:, :min_len] = trg[:, :min_len].data.cpu().numpy() 455 | 456 | for bidx in range(batch_size): 457 | min_len = min(dec_lens[bidx], trg_lens[bidx]) 458 | decoder_input[bidx][:min_len] = trg[bidx, :min_len].data.cpu().numpy() 459 | nr_list = num_corrupts[bidx] 460 | for nr in nr_list: 461 | 462 | prob = np.random.rand() 463 | 464 | #### each corruption changes multiple words 465 | if prob < sum(cor_p[:1]): # repeat 466 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 467 | 468 | elif prob < sum(cor_p[:2]): # drop 469 | decoder_input[bidx][nr:-1] = decoder_input[bidx][nr+1:] 470 | 471 | #### each corruption changes one word 472 | elif prob < sum(cor_p[:3]): # replace word with random word 473 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 474 | 475 | #### each corruption changes two words 476 | elif prob < sum(cor_p[:4]): # swap 477 | temp = decoder_input[bidx][nr] 478 | decoder_input[bidx][nr] = decoder_input[bidx][nr+1] 479 | decoder_input[bidx][nr+1] = temp 480 | 481 | result = torch.from_numpy(decoder_input).long() 482 | if decoder_masks.is_cuda: 483 | result = result.cuda(decoder_masks.get_device()) 484 | return Variable(result, requires_grad=False) 485 | 486 | def corrupt_target(trg, decoder_masks, vocab_size, weight=0.1, cor_p=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]): 487 | batch_size, max_trg_len = trg.size() 488 | max_dec_len = decoder_masks.size(1) 489 | dec_lens = (decoder_masks == 1).sum(-1).cpu().numpy() 490 | 491 | num_corrupts = np.array( [ np.random.choice(dec_lens[bidx]-1, 492 | min( max( math.floor(weight * dec_lens[bidx]), 1 ), dec_lens[bidx]-1 ), 493 | replace=False ) \ 494 | for bidx in range(batch_size) ] ) 495 | 496 | min_len = min(max_trg_len, max_dec_len) 497 | decoder_input = np.ones((batch_size, max_dec_len)) 498 | decoder_input.fill(3) 499 | decoder_input[:, :min_len] = trg[:, :min_len].data.cpu().numpy() 500 | 501 | for bidx in range(batch_size): 502 | nr_list = num_corrupts[bidx] 503 | for nr in nr_list: 504 | 505 | prob = np.random.rand() 506 | 507 | #### each corruption changes multiple words 508 | if prob < sum(cor_p[:1]): # repeat 509 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 510 | 511 | elif prob < sum(cor_p[:2]): # drop 512 | decoder_input[bidx][nr:-1] = decoder_input[bidx][nr+1:] 513 | 514 | elif prob < sum(cor_p[:3]): # add random word 515 | decoder_input[bidx][nr+1:] = decoder_input[bidx][nr:-1] 516 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 # sample except UNK/PAD/INIT/EOS 517 | 518 | #### each corruption changes one word 519 | elif prob < sum(cor_p[:4]): # repeat and drop next 520 | decoder_input[bidx][nr+1] = decoder_input[bidx][nr] 521 | 522 | elif prob < sum(cor_p[:5]): # replace word with random word 523 | decoder_input[bidx][nr] = np.random.randint(vocab_size-4) + 4 524 | 525 | #### each corruption changes two words 526 | elif prob < sum(cor_p[:6]): # swap 527 | temp = decoder_input[bidx][nr] 528 | decoder_input[bidx][nr] = decoder_input[bidx][nr+1] 529 | decoder_input[bidx][nr+1] = temp 530 | 531 | elif prob < sum(cor_p[:7]): # global swap 532 | swap_idx = np.random.randint(1, dec_lens[bidx]-nr) + nr 533 | temp = decoder_input[bidx][nr] 534 | decoder_input[bidx][nr] = decoder_input[bidx][swap_idx] 535 | decoder_input[bidx][swap_idx] = temp 536 | 537 | result = torch.from_numpy(decoder_input).long() 538 | if decoder_masks.is_cuda: 539 | result = result.cuda(decoder_masks.get_device()) 540 | return Variable(result, requires_grad=False) 541 | 542 | def drop(sentence, n_d): 543 | cur_len = np.sum( sentence != 1 ) 544 | for idx in range(n_d): 545 | drop_pos = random.randint(0, cur_len - 1) # a <= N <= b 546 | sentence[drop_pos:-1] = sentence[drop_pos+1:] 547 | cur_len = cur_len - 1 548 | sentence[-n_d:] = 1 549 | return sentence 550 | 551 | def repeat(sentence, n_r): 552 | cur_len = np.sum( sentence != 1 ) 553 | for idx in range(n_r): 554 | drop_pos = random.randint(0, cur_len) # a <= N <= b 555 | sentence[drop_pos+1:] = sentence[drop_pos:-1] 556 | sentence[cur_len:] = 1 557 | return sentence 558 | 559 | def remove_repeats(lst_of_sentences): 560 | lst = [] 561 | for sentence in lst_of_sentences: 562 | lst.append( " ".join([x[0] for x in groupby(sentence.split())]) ) 563 | return lst 564 | 565 | def remove_repeats_tensor(tensor): 566 | tensor = tensor.data.cpu() 567 | newtensor = tensor.clone() 568 | batch_size, seq_len = tensor.size() 569 | for bidx in range(batch_size): 570 | for sidx in range(seq_len-1): 571 | if newtensor[bidx, sidx] == newtensor[bidx, sidx+1]: 572 | newtensor[bidx, sidx:-1] = newtensor[bidx, sidx+1:] 573 | return Variable(newtensor) 574 | 575 | def mkdir(path): 576 | if not os.path.exists(path): 577 | os.mkdir(path) 578 | 579 | def print_bleu(bleu_output, verbose=True): 580 | (final_bleu, prec, bp, ref_lengths, hyp_lengths) = bleu_output 581 | ratio = 0 if ref_lengths == 0 else hyp_lengths/ref_lengths 582 | if verbose: 583 | return "BLEU = {:.2f}, {:.1f}/{:.1f}/{:.1f}/{:.1f} (BP={:.3f}, ratio={:.3f}, hyp_len={}, ref_len={})".format( 584 | final_bleu, prec[0], prec[1], prec[2], prec[3], bp, ratio, hyp_lengths, ref_lengths 585 | ) 586 | else: 587 | return "BLEU = {:.2f}, {:.1f}/{:.1f}/{:.1f}/{:.1f} (BP={:.3f}, ratio={:.3f})".format( 588 | final_bleu, prec[0], prec[1], prec[2], prec[3], bp, ratio 589 | ) 590 | 591 | def set_eos(argmax): 592 | new_argmax = Variable(argmax.data.new(*argmax.size()), requires_grad=False) 593 | new_argmax.fill_(3) 594 | batch_size, seq_len = argmax.size() 595 | argmax_lst = argmax.data.cpu().numpy().tolist() 596 | for bidx in range(batch_size): 597 | if 3 in argmax_lst[bidx]: 598 | idx = argmax_lst[bidx].index(3) 599 | if idx > 0 : 600 | new_argmax[bidx,:idx] = argmax[bidx,:idx] 601 | return new_argmax 602 | 603 | def init_encoder(model, saved): 604 | saved_ = {k.replace("encoder.",""):v for (k,v) in saved.items() if "encoder" in k} 605 | encoder = model.encoder 606 | encoder.load_state_dict(saved_) 607 | return model 608 | 609 | def oracle_converged(bleu_hist, num_items=5): 610 | batch_size = len(bleu_hist) 611 | converged = [False for bidx in range(batch_size)] 612 | for bidx in range(batch_size): 613 | if len(bleu_hist[bidx]) < num_items: 614 | converged[bidx] = False 615 | else: 616 | converged[bidx] = True 617 | hist = bleu_hist[bidx][-num_items:] 618 | for item in hist[1:]: 619 | if item > hist[0]: 620 | converged[bidx] = False # if BLEU improves in 4 iters, not converged 621 | return converged 622 | 623 | def equality_converged(output_hist, num_items=5): 624 | batch_size = len(output_hist) 625 | converged = [False for bidx in range(batch_size)] 626 | for bidx in range(batch_size): 627 | if len(output_hist[bidx]) < num_items: 628 | converged[bidx] = False 629 | else: 630 | converged[bidx] = False 631 | hist = output_hist[bidx][-num_items:] 632 | for item in hist[1:]: 633 | if item == hist[0]: 634 | converged[bidx] = True # if out_i == out_j for (j = i+1, i+2, i+3, i+4), converged 635 | return converged 636 | 637 | def jaccard_converged(multiset_hist, num_items=5, jaccard_thresh=1.0): 638 | batch_size = len(multiset_hist) 639 | converged = [False for bidx in range(batch_size)] 640 | for bidx in range(batch_size): 641 | if len(multiset_hist[bidx]) < num_items: 642 | converged[bidx] = False 643 | else: 644 | converged[bidx] = False 645 | hist = multiset_hist[bidx][-num_items:] 646 | for item in hist[1:]: 647 | 648 | inters = len(item.intersection(hist[0])) 649 | unio = len(item.union(hist[0])) 650 | jaccard_index = float(inters) / np.maximum(1.,float(unio)) 651 | 652 | if jaccard_index >= jaccard_thresh: 653 | converged[bidx] = True 654 | return converged 655 | --------------------------------------------------------------------------------