├── .gitignore
├── README.md
├── average_checkpoints.py
├── combine_corpus.py
└── example.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Data Diversification: A Simple Strategy For Neural Machine Translation
2 | #### Accepted as conference paper at 34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada, 2020
3 | #### Authors: Xuan-Phi Nguyen, Shafiq Joty, Wu Kui, Ai Ti Aw
4 |
5 | Paper link: [https://arxiv.org/abs/1911.01986](https://arxiv.org/abs/1911.01986)
6 |
7 | # Citation
8 |
9 | Please cite as:
10 |
11 | ```bibtex
12 | @incollection{nguyen2020data,
13 | title = {Data Diversification: A Simple Strategy For Neural Machine Translation},
14 | author = {Xuan-Phi Nguyen and Shafiq Joty and Wu Kui and Ai Ti Aw},
15 | booktitle = {Advances in Neural Information Processing Systems 32},
16 | year = {2020},
17 | publisher = {Curran Associates, Inc.},
18 | }
19 | ```
20 |
21 | ## Pretrained Models
22 |
23 | Model | Description | Dataset | Download
24 | ---|---|---|---
25 | `WMT'16 En-De` | Transformer | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: [download (.tar.gz)](https://drive.google.com/file/d/1dpUPmVvLZKiUHWeqo0_0-ox2yjGb109e/view?usp=sharing)
26 |
27 | ## Instruction To train WMT English-German
28 |
29 | **Step 1**: Follow instruction from [Fairseq](https://github.com/pytorch/fairseq/tree/master/examples/translation) to create
30 | the WMT'14 Dataset.
31 |
32 | Save the processed data as ``data_fairseq/translate_ende_wmt16_bpe32k``
33 |
34 | Save the raw data (which contains the file train.tok.clean.bpe.32000.en) to ``raw_data/wmt_ende``
35 |
36 | **Step 2**: copy the same data to ``data_fairseq/translate_deen_wmt16_bpe32k`` for De-En
37 | ```bash
38 | cp -r data_fairseq/translate_ende_wmt16_bpe32k data_fairseq/translate_deen_wmt16_bpe32k
39 | ```
40 |
41 |
42 | **Step 3**: Train forward models. Step 3-4 can be done all in parallel, if you have more than 8 GPUs, you can run all 6 models at once.
43 | ```bash
44 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
45 | export seed_prefix=100
46 | export problem=translate_ende_wmt16_bpe32k
47 | export model_name=big_tfm_baseline_df3584_s${seed_prefix}
48 | export data_dir=`pwd`/data_fairseq/$problem
49 |
50 | for index in {1..3}
51 | do
52 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index}
53 | fairseq-train \
54 | ${data_dir} \
55 | -s en -t de \
56 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
57 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
58 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
59 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
60 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
61 | --max-update 43000 \
62 | --keep-last-epochs 10 \
63 | --save-dir ${model_dir} \
64 | --ddp-backend no_c10d \
65 | --seed ${seed_prefix}${index} \
66 | --max-tokens 3584 \
67 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar
68 | done
69 | ```
70 |
71 |
72 | **Step 4**: Train backward models
73 | ```bash
74 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
75 | export seed_prefix=101
76 | export problem=translate_deen_wmt16_bpe32k
77 | export model_name=big_tfm_baseline_df3584_s${seed_prefix}
78 | export data_dir=`pwd`/data_fairseq/$problem
79 |
80 | for index in {1..3}
81 | do
82 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index}
83 | fairseq-train \
84 | ${data_dir} \
85 | -s de -t en \
86 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
87 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
88 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
89 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
90 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
91 | --max-update 43000 \
92 | --keep-last-epochs 10 \
93 | --save-dir ${model_dir} \
94 | --ddp-backend no_c10d \
95 | --seed ${seed_prefix}${index} \
96 | --max-tokens 3584 \
97 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar
98 | done
99 | ```
100 |
101 |
102 | **Step 5**: Inference forward models
103 |
104 | ```bash
105 | export CUDA_VISIBLE_DEVICES=0
106 | export seed_prefix=100
107 | export problem=translate_ende_wmt16_bpe32k
108 | export model_name=big_tfm_baseline_df3584_s${seed_prefix}
109 | export data_dir=`pwd`/data_fairseq/$problem
110 | export beam=5
111 | export lenpen=0.6
112 | export round=1
113 |
114 | for index in {1..3}
115 | do
116 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index}
117 | export best_file=$model_dir/checkpoint_best.pt
118 | export gen_out=$model_dir/infer_train_b${beam}_lp${lenpen}
119 | fairseq-generate ${data_dir} \
120 | -s en -t de \
121 | --path ${best_file} \
122 | --gen-subset train \
123 | --max-tokens ${infer_bsz} --beam ${beam} --lenpen ${lenpen} | dd of=$gen_out
124 | grep ^S ${gen_out} | cut -f2- > $gen_out.en
125 | grep ^H ${gen_out} | cut -f3- > $gen_out.de
126 | done
127 |
128 | ```
129 |
130 |
131 | **Step 6**: Inference backward models
132 |
133 | ```bash
134 | export CUDA_VISIBLE_DEVICES=0
135 | export seed_prefix=101
136 | export problem=translate_deen_wmt16_bpe32k
137 | export model_name=big_tfm_baseline_df3584_s${seed_prefix}
138 | export data_dir=`pwd`/data_fairseq/$problem
139 | export beam=5
140 | export lenpen=0.6
141 | export round=1
142 |
143 | for index in {1..3}
144 | do
145 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index}
146 | export best_file=$model_dir/checkpoint_best.pt
147 | export gen_out=$model_dir/infer_train_b${beam}_lp${lenpen}
148 | fairseq-generate ${data_dir} \
149 | -s de -t en \
150 | --path ${best_file} \
151 | --gen-subset train \
152 | --max-tokens ${infer_bsz} --beam ${beam} --lenpen ${lenpen} | dd of=$gen_out
153 | grep ^S ${gen_out} | cut -f2- > $gen_out.de
154 | grep ^H ${gen_out} | cut -f3- > $gen_out.en
155 | done
156 |
157 | ```
158 |
159 | **Step 7**: Merge and filter duplicates with the original dataset
160 |
161 | ```bash
162 |
163 | export ori=raw_data/wmt_ende/train.tok.clean.bpe.32000
164 | export bw_prefix=train_fairseq/translate_deen_wmt16_bpe32k/big_tfm_baseline_df3584_s101/model_
165 | export fw_prefix=train_fairseq/translate_ende_wmt16_bpe32k/big_tfm_baseline_df3584_s100/model_
166 | export prefix=
167 | for i in {1..3}
168 | do
169 | export prefix=$bw_prefix$i/infer_train_b5_lp0.6:$prefix
170 | done
171 | for i in {1..3}
172 | do
173 | export prefix=$fw_prefix$i/infer_train_b5_lp0.6:$prefix
174 | done
175 |
176 | mkdir -p raw_data/aug_ende_wmt16_bpe32k_s3_r1
177 | python -u combine_corpus.py --src en --tgt de --ori $ori --hypos $prefix --dir raw_data/aug_ende_wmt16_bpe32k_s3_r1 --out train
178 |
179 | export out=data_fairseq/translate_ende_aug_b5_r1_s3_nodup_wmt16_bpe32k
180 | # Copy the original data to new augmented data. We keep the valid/test set the same, only change the train set
181 | cp -r data_fairseq/translate_ende_wmt16_bpe32k $out
182 |
183 | fairseq-preprocess --source-lang en --target-lang de \
184 | --trainpref raw_data/aug_ende_wmt16_bpe32k_s3_r1/train \
185 | --destdir $out \
186 | --nwordssrc 0 --nwordstgt 0 \
187 | --workers 16 \
188 | --srcdict $out/dict.en.txt --tgtdict $out/dict.de.txt
189 |
190 | # This should report around 27M sentences
191 | ```
192 |
193 | **Step 8**: Train final models
194 |
195 | ```bash
196 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
197 | export seed_prefix=200
198 | export problem=translate_ende_aug_b5_r1_s3_nodup_wmt16_bpe32k
199 | export model_name=big_tfm_baseline_df3584_s${seed_prefix}
200 | export data_dir=`pwd`/data_fairseq/$problem
201 | export index=1
202 | export model_dir=train_fairseq/${problem}/${model_name}/model_${index}
203 | fairseq-train \
204 | ${data_dir} \
205 | -s en -t de \
206 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
207 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
208 | --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
209 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
210 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
211 | --max-update 43000 \
212 | --keep-last-epochs 10 \
213 | --save-dir ${model_dir} \
214 | --ddp-backend no_c10d \
215 | --seed ${seed_prefix}${index} \
216 | --max-tokens 3584 \
217 | --fp16 --update-freq 16 --log-interval 10000 --no-progress-bar
218 |
219 | export avg_checkpoint=$model_dir/checkpoint_avg5.pt
220 |
221 | # average checkpoints
222 | python average_checkpoints.py \
223 | --inputs ${model_dir} \
224 | --num-epoch-checkpoints 5 \
225 | --checkpoint-upper-bound 10000 \
226 | --output ${avg_checkpoint}
227 |
228 | export gen_out=$model_dir/infer.test.avg5.b5.lp0.6
229 | export ref=${gen_out}.ref
230 | export hypo=${gen_out}.hypo
231 | export ref_atat=${ref}.atat
232 | export hypo_atat=${hypo}.atat
233 | export beam=5
234 | export lenpen=0.6
235 | echo "Finish generating averaged, start generating samples"
236 | fairseq-generate ${data_dir} \
237 | -s en -t de \
238 | --gen-subset test \
239 | --path ${avg_checkpoint} \
240 | --max-tokens 2048 \
241 | --beam ${beam} \
242 | --lenpen ${lenpen} \
243 | --remove-bpe | dd of=${gen_out}
244 | grep ^T ${gen_out} | cut -f2- > ${ref}
245 | grep ^H ${gen_out} | cut -f3- > ${hypo}
246 |
247 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < ${hypo} > ${hypo_atat}
248 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < ${ref} > ${ref_atat}
249 | echo "------ Score BLEU ------------"
250 | $(which fairseq-score) --sys ${hypo_atat} --ref ${ref_atat}
251 | # expected: BLEU4 = 30.7
252 | ```
253 |
254 |
255 |
256 |
--------------------------------------------------------------------------------
/average_checkpoints.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import collections
5 | import torch
6 | import os
7 | import re
8 | from fairseq.utils import import_user_module
9 |
10 |
11 | def default_avg_params(params_dict):
12 | averaged_params = collections.OrderedDict()
13 |
14 | # v should be a list of torch Tensor.
15 | for k, v in params_dict.items():
16 | summed_v = None
17 | for x in v:
18 | summed_v = summed_v + x if summed_v is not None else x
19 | averaged_params[k] = summed_v / len(v)
20 |
21 | return averaged_params
22 |
23 |
24 | def ema_avg_params(params_dict, ema_decay):
25 | averaged_params = collections.OrderedDict()
26 | lens = [len(v) for k, v in params_dict.items()]
27 | assert all(x == lens[0] for x in lens), f'lens params: {lens}'
28 | num_checkpoints = lens[0]
29 | # y = x
30 |
31 | for k, v in params_dict.items():
32 | # order: newest to oldest
33 | # reverse the order
34 | # y_t = x_t * decay + y_{t-1} * (1 - decay)
35 | total_v = None
36 | for x in reversed(v):
37 | if total_v is None:
38 | total_v = x
39 | else:
40 | total_v = x * ema_decay + total_v * (1.0 - ema_decay)
41 |
42 | averaged_params[k] = total_v
43 | return averaged_params
44 |
45 |
46 | def average_checkpoints(inputs, ema_decay=1.0):
47 | """Loads checkpoints from inputs and returns a model with averaged weights.
48 |
49 | Args:
50 | inputs: An iterable of string paths of checkpoints to load from.
51 |
52 | Returns:
53 | A dict of string keys mapping to various values. The 'model' key
54 | from the returned dict should correspond to an OrderedDict mapping
55 | string parameter names to torch Tensors.
56 | """
57 | params_dict = collections.OrderedDict()
58 | params_keys = None
59 | new_state = None
60 | for i, f in enumerate(inputs):
61 | state = torch.load(
62 | f,
63 | map_location=(
64 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
65 | ),
66 | )
67 | # Copies over the settings from the first checkpoint
68 | if new_state is None:
69 | new_state = state
70 |
71 | model_params = state['model']
72 |
73 | model_params_keys = list(model_params.keys())
74 | if params_keys is None:
75 | params_keys = model_params_keys
76 | elif params_keys != model_params_keys:
77 | raise KeyError(
78 | 'For checkpoint {}, expected list of params: {}, '
79 | 'but found: {}'.format(f, params_keys, model_params_keys)
80 | )
81 |
82 | for k in params_keys:
83 | if k not in params_dict:
84 | params_dict[k] = []
85 | p = model_params[k]
86 | if isinstance(p, torch.HalfTensor):
87 | p = p.float()
88 | params_dict[k].append(p)
89 |
90 | if ema_decay < 1.0:
91 | print(f'Exponential moving averaging, decay={ema_decay}')
92 | averaged_params = ema_avg_params(params_dict, ema_decay)
93 | else:
94 | print(f'Default averaging')
95 | averaged_params = default_avg_params(params_dict)
96 | new_state['model'] = averaged_params
97 | return new_state
98 |
99 |
100 | def last_n_checkpoints(paths, n, update_based, upper_bound=None):
101 | assert len(paths) == 1
102 | path = paths[0]
103 | if update_based:
104 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
105 | else:
106 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
107 | files = os.listdir(path)
108 |
109 | entries = []
110 | for f in files:
111 | m = pt_regexp.fullmatch(f)
112 | if m is not None:
113 | sort_key = int(m.group(1))
114 | if upper_bound is None or sort_key <= upper_bound:
115 | entries.append((sort_key, m.group(0)))
116 | if len(entries) < n:
117 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
118 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
119 |
120 |
121 | def main():
122 | parser = argparse.ArgumentParser(
123 | description='Tool to average the params of input checkpoints to '
124 | 'produce a new checkpoint',
125 | )
126 | # fmt: off
127 | parser.add_argument('--inputs', required=True, nargs='+',
128 | help='Input checkpoint file paths.')
129 | parser.add_argument('--output', required=True, metavar='FILE',
130 | help='Write the new checkpoint containing the averaged weights to this path.')
131 | num_group = parser.add_mutually_exclusive_group()
132 | num_group.add_argument('--num-epoch-checkpoints', type=int,
133 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
134 | 'and average last this many of them.')
135 | num_group.add_argument('--num-update-checkpoints', type=int,
136 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
137 | 'and average last this many of them.')
138 | parser.add_argument('--checkpoint-upper-bound', type=int,
139 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
140 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
141 |
142 | # parser.add_argument('--ema', type=float, default=1.0, help='exponential moving average decay')
143 | # parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
144 | parser.add_argument('--ema', default='False', type=str, metavar='BOOL', help='ema')
145 | parser.add_argument('--ema_decay', type=float, default=1.0, help='exponential moving average decay')
146 | parser.add_argument('--user-dir', default=None)
147 |
148 | # fmt: on
149 | args = parser.parse_args()
150 |
151 | import_user_module(args)
152 | print(args)
153 |
154 | num = None
155 | is_update_based = False
156 | if args.num_update_checkpoints is not None:
157 | num = args.num_update_checkpoints
158 | is_update_based = True
159 | elif args.num_epoch_checkpoints is not None:
160 | num = args.num_epoch_checkpoints
161 |
162 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
163 | '--checkpoint-upper-bound requires --num-epoch-checkpoints'
164 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
165 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
166 |
167 | if num is not None:
168 | args.inputs = last_n_checkpoints(
169 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
170 | )
171 | # print('averaging checkpoints: ', args.inputs)
172 | print('averaging checkpoints: ')
173 | for checkpoint in args.inputs:
174 | print(checkpoint)
175 | print('-' * 40)
176 |
177 | # ema = args.ema
178 | # assert isinstance(args.ema, bool)
179 | print(f'Start averaing with ema={args.ema}, ema_decay={args.ema_decay}')
180 | new_state = average_checkpoints(args.inputs, args.ema_decay)
181 | torch.save(new_state, args.output)
182 | print('Finished writing averaged checkpoint to {}.'.format(args.output))
183 |
184 |
185 | if __name__ == '__main__':
186 | main()
187 |
--------------------------------------------------------------------------------
/combine_corpus.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Written by: Xuan-Phi Nguyen (nxphi47)
4 |
5 | """
6 |
7 | import torch
8 | import os
9 | import re
10 | import argparse
11 |
12 |
13 | def merge_nodup(src_ori, tgt_ori, src_hyps, tgt_hyps, **kwargs):
14 | sep = ' |||||||| '
15 | merge = [f'{x}{sep}{y}' for x, y in zip(src_ori, tgt_ori)]
16 | # ori_merge = set(ori_merge)
17 | for i, (src, tgt) in enumerate(zip(src_hyps, tgt_hyps)):
18 | merge += [f'{x}{sep}{y}' for x, y in zip(src, tgt)]
19 |
20 | merge = set(merge)
21 | out = [x.split(sep) for x in merge]
22 | print(f'Total size: {len(out)}')
23 | src = [x[0] for x in out]
24 | tgt = [x[1] for x in out]
25 |
26 | return src, tgt
27 |
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--src', default='en', type=str)
32 | parser.add_argument('--tgt', default='de', type=str)
33 | parser.add_argument('--ori')
34 | parser.add_argument('--hypos')
35 | parser.add_argument('--dir')
36 | parser.add_argument('--out')
37 |
38 | args = parser.parse_args()
39 |
40 | ori_src_f = f'{args.ori}.{args.src}'
41 | ori_tgt_f = f'{args.ori}.{args.tgt}'
42 | hypos = [x for x in args.hypos.split(":") if x != ""]
43 |
44 | hypos_src_f = [f'{h}.{args.src}' for h in hypos]
45 | hypos_tgt_f = [f'{h}.{args.tgt}' for h in hypos]
46 |
47 |
48 | def read(fo):
49 | with open(fo, 'r') as f:
50 | out = f.read().strip().split('\n')
51 | return out
52 |
53 | ori_src = read(ori_src_f)
54 | ori_tgt = read(ori_tgt_f)
55 | hypos_src = [read(h) for h in hypos_src_f]
56 | hypos_tgt = [read(h) for h in hypos_tgt_f]
57 | assert len(hypos_src) == len(hypos_tgt)
58 | print(f'Merge size: {len(hypos_src)}')
59 |
60 | assert len(ori_src) == len(ori_tgt)
61 | for i, (hx, hy) in enumerate(zip(hypos_src, hypos_tgt)):
62 | assert len(hx) == len(hy), f'invalid len {i}'
63 |
64 | src, tgt = merge_nodup(ori_src, ori_tgt, hypos_src, hypos_tgt)
65 | os.makedirs(args.dir, exist_ok=True)
66 | src_out = os.path.join(args.dir, f'{args.out}.{args.src}')
67 | tgt_out = os.path.join(args.dir, f'{args.out}.{args.tgt}')
68 | print(f'src_out:{src_out}')
69 | print(f'tgt_out:{tgt_out}')
70 | with open(src_out, 'w') as f:
71 | f.write('\n'.join(src))
72 | with open(tgt_out, 'w') as f:
73 | f.write('\n'.join(tgt))
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/example.md:
--------------------------------------------------------------------------------
1 | # Neural Machine Translation
2 |
3 | This README contains instructions for [using pretrained translation models](#example-usage-torchhub)
4 | as well as [training new models](#training-a-new-model).
5 |
6 | ## Pre-trained models
7 |
8 | Model | Description | Dataset | Download
9 | ---|---|---|---
10 | `conv.wmt14.en-fr` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
11 | `conv.wmt14.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
12 | `conv.wmt17.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
13 | `transformer.wmt14.en-fr` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
14 | `transformer.wmt16.en-de` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
15 | `transformer.wmt18.en-de` | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381))
WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz)
See NOTE in the archive
16 | `transformer.wmt19.en-de` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz)
17 | `transformer.wmt19.de-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz)
18 | `transformer.wmt19.en-ru` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz)
19 | `transformer.wmt19.ru-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz)
20 |
21 | ## Example usage (torch.hub)
22 |
23 | Interactive translation via PyTorch Hub:
24 | ```python
25 | import torch
26 |
27 | # List available models
28 | torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
29 |
30 | # Load a transformer trained on WMT'16 En-De
31 | en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
32 |
33 | # The underlying model is available under the *models* attribute
34 | assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
35 |
36 | # Translate a sentence
37 | en2de.translate('Hello world!')
38 | # 'Hallo Welt!'
39 | ```
40 |
41 | ## Example usage (CLI tools)
42 |
43 | Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
44 | ```bash
45 | mkdir -p data-bin
46 | curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
47 | curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
48 | fairseq-generate data-bin/wmt14.en-fr.newstest2014 \
49 | --path data-bin/wmt14.en-fr.fconv-py/model.pt \
50 | --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
51 | # ...
52 | # | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
53 | # | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
54 |
55 | # Compute BLEU score
56 | grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
57 | grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
58 | fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
59 | # BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
60 | ```
61 |
62 | ## Training a new model
63 |
64 | ### IWSLT'14 German to English (Transformer)
65 |
66 | The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf).
67 |
68 | First download and preprocess the data:
69 | ```bash
70 | # Download and prepare the data
71 | cd examples/translation/
72 | bash prepare-iwslt14.sh
73 | cd ../..
74 |
75 | # Preprocess/binarize the data
76 | TEXT=examples/translation/iwslt14.tokenized.de-en
77 | fairseq-preprocess --source-lang de --target-lang en \
78 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
79 | --destdir data-bin/iwslt14.tokenized.de-en \
80 | --workers 20
81 | ```
82 |
83 | Next we'll train a Transformer translation model over this data:
84 | ```bash
85 | CUDA_VISIBLE_DEVICES=0 fairseq-train \
86 | data-bin/iwslt14.tokenized.de-en \
87 | --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
88 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
89 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
90 | --dropout 0.3 --weight-decay 0.0001 \
91 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
92 | --max-tokens 4096
93 | ```
94 |
95 | Finally we can evaluate our trained model:
96 | ```bash
97 | fairseq-generate data-bin/iwslt14.tokenized.de-en \
98 | --path checkpoints/checkpoint_best.pt \
99 | --batch-size 128 --beam 5 --remove-bpe
100 | ```
101 |
102 | ### WMT'14 English to German (Convolutional)
103 |
104 | The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset.
105 | See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data.
106 |
107 | The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
108 | By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17.
109 |
110 | To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
111 |
112 | ```bash
113 | # Download and prepare the data
114 | cd examples/translation/
115 | # WMT'17 data:
116 | bash prepare-wmt14en2de.sh
117 | # or to use WMT'14 data:
118 | # bash prepare-wmt14en2de.sh --icml17
119 | cd ../..
120 |
121 | # Binarize the dataset
122 | TEXT=examples/translation/wmt17_en_de
123 | fairseq-preprocess \
124 | --source-lang en --target-lang de \
125 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
126 | --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \
127 | --workers 20
128 |
129 | # Train the model
130 | mkdir -p checkpoints/fconv_wmt_en_de
131 | fairseq-train \
132 | data-bin/wmt17_en_de \
133 | --arch fconv_wmt_en_de \
134 | --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
135 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
136 | --lr-scheduler fixed --force-anneal 50 \
137 | --save-dir checkpoints/fconv_wmt_en_de
138 |
139 | # Evaluate
140 | fairseq-generate data-bin/wmt17_en_de \
141 | --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \
142 | --beam 5 --remove-bpe
143 | ```
144 |
145 | ### WMT'14 English to French
146 | ```bash
147 | # Download and prepare the data
148 | cd examples/translation/
149 | bash prepare-wmt14en2fr.sh
150 | cd ../..
151 |
152 | # Binarize the dataset
153 | TEXT=examples/translation/wmt14_en_fr
154 | fairseq-preprocess \
155 | --source-lang en --target-lang fr \
156 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
157 | --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \
158 | --workers 60
159 |
160 | # Train the model
161 | mkdir -p checkpoints/fconv_wmt_en_fr
162 | fairseq-train \
163 | data-bin/wmt14_en_fr \
164 | --lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \
165 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
166 | --lr-scheduler fixed --force-anneal 50 \
167 | --arch fconv_wmt_en_fr \
168 | --save-dir checkpoints/fconv_wmt_en_fr
169 |
170 | # Evaluate
171 | fairseq-generate \
172 | data-bin/fconv_wmt_en_fr \
173 | --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \
174 | --beam 5 --remove-bpe
175 | ```
176 |
177 | ## Multilingual Translation
178 |
179 | We also support training multilingual translation models. In this example we'll
180 | train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
181 |
182 | Note that we use slightly different preprocessing here than for the IWSLT'14
183 | En-De data above. In particular we learn a joint BPE code for all three
184 | languages and use interactive.py and sacrebleu for scoring the test set.
185 |
186 | ```bash
187 | # First install sacrebleu and sentencepiece
188 | pip install sacrebleu sentencepiece
189 |
190 | # Then download and preprocess the data
191 | cd examples/translation/
192 | bash prepare-iwslt17-multilingual.sh
193 | cd ../..
194 |
195 | # Binarize the de-en dataset
196 | TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
197 | fairseq-preprocess --source-lang de --target-lang en \
198 | --trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \
199 | --joined-dictionary \
200 | --destdir data-bin/iwslt17.de_fr.en.bpe16k \
201 | --workers 10
202 |
203 | # Binarize the fr-en dataset
204 | # NOTE: it's important to reuse the en dictionary from the previous step
205 | fairseq-preprocess --source-lang fr --target-lang en \
206 | --trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \
207 | --joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
208 | --destdir data-bin/iwslt17.de_fr.en.bpe16k \
209 | --workers 10
210 |
211 | # Train a multilingual transformer model
212 | # NOTE: the command below assumes 1 GPU, but accumulates gradients from
213 | # 8 fwd/bwd passes to simulate training on 8 GPUs
214 | mkdir -p checkpoints/multilingual_transformer
215 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
216 | --max-epoch 50 \
217 | --ddp-backend=no_c10d \
218 | --task multilingual_translation --lang-pairs de-en,fr-en \
219 | --arch multilingual_transformer_iwslt_de_en \
220 | --share-decoders --share-decoder-input-output-embed \
221 | --optimizer adam --adam-betas '(0.9, 0.98)' \
222 | --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
223 | --warmup-updates 4000 --warmup-init-lr '1e-07' \
224 | --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
225 | --dropout 0.3 --weight-decay 0.0001 \
226 | --save-dir checkpoints/multilingual_transformer \
227 | --max-tokens 4000 \
228 | --update-freq 8
229 |
230 | # Generate and score the test set with sacrebleu
231 | SRC=de
232 | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
233 | | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
234 | > iwslt17.test.${SRC}-en.${SRC}.bpe
235 | cat iwslt17.test.${SRC}-en.${SRC}.bpe \
236 | | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
237 | --task multilingual_translation --source-lang ${SRC} --target-lang en \
238 | --path checkpoints/multilingual_transformer/checkpoint_best.pt \
239 | --buffer-size 2000 --batch-size 128 \
240 | --beam 5 --remove-bpe=sentencepiece \
241 | > iwslt17.test.${SRC}-en.en.sys
242 | grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
243 | | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
244 | ```
245 |
246 | ##### Argument format during inference
247 |
248 | During inference it is required to specify a single `--source-lang` and
249 | `--target-lang`, which indicates the inference langauge direction.
250 | `--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
251 | the same value as training.
--------------------------------------------------------------------------------