├── .gitignore ├── README.md ├── average_checkpoints.py ├── combine_corpus.py └── example.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Diversification: A Simple Strategy For Neural Machine Translation 2 | #### Accepted as conference paper at 34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada, 2020 3 | #### Authors: Xuan-Phi Nguyen, Shafiq Joty, Wu Kui, Ai Ti Aw 4 | 5 | Paper link: [https://arxiv.org/abs/1911.01986](https://arxiv.org/abs/1911.01986) 6 | 7 | # Citation 8 | 9 | Please cite as: 10 | 11 | ```bibtex 12 | @incollection{nguyen2020data, 13 | title = {Data Diversification: A Simple Strategy For Neural Machine Translation}, 14 | author = {Xuan-Phi Nguyen and Shafiq Joty and Wu Kui and Ai Ti Aw}, 15 | booktitle = {Advances in Neural Information Processing Systems 32}, 16 | year = {2020}, 17 | publisher = {Curran Associates, Inc.}, 18 | } 19 | ``` 20 | 21 | ## Pretrained Models 22 | 23 | Model | Description | Dataset | Download 24 | ---|---|---|--- 25 | `WMT'16 En-De` | Transformer | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: [download (.tar.gz)](https://drive.google.com/file/d/1dpUPmVvLZKiUHWeqo0_0-ox2yjGb109e/view?usp=sharing) 26 | 27 | ## Instruction To train WMT English-German 28 | 29 | **Step 1**: Follow instruction from [Fairseq](https://github.com/pytorch/fairseq/tree/master/examples/translation) to create 30 | the WMT'14 Dataset. 31 | 32 | Save the processed data as ``data_fairseq/translate_ende_wmt16_bpe32k`` 33 | 34 | Save the raw data (which contains the file train.tok.clean.bpe.32000.en) to ``raw_data/wmt_ende`` 35 | 36 | **Step 2**: copy the same data to ``data_fairseq/translate_deen_wmt16_bpe32k`` for De-En 37 | ```bash 38 | cp -r data_fairseq/translate_ende_wmt16_bpe32k data_fairseq/translate_deen_wmt16_bpe32k 39 | ``` 40 | 41 | 42 | **Step 3**: Train forward models. Step 3-4 can be done all in parallel, if you have more than 8 GPUs, you can run all 6 models at once. 43 | ```bash 44 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 45 | export seed_prefix=100 46 | export problem=translate_ende_wmt16_bpe32k 47 | export model_name=big_tfm_baseline_df3584_s${seed_prefix} 48 | export data_dir=`pwd`/data_fairseq/$problem 49 | 50 | for index in {1..3} 51 | do 52 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index} 53 | fairseq-train \ 54 | ${data_dir} \ 55 | -s en -t de \ 56 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ 57 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 58 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ 59 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ 60 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 61 | --max-update 43000 \ 62 | --keep-last-epochs 10 \ 63 | --save-dir ${model_dir} \ 64 | --ddp-backend no_c10d \ 65 | --seed ${seed_prefix}${index} \ 66 | --max-tokens 3584 \ 67 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar 68 | done 69 | ``` 70 | 71 | 72 | **Step 4**: Train backward models 73 | ```bash 74 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 75 | export seed_prefix=101 76 | export problem=translate_deen_wmt16_bpe32k 77 | export model_name=big_tfm_baseline_df3584_s${seed_prefix} 78 | export data_dir=`pwd`/data_fairseq/$problem 79 | 80 | for index in {1..3} 81 | do 82 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index} 83 | fairseq-train \ 84 | ${data_dir} \ 85 | -s de -t en \ 86 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ 87 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 88 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ 89 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ 90 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 91 | --max-update 43000 \ 92 | --keep-last-epochs 10 \ 93 | --save-dir ${model_dir} \ 94 | --ddp-backend no_c10d \ 95 | --seed ${seed_prefix}${index} \ 96 | --max-tokens 3584 \ 97 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar 98 | done 99 | ``` 100 | 101 | 102 | **Step 5**: Inference forward models 103 | 104 | ```bash 105 | export CUDA_VISIBLE_DEVICES=0 106 | export seed_prefix=100 107 | export problem=translate_ende_wmt16_bpe32k 108 | export model_name=big_tfm_baseline_df3584_s${seed_prefix} 109 | export data_dir=`pwd`/data_fairseq/$problem 110 | export beam=5 111 | export lenpen=0.6 112 | export round=1 113 | 114 | for index in {1..3} 115 | do 116 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index} 117 | export best_file=$model_dir/checkpoint_best.pt 118 | export gen_out=$model_dir/infer_train_b${beam}_lp${lenpen} 119 | fairseq-generate ${data_dir} \ 120 | -s en -t de \ 121 | --path ${best_file} \ 122 | --gen-subset train \ 123 | --max-tokens ${infer_bsz} --beam ${beam} --lenpen ${lenpen} | dd of=$gen_out 124 | grep ^S ${gen_out} | cut -f2- > $gen_out.en 125 | grep ^H ${gen_out} | cut -f3- > $gen_out.de 126 | done 127 | 128 | ``` 129 | 130 | 131 | **Step 6**: Inference backward models 132 | 133 | ```bash 134 | export CUDA_VISIBLE_DEVICES=0 135 | export seed_prefix=101 136 | export problem=translate_deen_wmt16_bpe32k 137 | export model_name=big_tfm_baseline_df3584_s${seed_prefix} 138 | export data_dir=`pwd`/data_fairseq/$problem 139 | export beam=5 140 | export lenpen=0.6 141 | export round=1 142 | 143 | for index in {1..3} 144 | do 145 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index} 146 | export best_file=$model_dir/checkpoint_best.pt 147 | export gen_out=$model_dir/infer_train_b${beam}_lp${lenpen} 148 | fairseq-generate ${data_dir} \ 149 | -s de -t en \ 150 | --path ${best_file} \ 151 | --gen-subset train \ 152 | --max-tokens ${infer_bsz} --beam ${beam} --lenpen ${lenpen} | dd of=$gen_out 153 | grep ^S ${gen_out} | cut -f2- > $gen_out.de 154 | grep ^H ${gen_out} | cut -f3- > $gen_out.en 155 | done 156 | 157 | ``` 158 | 159 | **Step 7**: Merge and filter duplicates with the original dataset 160 | 161 | ```bash 162 | 163 | export ori=raw_data/wmt_ende/train.tok.clean.bpe.32000 164 | export bw_prefix=train_fairseq/translate_deen_wmt16_bpe32k/big_tfm_baseline_df3584_s101/model_ 165 | export fw_prefix=train_fairseq/translate_ende_wmt16_bpe32k/big_tfm_baseline_df3584_s100/model_ 166 | export prefix= 167 | for i in {1..3} 168 | do 169 | export prefix=$bw_prefix$i/infer_train_b5_lp0.6:$prefix 170 | done 171 | for i in {1..3} 172 | do 173 | export prefix=$fw_prefix$i/infer_train_b5_lp0.6:$prefix 174 | done 175 | 176 | mkdir -p raw_data/aug_ende_wmt16_bpe32k_s3_r1 177 | python -u combine_corpus.py --src en --tgt de --ori $ori --hypos $prefix --dir raw_data/aug_ende_wmt16_bpe32k_s3_r1 --out train 178 | 179 | export out=data_fairseq/translate_ende_aug_b5_r1_s3_nodup_wmt16_bpe32k 180 | # Copy the original data to new augmented data. We keep the valid/test set the same, only change the train set 181 | cp -r data_fairseq/translate_ende_wmt16_bpe32k $out 182 | 183 | fairseq-preprocess --source-lang en --target-lang de \ 184 | --trainpref raw_data/aug_ende_wmt16_bpe32k_s3_r1/train \ 185 | --destdir $out \ 186 | --nwordssrc 0 --nwordstgt 0 \ 187 | --workers 16 \ 188 | --srcdict $out/dict.en.txt --tgtdict $out/dict.de.txt 189 | 190 | # This should report around 27M sentences 191 | ``` 192 | 193 | **Step 8**: Train final models 194 | 195 | ```bash 196 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 197 | export seed_prefix=200 198 | export problem=translate_ende_aug_b5_r1_s3_nodup_wmt16_bpe32k 199 | export model_name=big_tfm_baseline_df3584_s${seed_prefix} 200 | export data_dir=`pwd`/data_fairseq/$problem 201 | export index=1 202 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index} 203 | fairseq-train \ 204 | ${data_dir} \ 205 | -s en -t de \ 206 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ 207 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 208 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ 209 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ 210 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 211 | --max-update 43000 \ 212 | --keep-last-epochs 10 \ 213 | --save-dir ${model_dir} \ 214 | --ddp-backend no_c10d \ 215 | --seed ${seed_prefix}${index} \ 216 | --max-tokens 3584 \ 217 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar 218 | 219 | export avg_checkpoint=$model_dir/checkpoint_avg5.pt 220 | 221 | # average checkpoints 222 | python average_checkpoints.py \ 223 | --inputs ${model_dir} \ 224 | --num-epoch-checkpoints 5 \ 225 | --checkpoint-upper-bound 10000 \ 226 | --output ${avg_checkpoint} 227 | 228 | export gen_out=$model_dir/infer.test.avg5.b5.lp0.6 229 | export ref=${gen_out}.ref 230 | export hypo=${gen_out}.hypo 231 | export ref_atat=${ref}.atat 232 | export hypo_atat=${hypo}.atat 233 | export beam=5 234 | export lenpen=0.6 235 | echo "Finish generating averaged, start generating samples" 236 | fairseq-generate ${data_dir} \ 237 | -s en -t de \ 238 | --gen-subset test \ 239 | --path ${avg_checkpoint} \ 240 | --max-tokens 2048 \ 241 | --beam ${beam} \ 242 | --lenpen ${lenpen} \ 243 | --remove-bpe | dd of=${gen_out} 244 | grep ^T ${gen_out} | cut -f2- > ${ref} 245 | grep ^H ${gen_out} | cut -f3- > ${hypo} 246 | 247 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < ${hypo} > ${hypo_atat} 248 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < ${ref} > ${ref_atat} 249 | echo "------ Score BLEU ------------" 250 | $(which fairseq-score) --sys ${hypo_atat} --ref ${ref_atat} 251 | # expected: BLEU4 = 30.7 252 | ``` 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import collections 5 | import torch 6 | import os 7 | import re 8 | from fairseq.utils import import_user_module 9 | 10 | 11 | def default_avg_params(params_dict): 12 | averaged_params = collections.OrderedDict() 13 | 14 | # v should be a list of torch Tensor. 15 | for k, v in params_dict.items(): 16 | summed_v = None 17 | for x in v: 18 | summed_v = summed_v + x if summed_v is not None else x 19 | averaged_params[k] = summed_v / len(v) 20 | 21 | return averaged_params 22 | 23 | 24 | def ema_avg_params(params_dict, ema_decay): 25 | averaged_params = collections.OrderedDict() 26 | lens = [len(v) for k, v in params_dict.items()] 27 | assert all(x == lens[0] for x in lens), f'lens params: {lens}' 28 | num_checkpoints = lens[0] 29 | # y = x 30 | 31 | for k, v in params_dict.items(): 32 | # order: newest to oldest 33 | # reverse the order 34 | # y_t = x_t * decay + y_{t-1} * (1 - decay) 35 | total_v = None 36 | for x in reversed(v): 37 | if total_v is None: 38 | total_v = x 39 | else: 40 | total_v = x * ema_decay + total_v * (1.0 - ema_decay) 41 | 42 | averaged_params[k] = total_v 43 | return averaged_params 44 | 45 | 46 | def average_checkpoints(inputs, ema_decay=1.0): 47 | """Loads checkpoints from inputs and returns a model with averaged weights. 48 | 49 | Args: 50 | inputs: An iterable of string paths of checkpoints to load from. 51 | 52 | Returns: 53 | A dict of string keys mapping to various values. The 'model' key 54 | from the returned dict should correspond to an OrderedDict mapping 55 | string parameter names to torch Tensors. 56 | """ 57 | params_dict = collections.OrderedDict() 58 | params_keys = None 59 | new_state = None 60 | for i, f in enumerate(inputs): 61 | state = torch.load( 62 | f, 63 | map_location=( 64 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 65 | ), 66 | ) 67 | # Copies over the settings from the first checkpoint 68 | if new_state is None: 69 | new_state = state 70 | 71 | model_params = state['model'] 72 | 73 | model_params_keys = list(model_params.keys()) 74 | if params_keys is None: 75 | params_keys = model_params_keys 76 | elif params_keys != model_params_keys: 77 | raise KeyError( 78 | 'For checkpoint {}, expected list of params: {}, ' 79 | 'but found: {}'.format(f, params_keys, model_params_keys) 80 | ) 81 | 82 | for k in params_keys: 83 | if k not in params_dict: 84 | params_dict[k] = [] 85 | p = model_params[k] 86 | if isinstance(p, torch.HalfTensor): 87 | p = p.float() 88 | params_dict[k].append(p) 89 | 90 | if ema_decay < 1.0: 91 | print(f'Exponential moving averaging, decay={ema_decay}') 92 | averaged_params = ema_avg_params(params_dict, ema_decay) 93 | else: 94 | print(f'Default averaging') 95 | averaged_params = default_avg_params(params_dict) 96 | new_state['model'] = averaged_params 97 | return new_state 98 | 99 | 100 | def last_n_checkpoints(paths, n, update_based, upper_bound=None): 101 | assert len(paths) == 1 102 | path = paths[0] 103 | if update_based: 104 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 105 | else: 106 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 107 | files = os.listdir(path) 108 | 109 | entries = [] 110 | for f in files: 111 | m = pt_regexp.fullmatch(f) 112 | if m is not None: 113 | sort_key = int(m.group(1)) 114 | if upper_bound is None or sort_key <= upper_bound: 115 | entries.append((sort_key, m.group(0))) 116 | if len(entries) < n: 117 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 118 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 119 | 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser( 123 | description='Tool to average the params of input checkpoints to ' 124 | 'produce a new checkpoint', 125 | ) 126 | # fmt: off 127 | parser.add_argument('--inputs', required=True, nargs='+', 128 | help='Input checkpoint file paths.') 129 | parser.add_argument('--output', required=True, metavar='FILE', 130 | help='Write the new checkpoint containing the averaged weights to this path.') 131 | num_group = parser.add_mutually_exclusive_group() 132 | num_group.add_argument('--num-epoch-checkpoints', type=int, 133 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 134 | 'and average last this many of them.') 135 | num_group.add_argument('--num-update-checkpoints', type=int, 136 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 137 | 'and average last this many of them.') 138 | parser.add_argument('--checkpoint-upper-bound', type=int, 139 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, ' 140 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.') 141 | 142 | # parser.add_argument('--ema', type=float, default=1.0, help='exponential moving average decay') 143 | # parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') 144 | parser.add_argument('--ema', default='False', type=str, metavar='BOOL', help='ema') 145 | parser.add_argument('--ema_decay', type=float, default=1.0, help='exponential moving average decay') 146 | parser.add_argument('--user-dir', default=None) 147 | 148 | # fmt: on 149 | args = parser.parse_args() 150 | 151 | import_user_module(args) 152 | print(args) 153 | 154 | num = None 155 | is_update_based = False 156 | if args.num_update_checkpoints is not None: 157 | num = args.num_update_checkpoints 158 | is_update_based = True 159 | elif args.num_epoch_checkpoints is not None: 160 | num = args.num_epoch_checkpoints 161 | 162 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \ 163 | '--checkpoint-upper-bound requires --num-epoch-checkpoints' 164 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 165 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' 166 | 167 | if num is not None: 168 | args.inputs = last_n_checkpoints( 169 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, 170 | ) 171 | # print('averaging checkpoints: ', args.inputs) 172 | print('averaging checkpoints: ') 173 | for checkpoint in args.inputs: 174 | print(checkpoint) 175 | print('-' * 40) 176 | 177 | # ema = args.ema 178 | # assert isinstance(args.ema, bool) 179 | print(f'Start averaing with ema={args.ema}, ema_decay={args.ema_decay}') 180 | new_state = average_checkpoints(args.inputs, args.ema_decay) 181 | torch.save(new_state, args.output) 182 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /combine_corpus.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Written by: Xuan-Phi Nguyen (nxphi47) 4 | 5 | """ 6 | 7 | import torch 8 | import os 9 | import re 10 | import argparse 11 | 12 | 13 | def merge_nodup(src_ori, tgt_ori, src_hyps, tgt_hyps, **kwargs): 14 | sep = ' |||||||| ' 15 | merge = [f'{x}{sep}{y}' for x, y in zip(src_ori, tgt_ori)] 16 | # ori_merge = set(ori_merge) 17 | for i, (src, tgt) in enumerate(zip(src_hyps, tgt_hyps)): 18 | merge += [f'{x}{sep}{y}' for x, y in zip(src, tgt)] 19 | 20 | merge = set(merge) 21 | out = [x.split(sep) for x in merge] 22 | print(f'Total size: {len(out)}') 23 | src = [x[0] for x in out] 24 | tgt = [x[1] for x in out] 25 | 26 | return src, tgt 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--src', default='en', type=str) 32 | parser.add_argument('--tgt', default='de', type=str) 33 | parser.add_argument('--ori') 34 | parser.add_argument('--hypos') 35 | parser.add_argument('--dir') 36 | parser.add_argument('--out') 37 | 38 | args = parser.parse_args() 39 | 40 | ori_src_f = f'{args.ori}.{args.src}' 41 | ori_tgt_f = f'{args.ori}.{args.tgt}' 42 | hypos = [x for x in args.hypos.split(":") if x != ""] 43 | 44 | hypos_src_f = [f'{h}.{args.src}' for h in hypos] 45 | hypos_tgt_f = [f'{h}.{args.tgt}' for h in hypos] 46 | 47 | 48 | def read(fo): 49 | with open(fo, 'r') as f: 50 | out = f.read().strip().split('\n') 51 | return out 52 | 53 | ori_src = read(ori_src_f) 54 | ori_tgt = read(ori_tgt_f) 55 | hypos_src = [read(h) for h in hypos_src_f] 56 | hypos_tgt = [read(h) for h in hypos_tgt_f] 57 | assert len(hypos_src) == len(hypos_tgt) 58 | print(f'Merge size: {len(hypos_src)}') 59 | 60 | assert len(ori_src) == len(ori_tgt) 61 | for i, (hx, hy) in enumerate(zip(hypos_src, hypos_tgt)): 62 | assert len(hx) == len(hy), f'invalid len {i}' 63 | 64 | src, tgt = merge_nodup(ori_src, ori_tgt, hypos_src, hypos_tgt) 65 | os.makedirs(args.dir, exist_ok=True) 66 | src_out = os.path.join(args.dir, f'{args.out}.{args.src}') 67 | tgt_out = os.path.join(args.dir, f'{args.out}.{args.tgt}') 68 | print(f'src_out:{src_out}') 69 | print(f'tgt_out:{tgt_out}') 70 | with open(src_out, 'w') as f: 71 | f.write('\n'.join(src)) 72 | with open(tgt_out, 'w') as f: 73 | f.write('\n'.join(tgt)) 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /example.md: -------------------------------------------------------------------------------- 1 | # Neural Machine Translation 2 | 3 | This README contains instructions for [using pretrained translation models](#example-usage-torchhub) 4 | as well as [training new models](#training-a-new-model). 5 | 6 | ## Pre-trained models 7 | 8 | Model | Description | Dataset | Download 9 | ---|---|---|--- 10 | `conv.wmt14.en-fr` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) 11 | `conv.wmt14.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) 12 | `conv.wmt17.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) 13 | `transformer.wmt14.en-fr` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) 14 | `transformer.wmt16.en-de` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) 15 | `transformer.wmt18.en-de` | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381))
WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz)
See NOTE in the archive 16 | `transformer.wmt19.en-de` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz) 17 | `transformer.wmt19.de-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz) 18 | `transformer.wmt19.en-ru` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz) 19 | `transformer.wmt19.ru-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz) 20 | 21 | ## Example usage (torch.hub) 22 | 23 | Interactive translation via PyTorch Hub: 24 | ```python 25 | import torch 26 | 27 | # List available models 28 | torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] 29 | 30 | # Load a transformer trained on WMT'16 En-De 31 | en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt') 32 | 33 | # The underlying model is available under the *models* attribute 34 | assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel) 35 | 36 | # Translate a sentence 37 | en2de.translate('Hello world!') 38 | # 'Hallo Welt!' 39 | ``` 40 | 41 | ## Example usage (CLI tools) 42 | 43 | Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: 44 | ```bash 45 | mkdir -p data-bin 46 | curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin 47 | curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin 48 | fairseq-generate data-bin/wmt14.en-fr.newstest2014 \ 49 | --path data-bin/wmt14.en-fr.fconv-py/model.pt \ 50 | --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out 51 | # ... 52 | # | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s) 53 | # | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) 54 | 55 | # Compute BLEU score 56 | grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys 57 | grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref 58 | fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref 59 | # BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) 60 | ``` 61 | 62 | ## Training a new model 63 | 64 | ### IWSLT'14 German to English (Transformer) 65 | 66 | The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf). 67 | 68 | First download and preprocess the data: 69 | ```bash 70 | # Download and prepare the data 71 | cd examples/translation/ 72 | bash prepare-iwslt14.sh 73 | cd ../.. 74 | 75 | # Preprocess/binarize the data 76 | TEXT=examples/translation/iwslt14.tokenized.de-en 77 | fairseq-preprocess --source-lang de --target-lang en \ 78 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 79 | --destdir data-bin/iwslt14.tokenized.de-en \ 80 | --workers 20 81 | ``` 82 | 83 | Next we'll train a Transformer translation model over this data: 84 | ```bash 85 | CUDA_VISIBLE_DEVICES=0 fairseq-train \ 86 | data-bin/iwslt14.tokenized.de-en \ 87 | --arch transformer_iwslt_de_en --share-decoder-input-output-embed \ 88 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 89 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 90 | --dropout 0.3 --weight-decay 0.0001 \ 91 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 92 | --max-tokens 4096 93 | ``` 94 | 95 | Finally we can evaluate our trained model: 96 | ```bash 97 | fairseq-generate data-bin/iwslt14.tokenized.de-en \ 98 | --path checkpoints/checkpoint_best.pt \ 99 | --batch-size 128 --beam 5 --remove-bpe 100 | ``` 101 | 102 | ### WMT'14 English to German (Convolutional) 103 | 104 | The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset. 105 | See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data. 106 | 107 | The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script. 108 | By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17. 109 | 110 | To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option. 111 | 112 | ```bash 113 | # Download and prepare the data 114 | cd examples/translation/ 115 | # WMT'17 data: 116 | bash prepare-wmt14en2de.sh 117 | # or to use WMT'14 data: 118 | # bash prepare-wmt14en2de.sh --icml17 119 | cd ../.. 120 | 121 | # Binarize the dataset 122 | TEXT=examples/translation/wmt17_en_de 123 | fairseq-preprocess \ 124 | --source-lang en --target-lang de \ 125 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 126 | --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \ 127 | --workers 20 128 | 129 | # Train the model 130 | mkdir -p checkpoints/fconv_wmt_en_de 131 | fairseq-train \ 132 | data-bin/wmt17_en_de \ 133 | --arch fconv_wmt_en_de \ 134 | --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ 135 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 136 | --lr-scheduler fixed --force-anneal 50 \ 137 | --save-dir checkpoints/fconv_wmt_en_de 138 | 139 | # Evaluate 140 | fairseq-generate data-bin/wmt17_en_de \ 141 | --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \ 142 | --beam 5 --remove-bpe 143 | ``` 144 | 145 | ### WMT'14 English to French 146 | ```bash 147 | # Download and prepare the data 148 | cd examples/translation/ 149 | bash prepare-wmt14en2fr.sh 150 | cd ../.. 151 | 152 | # Binarize the dataset 153 | TEXT=examples/translation/wmt14_en_fr 154 | fairseq-preprocess \ 155 | --source-lang en --target-lang fr \ 156 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 157 | --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \ 158 | --workers 60 159 | 160 | # Train the model 161 | mkdir -p checkpoints/fconv_wmt_en_fr 162 | fairseq-train \ 163 | data-bin/wmt14_en_fr \ 164 | --lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \ 165 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 166 | --lr-scheduler fixed --force-anneal 50 \ 167 | --arch fconv_wmt_en_fr \ 168 | --save-dir checkpoints/fconv_wmt_en_fr 169 | 170 | # Evaluate 171 | fairseq-generate \ 172 | data-bin/fconv_wmt_en_fr \ 173 | --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \ 174 | --beam 5 --remove-bpe 175 | ``` 176 | 177 | ## Multilingual Translation 178 | 179 | We also support training multilingual translation models. In this example we'll 180 | train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets. 181 | 182 | Note that we use slightly different preprocessing here than for the IWSLT'14 183 | En-De data above. In particular we learn a joint BPE code for all three 184 | languages and use interactive.py and sacrebleu for scoring the test set. 185 | 186 | ```bash 187 | # First install sacrebleu and sentencepiece 188 | pip install sacrebleu sentencepiece 189 | 190 | # Then download and preprocess the data 191 | cd examples/translation/ 192 | bash prepare-iwslt17-multilingual.sh 193 | cd ../.. 194 | 195 | # Binarize the de-en dataset 196 | TEXT=examples/translation/iwslt17.de_fr.en.bpe16k 197 | fairseq-preprocess --source-lang de --target-lang en \ 198 | --trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \ 199 | --joined-dictionary \ 200 | --destdir data-bin/iwslt17.de_fr.en.bpe16k \ 201 | --workers 10 202 | 203 | # Binarize the fr-en dataset 204 | # NOTE: it's important to reuse the en dictionary from the previous step 205 | fairseq-preprocess --source-lang fr --target-lang en \ 206 | --trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \ 207 | --joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \ 208 | --destdir data-bin/iwslt17.de_fr.en.bpe16k \ 209 | --workers 10 210 | 211 | # Train a multilingual transformer model 212 | # NOTE: the command below assumes 1 GPU, but accumulates gradients from 213 | # 8 fwd/bwd passes to simulate training on 8 GPUs 214 | mkdir -p checkpoints/multilingual_transformer 215 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ 216 | --max-epoch 50 \ 217 | --ddp-backend=no_c10d \ 218 | --task multilingual_translation --lang-pairs de-en,fr-en \ 219 | --arch multilingual_transformer_iwslt_de_en \ 220 | --share-decoders --share-decoder-input-output-embed \ 221 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 222 | --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \ 223 | --warmup-updates 4000 --warmup-init-lr '1e-07' \ 224 | --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ 225 | --dropout 0.3 --weight-decay 0.0001 \ 226 | --save-dir checkpoints/multilingual_transformer \ 227 | --max-tokens 4000 \ 228 | --update-freq 8 229 | 230 | # Generate and score the test set with sacrebleu 231 | SRC=de 232 | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \ 233 | | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \ 234 | > iwslt17.test.${SRC}-en.${SRC}.bpe 235 | cat iwslt17.test.${SRC}-en.${SRC}.bpe \ 236 | | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \ 237 | --task multilingual_translation --source-lang ${SRC} --target-lang en \ 238 | --path checkpoints/multilingual_transformer/checkpoint_best.pt \ 239 | --buffer-size 2000 --batch-size 128 \ 240 | --beam 5 --remove-bpe=sentencepiece \ 241 | > iwslt17.test.${SRC}-en.en.sys 242 | grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \ 243 | | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en 244 | ``` 245 | 246 | ##### Argument format during inference 247 | 248 | During inference it is required to specify a single `--source-lang` and 249 | `--target-lang`, which indicates the inference langauge direction. 250 | `--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to 251 | the same value as training. --------------------------------------------------------------------------------