├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── cmlm ├── data.py ├── distributed.py ├── model.py └── util.py ├── dump_teacher_hiddens.py ├── dump_teacher_topk.py ├── launch_container.sh ├── run_cmlm_finetuning.py ├── run_mt.sh └── scripts ├── bert_detokenize.py ├── bert_prepro.py ├── bert_tokenize.py ├── download-iwslt_deen.sh ├── prepare-iwslt_deen.sh └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "opennmt"] 2 | path = opennmt 3 | url = git@github.com:chenrocks/distill-bert-textgen-onmt 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.03-py3 2 | 3 | # python dependencies 4 | RUN pip install \ 5 | six==1.12.0 future==0.17.1 configargparse==0.14.0 tensorboardX==1.6 ipdb==0.12 \ 6 | pytorch-pretrained-bert==0.6.1 tqdm==4.30.0 torchtext==0.4.0 7 | 8 | # download pretrained BERT checkpoint 9 | RUN python -c \ 10 | "from pytorch_pretrained_bert import BertTokenizer, BertForPreTraining; m = 'bert-base-multilingual-cased'; BertTokenizer.from_pretrained(m); BertForPreTraining.from_pretrained(m)" 11 | 12 | # moses for MT preprocessing 13 | RUN git clone https://github.com/moses-smt/mosesdecoder.git /workspace/mosesdecoder && \ 14 | cd /workspace/mosesdecoder && git checkout c054501 && rm -r .git && cd - 15 | 16 | WORKDIR /src 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distill-BERT-Textgen 2 | Research code for ACL 2020 paper: "[Distilling Knowledge Learned in BERT for Text Generation](https://arxiv.org/abs/1911.03829)". 3 | 4 | ![Overview](https://acvrpublicycchen.blob.core.windows.net/distill-bert-textgen/overview.png) 5 | 6 | This repository contains the code needed to reproduce our IWSLT De-En experiments. 7 | 8 | ## Setting Up 9 | This repo is tested on Ubuntu 18.04 machine with Nvidia GPU. We do not plan to support other OS or CPU-only machines. 10 | 11 | 1. Prerequisite 12 | - [Docker](https://docs.docker.com/engine/install/ubuntu/) 13 | 14 | you also need to follow [this](https://docs.docker.com/engine/install/linux-postinstall/) to run docker without sudo 15 | - nvidia-driver (we tested on version 418) 16 | ```bash 17 | # reference installation command on Ubuntu 18 | sudo add-apt-repository ppa:graphics-drivers/ppa 19 | sudo apt update 20 | sudo apt install nvidia-driver-418 21 | ``` 22 | - [nvidia-docker](https://github.com/NVIDIA/nvidia-docker#ubuntu-160418042004-debian-jessiestretchbuster) 23 | - clone this repo and its submodule (we use a modified version of OpenNMT-py) 24 | ```bash 25 | git clone --recursive git@github.com:ChenRocks/Distill-BERT-Textgen.git 26 | ``` 27 | 28 | Users can potentially setup non-docker environment following the `Dockerfile` to install python packages and other dependencies. 29 | However, to guarantee reproducibility, it is safest to use our official docker image and we will not provide official support/troubleshooting if you do not use dockerized setup. 30 | (If you absolutely need non-docker install, feel free to discuss in github issue with other users and contribution is welcome.) 31 | 32 | 2. Downloading Data and Preprocessing 33 | 34 | - Run the following command to download raw data and then preprocess 35 | ```bash 36 | source scripts/setup.sh 37 | ``` 38 | and then you should see populated with files of the following structure. 39 | ``` 40 | ├── download 41 | │   ├── de-en 42 | │   └── de-en.tgz 43 | ├── dump 44 | │   └── de-en 45 | │   ├── DEEN.db.bak 46 | │   ├── DEEN.db.dat 47 | │   ├── DEEN.db.dir 48 | │   ├── DEEN.train.0.pt 49 | │   ├── DEEN.valid.0.pt 50 | │   ├── DEEN.vocab.pt 51 | │   ├── dev.de.bert 52 | │   ├── dev.en.bert 53 | │   ├── ref 54 | │   ├── test.de.bert 55 | │   └── test.en.bert 56 | ├── raw 57 | │   └── de-en 58 | └── tmp 59 | └── de-en 60 | ``` 61 | 62 | ## Usage 63 | First, launch the docker container 64 | ```bash 65 | source launch_container.sh 66 | ``` 67 | This will mount /dump (contains preprocessed data), (store experiment outputs), 68 | and the repo itself (so that any code you change is reflected inside the container). 69 | All following commands in this section should be run inside the docker container. 70 | To exit the docker environment, type `exit` or press Ctrl+D. 71 | 72 | 1. Training 73 | 1. C-MLM finetuning 74 | ```bash 75 | python run_cmlm_finetuning.py --train_file /data/de-en/DEEN.db \ 76 | --vocab_file /data/de-en/DEEN.vocab.pt \ 77 | --valid_src /data/de-en/dev.de.bert \ 78 | --valid_tgt /data/de-en/dev.en.bert \ 79 | --bert_model bert-base-multilingual-cased \ 80 | --output_dir /output/ \ 81 | --train_batch_size 16384 \ 82 | --learning_rate 5e-5 \ 83 | --valid_steps 5000 \ 84 | --num_train_steps 100000 \ 85 | --warmup_proportion 0.05 \ 86 | --gradient_accumulation_steps 1 \ 87 | --fp16 88 | ``` 89 | 2. Extract teacher soft label 90 | ```bash 91 | # extract hidden states of teacher 92 | python dump_teacher_hiddens.py \ 93 | --bert bert-base-multilingual-cased \ 94 | --ckpt /output//ckpt/model_step_100000.pt \ 95 | --db /data/de-en/DEEN.db --output /data/de-en/targets/ 96 | 97 | # extract top-k logits 98 | python dump_teacher_topk.py --bert_hidden /data/de-en/targets/ 99 | ``` 100 | 3. Seq2Seq training with KD 101 | ```bash 102 | python opennmt/train.py \ 103 | --bert_kd \ 104 | --bert_dump /data/de-en/targets/ \ 105 | --data_db /data/de-en/DEEN.db \ 106 | -data /data/de-en/DEEN \ 107 | -config opennmt/config/config-transformer-base-mt-deen.yml \ 108 | -learning_rate 2.0 \ 109 | -warmup_steps 8000 \ 110 | --kd_alpha 0.5 \ 111 | --kd_temperature 10.0 \ 112 | --kd_topk 8 \ 113 | --train_steps 100000 \ 114 | -save_model /output/ 115 | ``` 116 | 117 | 118 | 2. Inference and Evaluatation 119 | 120 | The following command will translate the dev split using the 100k step checkpoint, with beam size 5 and length penalty 0.6. 121 | ```bash 122 | ./run_mt.sh /output/ 100000 dev 5 0.6 123 | ``` 124 | Usually the BLEU score correlates well with the accuracy in validation. 125 | The results will be stored at `/output//output/`. 126 | 127 | 128 | ## Misc 129 | - We test on a single Nvidia Titan RTX GPU, which has 24GB of RAM. If you encounter OOM, try 130 | decrease batch size and increase gradient accumulation. 131 | - If you have a multi-GPU machine, use `CUDA_VISIBLE_DEVICES` to sepcify GPU you want to use before 132 | launching the docker container. Otherwise it will use GPU 0 only. 133 | - Feel free to ask questions and discuss in the github issues. 134 | 135 | 136 | 137 | ## Citation 138 | If you find this work helpful to your research, please consider citing: 139 | ``` 140 | @inproceedings{chen2020distilling, 141 | title={Distilling Knowledge Learned in BERT for Text Generation}, 142 | author={Chen, Yen-Chun and Gan, Zhe and Cheng, Yu and Liu, Jingzhou and Liu, Jingjing}, 143 | booktitle={ACL}, 144 | year={2020} 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /cmlm/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | data for C-MLM 6 | """ 7 | import math 8 | import random 9 | import shelve 10 | import warnings 11 | 12 | import torch 13 | from torch.utils.data import Dataset, Sampler 14 | 15 | from toolz.sandbox.core import unzip 16 | 17 | 18 | EOS = '' 19 | IN_WORD = '@@' 20 | UNK = '' 21 | UNK_BERT = '[UNK]' 22 | MASK = '[MASK]' 23 | CLS = '[CLS]' 24 | SEP = '[SEP]' 25 | MOSES_SPECIALS = {'&': '&', '|': '|', '<': '<', '>': '>', 26 | ''': "'", '"': '"', '[': '[', ']': ']', 27 | '@-@': '-'} 28 | 29 | 30 | class BertDataset(Dataset): 31 | def __init__(self, corpus_path, tokenizer, vocab, seq_len, max_len=150): 32 | self.db = shelve.open(corpus_path, 'r') 33 | self.lens = [] 34 | self.ids = [] 35 | for i, example in self.db.items(): 36 | src_len = len(example['src']) 37 | tgt_len = len(example['tgt']) 38 | if (src_len <= max_len and tgt_len <= max_len): 39 | self.ids.append(i) 40 | self.lens.append(min(seq_len, src_len+tgt_len+3)) 41 | self.vocab = vocab # vocab for output seq2seq 42 | self.tokenizer = tokenizer 43 | self.seq_len = seq_len 44 | 45 | def __len__(self): 46 | return len(self.ids) 47 | 48 | def __getitem__(self, i): 49 | id_ = self.ids[i] 50 | item = self.db[id_] 51 | t1, t2 = item['src'], item['tgt'] 52 | 53 | # combine to one sample 54 | cur_example = InputExample(guid=i, tokens_a=t1, tokens_b=t2) 55 | 56 | # transform sample to features 57 | cur_features = convert_example_to_features( 58 | cur_example, self.seq_len, self.tokenizer, self.vocab) 59 | 60 | features = (cur_features.input_ids, cur_features.input_mask, 61 | cur_features.segment_ids, cur_features.lm_label_ids) 62 | 63 | return features 64 | 65 | @staticmethod 66 | def pad_collate(features): 67 | """ pad the input features to same length""" 68 | input_ids, input_masks, segment_ids, lm_label_ids = map( 69 | list, unzip(features)) 70 | max_len = max(map(len, input_ids)) 71 | for ids, masks, segs, labels in zip(input_ids, input_masks, 72 | segment_ids, lm_label_ids): 73 | while len(ids) < max_len: 74 | ids.append(0) 75 | masks.append(0) 76 | segs.append(0) 77 | labels.append(-1) 78 | input_ids = torch.tensor(input_ids) 79 | input_mask = torch.tensor(input_masks) 80 | segment_ids = torch.tensor(segment_ids) 81 | lm_label_ids = torch.tensor(lm_label_ids) 82 | return input_ids, input_mask, segment_ids, lm_label_ids 83 | 84 | 85 | class InputExample(object): 86 | """A single training/test example for the language model.""" 87 | 88 | def __init__(self, guid, tokens_a, tokens_b=None, lm_labels=None): 89 | """Constructs a InputExample. 90 | 91 | Args: 92 | guid: Unique id for the example. 93 | tokens_a: string. The untokenized text of the first sequence. For 94 | single sequence tasks, only this sequence must be specified. 95 | tokens_b: (Optional) string. The untokenized text of the second 96 | sequence. Only must be specified for sequence pair tasks. 97 | label: (Optional) string. The label of the example. This should be 98 | specified for train and dev examples, but not for test examples. 99 | """ 100 | self.guid = guid 101 | self.tokens_a = tokens_a 102 | self.tokens_b = tokens_b 103 | self.lm_labels = lm_labels # masked words for language model 104 | 105 | 106 | class InputFeatures(object): 107 | """A single set of features of data.""" 108 | 109 | def __init__(self, input_ids, input_mask, 110 | segment_ids, lm_label_ids): 111 | self.input_ids = input_ids 112 | self.input_mask = input_mask 113 | self.segment_ids = segment_ids 114 | self.lm_label_ids = lm_label_ids 115 | 116 | 117 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 118 | """Truncates a sequence pair in place to the maximum length.""" 119 | 120 | # This is a simple heuristic which will always truncate the longer sequence 121 | # one token at a time. This makes more sense than truncating an equal 122 | # percent of tokens from each, since if one sequence is very short then 123 | # each token that's truncated likely contains more information than a 124 | # longer sequence. 125 | while True: 126 | total_length = len(tokens_a) + len(tokens_b) 127 | if total_length <= max_length: 128 | break 129 | if len(tokens_a) > len(tokens_b): 130 | tokens_a.pop() 131 | else: 132 | tokens_b.pop() 133 | 134 | 135 | def convert_token_to_bert(token): 136 | bert_token = token.replace(IN_WORD, '') 137 | try: 138 | bert_token = MOSES_SPECIALS[bert_token] # handle moses tokens 139 | except KeyError: 140 | pass 141 | if bert_token == UNK: 142 | # this should only happen with gigaword 143 | bert_token = UNK_BERT 144 | return bert_token 145 | 146 | 147 | def convert_raw_input_to_features(src_line, tgt_line, 148 | toker, vocab, device, p=0.15): 149 | src_toks = [convert_token_to_bert(tok) 150 | for tok in src_line.strip().split()] 151 | tgt_toks = tgt_line.strip().split() 152 | output_labels = [] 153 | for i, tok in enumerate(tgt_toks): 154 | if random.random() < p: 155 | tgt_toks[i] = MASK 156 | output_labels.append(vocab[tok]) 157 | else: 158 | tgt_toks[i] = convert_token_to_bert(tok) 159 | output_labels.append(-1) 160 | if random.random() < p: 161 | tgt_toks.append(MASK) 162 | output_labels.append(vocab[EOS]) 163 | else: 164 | tgt_toks.append(SEP) 165 | output_labels.append(-1) 166 | input_ids = toker.convert_tokens_to_ids( 167 | [CLS] + src_toks + [SEP] + tgt_toks) 168 | type_ids = [0]*(len(src_toks) + 2) + [1]*(len(tgt_toks)) 169 | mask = [1] * len(input_ids) 170 | labels = [-1] * (len(src_toks) + 2) + output_labels 171 | input_ids = torch.LongTensor(input_ids).to(device).unsqueeze(0) 172 | type_ids = torch.LongTensor(type_ids).to(device).unsqueeze(0) 173 | mask = torch.LongTensor(mask).to(device).unsqueeze(0) 174 | labels = torch.LongTensor(labels).to(device).unsqueeze(0) 175 | return input_ids, type_ids, mask, labels 176 | 177 | 178 | def random_word(tokens, output_vocab): 179 | """ 180 | NOTE: this assumes other MT prepro like moses and we try to align 181 | them with BERT 182 | Masking some random tokens for Language Model task with probabilities as in 183 | the original BERT paper. 184 | :param tokens: list of str, tokenized sentence. 185 | :param output_vocab: vocab for seq2seq output 186 | :return: (list of str, list of int), masked tokens and related labels for 187 | LM prediction 188 | """ 189 | output_label = [] 190 | 191 | for i, token in enumerate(tokens): 192 | # mask token with 15% probability 193 | if random.random() < 0.15: 194 | # we always MASK given our purpose 195 | tokens[i] = MASK 196 | 197 | # append current token to output (we will predict these later) 198 | try: 199 | output_label.append(output_vocab[token]) 200 | except KeyError: 201 | # For unknown words (should not occur with BPE vocab) 202 | output_label.append(output_vocab[UNK]) 203 | warnings.warn(f"Cannot find token '{token}' in vocab. Using " 204 | f"{UNK} insetad") 205 | else: 206 | # handle input for BERT 207 | tokens[i] = convert_token_to_bert(token) 208 | 209 | # no masking token (will be ignored by loss function later) 210 | output_label.append(-1) 211 | 212 | # last SEP is used to learn EOS 213 | if random.random() < 0.15: 214 | tokens.append(MASK) 215 | output_label.append(output_vocab[EOS]) 216 | else: 217 | tokens.append(SEP) 218 | output_label.append(-1) 219 | 220 | return tokens, output_label 221 | 222 | 223 | def convert_example_to_features(example, max_seq_length, 224 | tokenizer, output_vocab): 225 | """ 226 | Convert a raw sample (pair of sentences as tokenized strings) into a proper 227 | training sample with IDs, LM labels, input_mask, CLS and SEP tokens etc. 228 | :param example: InputExample, containing sentence input as strings 229 | :param max_seq_length: int, maximum length of sequence. 230 | :param tokenizer: Tokenizer 231 | :param output_vocab: vocab for seq2seq output 232 | :return: InputFeatures, containing all inputs and labels of one sample as 233 | IDs (as used for model training) 234 | """ 235 | tokens_a = example.tokens_a 236 | tokens_b = example.tokens_b 237 | # Modifies `tokens_a` and `tokens_b` in place so that the total 238 | # length is less than the specified length. 239 | # Account for [CLS], [SEP], EOS with "- 3" 240 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 241 | 242 | # convert to BERT compatible inputs 243 | tokens_a = [convert_token_to_bert(tok) for tok in tokens_a] 244 | # only mask sent_b because it's seq2seq problem 245 | t1_label = [-1] * len(tokens_a) 246 | while True: 247 | tokens_b, t2_label = random_word(tokens_b, output_vocab) 248 | if any(label != -1 for label in t2_label): 249 | break 250 | # concatenate lm labels and account for CLS, SEP 251 | lm_label_ids = ([-1] + t1_label + [-1] + t2_label) 252 | 253 | # The convention in BERT is: 254 | # (a) For sequence pairs: 255 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 256 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 257 | # For our MT setup 258 | # (b) For single sequences: 259 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [MASK] 260 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 261 | tokens = [] 262 | segment_ids = [] 263 | tokens.append(CLS) 264 | segment_ids.append(0) 265 | for token in tokens_a: 266 | tokens.append(token) 267 | segment_ids.append(0) 268 | tokens.append(SEP) 269 | segment_ids.append(0) 270 | 271 | assert len(tokens_b) > 0 272 | for token in tokens_b: 273 | tokens.append(token) 274 | segment_ids.append(1) 275 | # NOTE the last SEP is handled differently from original BERT 276 | 277 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 278 | 279 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 280 | # tokens are attended to. 281 | input_mask = [1] * len(input_ids) 282 | 283 | # Zero-pad up to multiples of 8 (for tensor cores) 284 | while len(input_ids) % 8 != 0: 285 | input_ids.append(0) 286 | input_mask.append(0) 287 | segment_ids.append(0) 288 | lm_label_ids.append(-1) 289 | 290 | assert len(input_ids) % 8 == 0 291 | assert (len(input_ids) == len(input_mask) 292 | == len(segment_ids) == len(lm_label_ids)) 293 | 294 | features = InputFeatures(input_ids=input_ids, 295 | input_mask=input_mask, 296 | segment_ids=segment_ids, 297 | lm_label_ids=lm_label_ids) 298 | return features 299 | 300 | 301 | class BucketSampler(Sampler): 302 | def __init__(self, lens, bucket_size, batch_size, droplast=False): 303 | self._lens = lens 304 | self._batch_size = batch_size 305 | self._bucket_size = bucket_size 306 | self._droplast = droplast 307 | 308 | def _create_ids(self): 309 | return list(range(len(self._lens))) 310 | 311 | def _sort_fn(self, i): 312 | return self._lens[i] 313 | 314 | def __iter__(self): 315 | ids = self._create_ids() 316 | random.shuffle(ids) 317 | buckets = [sorted(ids[i:i+self._bucket_size], 318 | key=self._sort_fn, reverse=True) 319 | for i in range(0, len(ids), self._bucket_size)] 320 | batches = [bucket[i:i+self._batch_size] 321 | for bucket in buckets 322 | for i in range(0, len(bucket), self._batch_size)] 323 | if self._droplast: 324 | batches = [batch for batch in batches 325 | if len(batch) == self._batch_size] 326 | random.shuffle(batches) 327 | return iter(batches) 328 | 329 | def __len__(self): 330 | bucket_sizes = ([self._bucket_size] 331 | * (len(self._lens) // self._bucket_size) 332 | + [len(self._lens) % self._bucket_size]) 333 | if self._droplast: 334 | return sum(s//self._batch_size for s in bucket_sizes) 335 | else: 336 | return sum(math.ceil(s/self._batch_size) for s in bucket_sizes) 337 | 338 | 339 | class DistributedBucketSampler(BucketSampler): 340 | def __init__(self, num_replicas, rank, *args, **kwargs): 341 | super().__init__(*args, **kwargs) 342 | self._rank = rank 343 | self._num_replicas = num_replicas 344 | 345 | def _create_ids(self): 346 | return super()._create_ids()[self._rank:-1:self._num_replicas] 347 | 348 | def __len__(self): 349 | num_data = len(self._create_ids()) 350 | bucket_sizes = ([self._bucket_size] 351 | * (num_data // self._bucket_size) 352 | + [num_data % self._bucket_size]) 353 | if self._droplast: 354 | return sum(s//self._batch_size for s in bucket_sizes) 355 | else: 356 | return sum(math.ceil(s/self._batch_size) for s in bucket_sizes) 357 | 358 | 359 | class TokenBucketSampler(Sampler): 360 | def __init__(self, lens, bucket_size, batch_size, droplast=False): 361 | self._lens = lens 362 | self._max_tok = batch_size 363 | self._bucket_size = bucket_size 364 | self._droplast = droplast 365 | 366 | def _create_ids(self): 367 | return list(range(len(self._lens))) 368 | 369 | def _sort_fn(self, i): 370 | return self._lens[i] 371 | 372 | def __iter__(self): 373 | ids = self._create_ids() 374 | random.shuffle(ids) 375 | buckets = [sorted(ids[i:i+self._bucket_size], 376 | key=self._sort_fn, reverse=True) 377 | for i in range(0, len(ids), self._bucket_size)] 378 | # fill batches until max_token (include padding) 379 | batches = [] 380 | for bucket in buckets: 381 | max_len = 0 382 | batch_indices = [] 383 | for index in bucket: 384 | max_len = max(max_len, self._lens[index]) 385 | if max_len * (len(batch_indices) + 1) > self._max_tok: 386 | if not batch_indices: 387 | raise ValueError( 388 | "max_tokens too small / max_seq_len too long") 389 | batches.append(batch_indices) 390 | batch_indices = [index] 391 | else: 392 | batch_indices.append(index) 393 | if not self._droplast and batch_indices: 394 | batches.append(batch_indices) 395 | random.shuffle(batches) 396 | return iter(batches) 397 | 398 | def __len__(self): 399 | raise ValueError("NOT supported. " 400 | "This has some randomness across epochs") 401 | 402 | 403 | class DistributedTokenBucketSampler(TokenBucketSampler): 404 | def __init__(self, num_replicas, rank, *args, **kwargs): 405 | super().__init__(*args, **kwargs) 406 | self._rank = rank 407 | self._num_replicas = num_replicas 408 | 409 | def _create_ids(self): 410 | return super()._create_ids()[self._rank:-1:self._num_replicas] 411 | -------------------------------------------------------------------------------- /cmlm/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | distributed utils not presented in OpenNMT-py 6 | """ 7 | import torch.distributed 8 | 9 | 10 | def broadcast_tensors(tensors, rank=0): 11 | """ broadcast list of tensors at once 12 | this can be used to sync parameter initialization across GPUs 13 | 14 | Args: 15 | tensors: list of Tensors to brodcast 16 | rank: rank to broadcast 17 | """ 18 | # buffer size in bytes, determine equiv. # of elements based on data type 19 | sz = sum(t.numel() for t in tensors) 20 | buffer_t = tensors[0].new(sz).zero_() 21 | 22 | # copy tensors into buffer_t 23 | offset = 0 24 | for t in tensors: 25 | numel = t.numel() 26 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 27 | offset += numel 28 | assert offset == sz 29 | 30 | # broadcast 31 | torch.distributed.broadcast(buffer_t, rank) 32 | 33 | # copy all-reduced buffer back into tensors 34 | offset = 0 35 | for t in tensors: 36 | numel = t.numel() 37 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 38 | offset += numel 39 | assert offset == sz == buffer_t.numel() 40 | -------------------------------------------------------------------------------- /cmlm/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | C-MLM model 6 | """ 7 | import torch 8 | from torch import nn 9 | from pytorch_pretrained_bert.modeling import BertForMaskedLM 10 | 11 | 12 | IN_WORD = '@@' 13 | 14 | 15 | def convert_embedding(toker, vocab, emb_weight): 16 | """ seq2seq vs pretrained BERT embedding conversion""" 17 | vocab_size = emb_weight.size(1) 18 | if vocab_size % 8 != 0: 19 | # pad for tensor cores 20 | vocab_size += (8 - vocab_size % 8) 21 | vectors = [torch.zeros(vocab_size) for _ in range(len(vocab))] 22 | for word, id_ in vocab.items(): 23 | word = word.replace(IN_WORD, '') 24 | if word in toker.vocab: 25 | bert_id = toker.vocab[word] 26 | else: 27 | bert_id = toker.vocab['[UNK]'] 28 | vectors[id_] = emb_weight[bert_id].clone() 29 | embedding = nn.Parameter(torch.stack(vectors, dim=0)) 30 | return embedding 31 | 32 | 33 | class BertForSeq2seq(BertForMaskedLM): 34 | """ 35 | The original output projection is shared w/ embedding. Now for seq2seq, we 36 | use initilization from bert embedding but untied embedding due to 37 | tokenization difference 38 | """ 39 | def __init__(self, config, causal=False): 40 | super().__init__(config) 41 | self.apply(self.init_bert_weights) 42 | 43 | def update_output_layer(self, output_embedding): 44 | self.cls.predictions.decoder.weight = output_embedding 45 | vocab_size = output_embedding.size(0) 46 | self.cls.predictions.bias = nn.Parameter(torch.zeros(vocab_size)) 47 | self.config.vocab_size = vocab_size 48 | 49 | def update_output_layer_by_size(self, vocab_size): 50 | if vocab_size % 8 != 0: 51 | # pad for tensor cores 52 | vocab_size += (8 - vocab_size % 8) 53 | emb_dim = self.cls.predictions.decoder.weight.size(1) 54 | self.cls.predictions.decoder.weight = nn.Parameter( 55 | torch.Tensor(vocab_size, emb_dim)) 56 | self.cls.predictions.bias = nn.Parameter(torch.zeros(vocab_size)) 57 | self.config.vocab_size = vocab_size 58 | 59 | def update_embedding_layer_by_size(self, vocab_size): 60 | if vocab_size % 8 != 0: 61 | # pad for tensor cores 62 | vocab_size += (8 - vocab_size % 8) 63 | emb_dim = self.cls.predictions.decoder.weight.size(1) 64 | self.bert.embeddings.word_embeddings = nn.Embedding( 65 | vocab_size, emb_dim, padding_idx=0) 66 | self.config.vocab_size = vocab_size 67 | 68 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, 69 | masked_lm_labels=None, output_mask=None, do_padding=True): 70 | """ only computes masked logits to save some computation""" 71 | if output_mask is None: 72 | # reduce to normal forward 73 | return super().forward(input_ids, token_type_ids, attention_mask, 74 | masked_lm_labels) 75 | 76 | sequence_output, pooled_output = self.bert( 77 | input_ids, token_type_ids, attention_mask, 78 | output_all_encoded_layers=False) 79 | # only compute masked outputs 80 | output_mask = output_mask.byte() 81 | sequence_output_masked = sequence_output.masked_select( 82 | output_mask.unsqueeze(-1).expand_as(sequence_output) 83 | ).contiguous().view(-1, self.config.hidden_size) 84 | n_pred, hid = sequence_output_masked.size() 85 | if do_padding and (n_pred == 0 or n_pred % 8): 86 | # pad for tensor cores 87 | n_pad = 8 - n_pred % 8 88 | pad = torch.zeros(n_pad, hid, 89 | dtype=sequence_output_masked.dtype, 90 | device=sequence_output_masked.device) 91 | sequence_output_masked = torch.cat( 92 | [sequence_output_masked, pad], dim=0) 93 | else: 94 | n_pad = 0 95 | prediction_scores = self.cls.predictions(sequence_output_masked) 96 | 97 | if masked_lm_labels is not None: 98 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 99 | lm_labels = masked_lm_labels.masked_select(output_mask) 100 | if n_pad != 0: 101 | pad = torch.zeros(n_pad, 102 | dtype=lm_labels.dtype, 103 | device=lm_labels.device).fill_(-1) 104 | lm_labels = torch.cat([lm_labels, pad], dim=0) 105 | masked_lm_loss = loss_fct(prediction_scores, lm_labels) 106 | return masked_lm_loss 107 | else: 108 | return prediction_scores 109 | -------------------------------------------------------------------------------- /cmlm/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for tensorboard logging 6 | """ 7 | import tensorboardX 8 | 9 | 10 | class Logger(object): 11 | def __init__(self): 12 | self._logger = None 13 | self._global_step = 0 14 | 15 | def create(self, path): 16 | self._logger = tensorboardX.SummaryWriter(path) 17 | 18 | def noop(self, *args, **kwargs): 19 | return 20 | 21 | def step(self): 22 | self._global_step += 1 23 | 24 | @property 25 | def global_step(self): 26 | return self._global_step 27 | 28 | def log_scaler_dict(self, log_dict, prefix=''): 29 | """ log a dictionary of scalar values""" 30 | if self._logger is None: 31 | return 32 | if prefix: 33 | prefix = f'{prefix}_' 34 | for name, value in log_dict.items(): 35 | if isinstance(value, dict): 36 | self.log_scaler_dict(value, self._global_step, 37 | prefix=f'{prefix}{name}') 38 | else: 39 | self._logger.add_scalar(f'{prefix}{name}', value, 40 | self._global_step) 41 | 42 | def __getattr__(self, name): 43 | if self._logger is None: 44 | return self.noop 45 | return self._logger.__getattribute__(name) 46 | 47 | 48 | class RunningMeter(object): 49 | """ running meteor of a scalar value 50 | (useful for monitoring training loss) 51 | """ 52 | def __init__(self, name, val=None, smooth=0.99): 53 | self._name = name 54 | self._sm = smooth 55 | self._val = val 56 | 57 | def __call__(self, value): 58 | self._val = (value if self._val is None 59 | else value*(1-self._sm) + self._val*self._sm) 60 | 61 | def __str__(self): 62 | return f'{self._name}: {self._val:.4f}' 63 | 64 | @property 65 | def val(self): 66 | return self._val 67 | -------------------------------------------------------------------------------- /dump_teacher_hiddens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | precompute hidden states of CMLM teacher to speedup KD training 6 | """ 7 | import argparse 8 | import io 9 | import os 10 | import shelve 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import Dataset, DataLoader 15 | from tqdm import tqdm 16 | from pytorch_pretrained_bert import BertTokenizer 17 | from toolz.sandbox import unzip 18 | 19 | from cmlm.model import BertForSeq2seq 20 | from cmlm.data import convert_token_to_bert, CLS, SEP, MASK 21 | 22 | 23 | def tensor_dumps(tensor): 24 | with io.BytesIO() as writer: 25 | np.save(writer, tensor.cpu().numpy().astype(np.float16), 26 | allow_pickle=False) 27 | dump = writer.getvalue() 28 | return dump 29 | 30 | 31 | def gather_hiddens(hiddens, masks): 32 | outputs = [] 33 | for hid, mask in zip(hiddens.split(1, dim=1), masks.split(1, dim=1)): 34 | if mask.sum().item() == 0: 35 | continue 36 | mask = mask.unsqueeze(-1).expand_as(hid) 37 | outputs.append(hid.masked_select(mask)) 38 | output = torch.stack(outputs, dim=0) 39 | return output 40 | 41 | 42 | class BertSampleDataset(Dataset): 43 | def __init__(self, corpus_path, tokenizer, num_samples=7): 44 | self.db = shelve.open(corpus_path, 'r') 45 | self.ids = [] 46 | for i, ex in self.db.items(): 47 | if len(ex['src']) + len(ex['tgt']) + 3 <= 512: 48 | self.ids.append(i) 49 | self.toker = tokenizer 50 | self.num_samples = num_samples 51 | 52 | def __len__(self): 53 | return len(self.ids) 54 | 55 | def __getitem__(self, i): 56 | id_ = self.ids[i] 57 | example = self.db[id_] 58 | features = convert_example(example['src'], example['tgt'], 59 | self.toker, self.num_samples) 60 | return (id_, ) + features 61 | 62 | 63 | def convert_example(src, tgt, toker, num_samples): 64 | src = [convert_token_to_bert(tok) for tok in src] 65 | tgt = [convert_token_to_bert(tok) for tok in tgt] + [SEP] 66 | 67 | # build the random masks 68 | tgt_len = len(tgt) 69 | if tgt_len <= num_samples: 70 | masks = torch.eye(tgt_len).byte() 71 | num_samples = tgt_len 72 | else: 73 | mask_inds = [list(range(i, tgt_len, num_samples)) 74 | for i in range(num_samples)] 75 | masks = torch.zeros(num_samples, tgt_len).byte() 76 | for i, indices in enumerate(mask_inds): 77 | for j in indices: 78 | masks.data[i, j] = 1 79 | assert (masks.sum(dim=0) != torch.ones(tgt_len).long()).sum().item() == 0 80 | assert masks.sum().item() == tgt_len 81 | masks = torch.cat([torch.zeros(num_samples, len(src)+2).byte(), masks], 82 | dim=1) 83 | 84 | # make BERT inputs 85 | input_ids = toker.convert_tokens_to_ids([CLS] + src + [SEP] + tgt) 86 | mask_id = toker.convert_tokens_to_ids([MASK])[0] 87 | input_ids = torch.tensor([input_ids for _ in range(num_samples)]) 88 | input_ids.data.masked_fill_(masks, mask_id) 89 | token_ids = torch.tensor([[0] * (len(src) + 2) + [1] * len(tgt) 90 | for _ in range(num_samples)]) 91 | return input_ids, token_ids, masks 92 | 93 | 94 | def batch_features(features): 95 | ids, all_input_ids, all_token_ids, all_masks = map(list, unzip(features)) 96 | batch_size = sum(input_ids.size(0) for input_ids in all_input_ids) 97 | max_len = max(input_ids.size(1) for input_ids in all_input_ids) 98 | input_ids = torch.zeros(batch_size, max_len).long() 99 | token_ids = torch.zeros(batch_size, max_len).long() 100 | attn_mask = torch.zeros(batch_size, max_len).long() 101 | i = 0 102 | for inp, tok in zip(all_input_ids, all_token_ids): 103 | block, len_ = inp.size() 104 | input_ids.data[i: i+block, :len_] = inp.data 105 | token_ids.data[i: i+block, :len_] = tok.data 106 | attn_mask.data[i: i+block, :len_].fill_(1) 107 | i += block 108 | return ids, input_ids, token_ids, attn_mask, all_masks 109 | 110 | 111 | def process_batch(batch, bert, toker, num_samples=7): 112 | input_ids, token_ids, attn_mask, all_masks = batch 113 | input_ids = input_ids.cuda() 114 | token_ids = token_ids.cuda() 115 | attn_mask = attn_mask.cuda() 116 | hiddens, _ = bert.bert(input_ids, token_ids, attn_mask, 117 | output_all_encoded_layers=False) 118 | hiddens = bert.cls.predictions.transform(hiddens) 119 | i = 0 120 | outputs = [] 121 | for masks in all_masks: 122 | block, len_ = masks.size() 123 | hids = hiddens[i:i+block, :len_, :] 124 | masks = masks.cuda() 125 | outputs.append(gather_hiddens(hids, masks)) 126 | i += block 127 | return outputs 128 | 129 | 130 | def build_db_batched(corpus, out_db, bert, toker, batch_size=8): 131 | dataset = BertSampleDataset(corpus, toker) 132 | loader = DataLoader(dataset, batch_size=batch_size, 133 | num_workers=4, collate_fn=batch_features) 134 | with tqdm(desc='computing BERT features', total=len(dataset)) as pbar: 135 | for ids, *batch in loader: 136 | outputs = process_batch(batch, bert, toker) 137 | for id_, output in zip(ids, outputs): 138 | out_db[id_] = tensor_dumps(output) 139 | pbar.update(len(ids)) 140 | 141 | 142 | def main(opts): 143 | # load BERT 144 | state_dict = torch.load(opts.ckpt) 145 | vsize = state_dict['cls.predictions.decoder.weight'].size(0) 146 | bert = BertForSeq2seq.from_pretrained(opts.bert).eval().half().cuda() 147 | bert.update_output_layer_by_size(vsize) 148 | bert.load_state_dict(state_dict) 149 | toker = BertTokenizer.from_pretrained(opts.bert, 150 | do_lower_case='uncased' in opts.bert) 151 | 152 | # save the final projection layer 153 | linear = torch.nn.Linear(bert.config.hidden_size, bert.config.vocab_size) 154 | linear.weight.data = state_dict['cls.predictions.decoder.weight'] 155 | linear.bias.data = state_dict['cls.predictions.bias'] 156 | os.makedirs(opts.output) 157 | torch.save(linear, f'{opts.output}/linear.pt') 158 | 159 | # create DB 160 | with shelve.open(f'{opts.output}/db') as out_db, \ 161 | torch.no_grad(): 162 | build_db_batched(opts.db, out_db, bert, toker) 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('--bert', required=True, 168 | choices=['bert-base-uncased', 169 | 'bert-base-multilingual-cased'], 170 | help='BERT model') 171 | parser.add_argument('--ckpt', required=True, help='BERT checkpoint') 172 | parser.add_argument('--db', required=True, help='dataset to compute') 173 | parser.add_argument('--output', required=True, help='path to dump output') 174 | args = parser.parse_args() 175 | 176 | main(args) 177 | -------------------------------------------------------------------------------- /dump_teacher_topk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Further pre-compute the top-K prob to save memory 6 | """ 7 | import argparse 8 | import io 9 | import shelve 10 | 11 | import numpy as np 12 | import torch 13 | from tqdm import tqdm 14 | 15 | 16 | def tensor_loads(dump): 17 | with io.BytesIO(dump) as reader: 18 | obj = np.load(reader, allow_pickle=False) 19 | if isinstance(obj, np.ndarray): 20 | tensor = obj 21 | else: 22 | tensor = obj['arr_0'] 23 | return tensor 24 | 25 | 26 | def dump_topk(topk): 27 | logit, index = topk 28 | with io.BytesIO() as writer: 29 | torch.save((logit.cpu(), index.cpu()), writer) 30 | dump = writer.getvalue() 31 | return dump 32 | 33 | 34 | def main(opts): 35 | linear = torch.load(f'{opts.bert_hidden}/linear.pt').cuda() 36 | with shelve.open(f'{opts.bert_hidden}/db', 'r') as db, \ 37 | shelve.open(f'{opts.bert_hidden}/topk', 'c') as topk_db: 38 | for key, value in tqdm(db.items(), 39 | total=len(db), desc='computing topk...'): 40 | bert_hidden = torch.tensor(tensor_loads(value)).cuda() 41 | topk = linear(bert_hidden).topk(dim=-1, k=opts.topk) 42 | dump = dump_topk(topk) 43 | topk_db[key] = dump 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--bert_hidden', required=True, 49 | help='path to saved bert hidden') 50 | parser.add_argument('--topk', type=int, default=128, 51 | help='topk logits to pre-compute (can extract larger ' 52 | 'K and then set to smaller value at training)') 53 | args = parser.parse_args() 54 | main(args) 55 | -------------------------------------------------------------------------------- /launch_container.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DATA=$1 5 | OUTPUT=$2 6 | 7 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 8 | CUDA_VISIBLE_DEVICES='all' 9 | fi 10 | 11 | if [ ! -d $OUTPUT ]; then 12 | mkdir -p $OUTPUT 13 | fi 14 | 15 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ 16 | --mount src=$(pwd),dst=/src,type=bind \ 17 | --mount src=$OUTPUT,dst=/output,type=bind \ 18 | --mount src=$DATA/dump,dst=/data,type=bind \ 19 | chenrocks/distill-bert-textgen:latest 20 | -------------------------------------------------------------------------------- /run_cmlm_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | # modified from hugginface github 5 | 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. 7 | # team. 8 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | """C-MLM finetuning runner.""" 22 | import argparse 23 | import copy 24 | import json 25 | import logging 26 | import os 27 | from os.path import abspath, dirname, exists, join 28 | import random 29 | import subprocess 30 | from time import time 31 | 32 | import numpy as np 33 | import torch 34 | from torch.utils.data import DataLoader 35 | from tqdm import tqdm 36 | 37 | from pytorch_pretrained_bert.tokenization import BertTokenizer 38 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 39 | 40 | from cmlm.data import (BertDataset, TokenBucketSampler, 41 | DistributedTokenBucketSampler, 42 | convert_raw_input_to_features) 43 | from cmlm.model import convert_embedding, BertForSeq2seq 44 | from cmlm.util import Logger, RunningMeter 45 | from cmlm.distributed import broadcast_tensors 46 | 47 | # add opennmt to python module search path 48 | # other than distributed utils, this is also needed to load onmt vocab file 49 | import sys 50 | sys.path.insert(0, '/src/opennmt') 51 | from onmt.utils.distributed import (all_reduce_and_rescale_tensors, 52 | all_gather_list) 53 | 54 | 55 | logging.basicConfig( 56 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 57 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 58 | logger = logging.getLogger(__name__) 59 | 60 | TB_LOGGER = Logger() 61 | 62 | 63 | def noam_schedule(step, warmup_step=4000): 64 | if step <= warmup_step: 65 | return step / warmup_step 66 | return (warmup_step ** 0.5) * (step ** -0.5) 67 | 68 | 69 | def main(opts): 70 | if opts.local_rank == -1: 71 | assert torch.cuda.is_available() 72 | device = torch.device("cuda") 73 | n_gpu = 1 74 | else: 75 | torch.cuda.set_device(opts.local_rank) 76 | device = torch.device("cuda", opts.local_rank) 77 | # Initializes the distributed backend which will take care of 78 | # sychronizing nodes/GPUs 79 | torch.distributed.init_process_group(backend='nccl') 80 | n_gpu = torch.distributed.get_world_size() 81 | logger.info("device: {} n_gpu: {}, distributed training: {}, " 82 | "16-bits training: {}".format( 83 | device, n_gpu, bool(opts.local_rank != -1), opts.fp16)) 84 | opts.n_gpu = n_gpu 85 | 86 | if opts.gradient_accumulation_steps < 1: 87 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " 88 | "should be >= 1".format( 89 | opts.gradient_accumulation_steps)) 90 | 91 | is_master = opts.local_rank == -1 or torch.distributed.get_rank() == 0 92 | 93 | if is_master: 94 | save_training_meta(opts) 95 | 96 | random.seed(opts.seed) 97 | np.random.seed(opts.seed) 98 | torch.manual_seed(opts.seed) 99 | if n_gpu > 0: 100 | torch.cuda.manual_seed_all(opts.seed) 101 | 102 | tokenizer = BertTokenizer.from_pretrained( 103 | opts.bert_model, do_lower_case='uncased' in opts.bert_model) 104 | 105 | # train_examples = None 106 | print("Loading Train Dataset", opts.train_file) 107 | vocab_dump = torch.load(opts.vocab_file) 108 | vocab = vocab_dump['tgt'].fields[0][1].vocab.stoi 109 | train_dataset = BertDataset(opts.train_file, tokenizer, vocab, 110 | seq_len=opts.max_seq_length, 111 | max_len=opts.max_sent_length) 112 | 113 | # Prepare model 114 | model = BertForSeq2seq.from_pretrained(opts.bert_model) 115 | embedding = convert_embedding( 116 | tokenizer, vocab, model.bert.embeddings.word_embeddings.weight) 117 | model.update_output_layer(embedding) 118 | if opts.fp16: 119 | model.half() 120 | model.to(device) 121 | if opts.local_rank != -1: 122 | # need to make sure models are the same in the beginning 123 | params = [p.data for p in model.parameters()] 124 | broadcast_tensors(params) 125 | for name, module in model.named_modules(): 126 | # we might want to tune dropout for smaller dataset 127 | if isinstance(module, torch.nn.Dropout): 128 | module.p = opts.dropout 129 | 130 | # Prepare optimizer 131 | param_optimizer = [(n, p) for n, p in model.named_parameters() 132 | if 'pooler' not in n] 133 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 134 | optimizer_grouped_parameters = [ 135 | {'params': [p for n, p in param_optimizer 136 | if not any(nd in n for nd in no_decay)], 137 | 'weight_decay': 0.01}, 138 | {'params': [p for n, p in param_optimizer 139 | if any(nd in n for nd in no_decay)], 140 | 'weight_decay': 0.0} 141 | ] 142 | 143 | if opts.fp16: 144 | try: 145 | from apex.optimizers import FP16_Optimizer 146 | from apex.optimizers import FusedAdam 147 | except ImportError: 148 | raise ImportError("Please install apex from " 149 | "https://www.github.com/nvidia/apex " 150 | "to use distributed and fp16 training.") 151 | 152 | optimizer = FusedAdam(optimizer_grouped_parameters, 153 | lr=opts.learning_rate, 154 | bias_correction=False, 155 | max_grad_norm=1.0) 156 | if opts.loss_scale == 0: 157 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 158 | else: 159 | optimizer = FP16_Optimizer(optimizer, 160 | static_loss_scale=opts.loss_scale) 161 | 162 | else: 163 | optimizer = BertAdam(optimizer_grouped_parameters, 164 | lr=opts.learning_rate, 165 | warmup=opts.warmup_proportion, 166 | t_total=opts.num_train_steps) 167 | 168 | global_step = 0 169 | logger.info("***** Running training *****") 170 | logger.info(" Num examples = %d", len(train_dataset)) 171 | logger.info(" Batch size = %d", opts.train_batch_size) 172 | logger.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) 173 | logger.info(" Num steps = %d", opts.num_train_steps) 174 | 175 | if opts.local_rank == -1: 176 | train_sampler = TokenBucketSampler( 177 | train_dataset.lens, 178 | bucket_size=8192, 179 | batch_size=opts.train_batch_size, 180 | droplast=True) 181 | train_dataloader = DataLoader(train_dataset, 182 | batch_sampler=train_sampler, 183 | num_workers=4, 184 | collate_fn=BertDataset.pad_collate) 185 | else: 186 | train_sampler = DistributedTokenBucketSampler( 187 | n_gpu, opts.local_rank, 188 | train_dataset.lens, 189 | bucket_size=8192, 190 | batch_size=opts.train_batch_size, 191 | droplast=True) 192 | train_dataloader = DataLoader(train_dataset, 193 | batch_sampler=train_sampler, 194 | num_workers=4, 195 | collate_fn=BertDataset.pad_collate) 196 | 197 | if is_master: 198 | TB_LOGGER.create(join(opts.output_dir, 'log')) 199 | running_loss = RunningMeter('loss') 200 | model.train() 201 | if is_master: 202 | pbar = tqdm(total=opts.num_train_steps) 203 | else: 204 | logger.disabled = True 205 | pbar = None 206 | n_examples = 0 207 | n_epoch = 0 208 | start = time() 209 | while True: 210 | for step, batch in enumerate(train_dataloader): 211 | batch = tuple(t.to(device) if t is not None else t for t in batch) 212 | input_ids, input_mask, segment_ids, lm_label_ids = batch 213 | n_examples += input_ids.size(0) 214 | mask = lm_label_ids != -1 215 | loss = model(input_ids, segment_ids, input_mask, 216 | lm_label_ids, mask, True) 217 | if opts.fp16: 218 | optimizer.backward(loss) 219 | else: 220 | loss.backward() 221 | running_loss(loss.item()) 222 | if (step + 1) % opts.gradient_accumulation_steps == 0: 223 | global_step += 1 224 | if opts.fp16: 225 | # modify learning rate with special warm up BERT uses 226 | # if opts.fp16 is False, BertAdam is used that handles 227 | # this automatically 228 | lr_this_step = opts.learning_rate * warmup_linear( 229 | global_step/opts.num_train_steps, 230 | opts.warmup_proportion) 231 | if lr_this_step < 0: 232 | # save guard for possible miscalculation of train steps 233 | lr_this_step = 1e-8 234 | for param_group in optimizer.param_groups: 235 | param_group['lr'] = lr_this_step 236 | TB_LOGGER.add_scalar('lr', 237 | lr_this_step, global_step) 238 | 239 | # NOTE running loss not gathered across GPUs for speed 240 | TB_LOGGER.add_scalar('loss', running_loss.val, global_step) 241 | TB_LOGGER.step() 242 | 243 | if opts.local_rank != -1: 244 | # gather gradients from every processes 245 | grads = [p.grad.data for p in model.parameters() 246 | if p.requires_grad and p.grad is not None] 247 | all_reduce_and_rescale_tensors(grads, float(1)) 248 | optimizer.step() 249 | optimizer.zero_grad() 250 | if pbar is not None: 251 | pbar.update(1) 252 | if global_step % 5 == 0: 253 | torch.cuda.empty_cache() 254 | if global_step % 100 == 0: 255 | if opts.local_rank != -1: 256 | total = sum(all_gather_list(n_examples)) 257 | else: 258 | total = n_examples 259 | if is_master: 260 | ex_per_sec = int(total / (time()-start)) 261 | logger.info(f'{total} examples trained at ' 262 | f'{ex_per_sec} ex/s') 263 | TB_LOGGER.add_scalar('ex_per_s', ex_per_sec, global_step) 264 | 265 | if global_step % opts.valid_steps == 0: 266 | logger.info(f"start validation at Step {global_step}") 267 | with torch.no_grad(): 268 | val_log = validate(model, 269 | opts.valid_src, opts.valid_tgt, 270 | tokenizer, vocab, device, 271 | opts.local_rank) 272 | logger.info(f"Val Acc: {val_log['val_acc']}; " 273 | f"Val Loss: {val_log['val_loss']}") 274 | TB_LOGGER.log_scaler_dict(val_log) 275 | if is_master: 276 | output_model_file = join( 277 | opts.output_dir, 'ckpt', 278 | f"model_step_{global_step}.pt") 279 | # save cpu checkpoint 280 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) 281 | else v 282 | for k, v in model.state_dict().items()} 283 | torch.save(state_dict, output_model_file) 284 | if global_step >= opts.num_train_steps: 285 | break 286 | if global_step >= opts.num_train_steps: 287 | break 288 | n_epoch += 1 289 | if is_master: 290 | logger.info(f"finished {n_epoch} epochs") 291 | if opts.num_train_steps % opts.valid_steps != 0: 292 | with torch.no_grad(): 293 | val_log = validate(model, opts.valid_src, opts.valid_tgt, 294 | tokenizer, vocab, device, opts.local_rank) 295 | TB_LOGGER.log_scaler_dict(val_log) 296 | if is_master: 297 | output_model_file = join(opts.output_dir, 'ckpt', 298 | f"model_step_{global_step}.pt") 299 | # save cpu checkpoint 300 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) 301 | else v 302 | for k, v in model.state_dict().items()} 303 | torch.save(model.state_dict(), output_model_file) 304 | 305 | 306 | def validate(model, valid_src, valid_tgt, toker, vocab, device, local_rank): 307 | model.eval() 308 | val_loss = 0 309 | n_correct = 0 310 | n_word = 0 311 | with open(valid_src, 'r') as src_reader, \ 312 | open(valid_tgt, 'r') as tgt_reader: 313 | for i, (src, tgt) in enumerate(zip(src_reader, tgt_reader)): 314 | if local_rank != -1: 315 | global_rank = torch.distributed.get_rank() 316 | world_size = torch.distributed.get_world_size() 317 | if global_rank % world_size != 0: 318 | continue 319 | input_ids, type_ids, mask, labels = convert_raw_input_to_features( 320 | src, tgt, toker, vocab, device) 321 | prediction_scores = model(input_ids, type_ids, mask) 322 | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1, 323 | reduction='sum') 324 | loss = loss_fct(prediction_scores.squeeze(0), labels.view(-1)) 325 | val_loss += loss.item() 326 | n_correct += accuracy_count(prediction_scores, labels) 327 | n_word += (labels != -1).long().sum().item() 328 | if local_rank != -1: 329 | val_loss = sum(all_gather_list(val_loss)) 330 | n_correct = sum(all_gather_list(n_correct)) 331 | n_word = sum(all_gather_list(n_word)) 332 | val_loss /= n_word 333 | acc = n_correct / n_word 334 | val_log = {'val_loss': val_loss, 'val_acc': acc} 335 | model.train() 336 | return val_log 337 | 338 | 339 | def accuracy_count(out, labels): 340 | outputs = out.max(dim=-1)[1] 341 | mask = labels != -1 342 | n_correct = (outputs == labels).masked_select(mask).sum().item() 343 | return n_correct 344 | 345 | 346 | def save_training_meta(opts): 347 | if not exists(opts.output_dir): 348 | os.makedirs(join(opts.output_dir, 'log')) 349 | os.makedirs(join(opts.output_dir, 'ckpt')) 350 | 351 | with open(join(opts.output_dir, 'log', 'hps.json'), 'w') as writer: 352 | hps = copy.deepcopy(vars(opts)) 353 | del hps['local_rank'] 354 | json.dump(hps, writer, indent=4) 355 | # git info 356 | try: 357 | logger.info("Waiting on git info....") 358 | c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], 359 | timeout=10, stdout=subprocess.PIPE) 360 | git_branch_name = c.stdout.decode().strip() 361 | logger.info("Git branch: %s", git_branch_name) 362 | c = subprocess.run(["git", "rev-parse", "HEAD"], 363 | timeout=10, stdout=subprocess.PIPE) 364 | git_sha = c.stdout.decode().strip() 365 | logger.info("Git SHA: %s", git_sha) 366 | git_dir = abspath(dirname(__file__)) 367 | git_status = subprocess.check_output( 368 | ['git', 'status', '--short'], 369 | cwd=git_dir, universal_newlines=True).strip() 370 | with open(join(opts.output_dir, 'log', 'git_info.json'), 371 | 'w') as writer: 372 | json.dump({'branch': git_branch_name, 373 | 'is_dirty': bool(git_status), 374 | 'status': git_status, 375 | 'sha': git_sha}, 376 | writer, indent=4) 377 | except subprocess.TimeoutExpired as e: 378 | logger.exception(e) 379 | logger.warn("Git info not found. Moving right along...") 380 | 381 | 382 | if __name__ == "__main__": 383 | parser = argparse.ArgumentParser() 384 | 385 | # Required parameters 386 | parser.add_argument("--train_file", default=None, type=str, required=True, 387 | help="The input train corpus. (shelve DB)") 388 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 389 | help="seq2seq output vocab") 390 | parser.add_argument("--valid_src", default=None, type=str, required=True, 391 | help="source line txt for validation") 392 | parser.add_argument("--valid_tgt", default=None, type=str, required=True, 393 | help="target line txt for validation") 394 | 395 | parser.add_argument( 396 | "--bert_model", default=None, type=str, required=True, 397 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 398 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, " 399 | "bert-base-chinese.") 400 | parser.add_argument( 401 | "--output_dir", default=None, type=str, required=True, 402 | help="The output directory where the model checkpoints will be " 403 | "written.") 404 | 405 | # Other parameters 406 | parser.add_argument( 407 | "--max_seq_length", default=256, type=int, 408 | help="The maximum total input sequence length after WordPiece " 409 | "tokenization. \nSequences longer than this will be truncated, " 410 | "and sequences shorter \nthan this will be padded.") 411 | parser.add_argument("--max_sent_length", default=256, type=int, 412 | help="The maximum number of tokens in a sentence") 413 | parser.add_argument("--train_batch_size", default=512, type=int, 414 | help="Total batch size for training. " 415 | "(batch by tokens)") 416 | parser.add_argument("--learning_rate", default=3e-5, type=float, 417 | help="The initial learning rate for Adam.") 418 | parser.add_argument("--valid_steps", default=1000, type=int, 419 | help="Run validation every X steps") 420 | parser.add_argument("--num_train_steps", default=100000, type=int, 421 | help="Total number of training updates to perform.") 422 | parser.add_argument("--dropout", default=0.1, type=float, 423 | help="tune dropout regularization") 424 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 425 | help="Proportion of training to perform linear " 426 | "learning rate warmup for. (linear decay)" 427 | "E.g., 0.1 = 10%% of training.") 428 | parser.add_argument("--local_rank", type=int, default=-1, 429 | help="local_rank for distributed training on gpus") 430 | parser.add_argument('--seed', type=int, default=42, 431 | help="random seed for initialization") 432 | parser.add_argument('--gradient_accumulation_steps', type=int, default=16, 433 | help="Number of updates steps to accumualte before " 434 | "performing a backward/update pass.") 435 | parser.add_argument('--fp16', 436 | action='store_true', 437 | help="Whether to use 16-bit float precision instead " 438 | "of 32-bit") 439 | parser.add_argument('--loss_scale', type=float, default=0, 440 | help="Loss scaling to improve fp16 numeric stability. " 441 | "Only used when fp16 set to True.\n" 442 | "0 (default value): dynamic loss scaling.\n" 443 | "Positive power of 2: static loss scaling " 444 | "value.\n") 445 | 446 | args = parser.parse_args() 447 | 448 | if exists(args.output_dir) and os.listdir(args.output_dir): 449 | raise ValueError("Output directory ({}) already exists and is not " 450 | "empty.".format(args.output_dir)) 451 | 452 | main(args) 453 | -------------------------------------------------------------------------------- /run_mt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | set -x 5 | set -e 6 | export PYTHONUNBUFFERED=1 7 | 8 | MODEL=$1 9 | CKPT=$2 10 | SPLIT=$3 # dev/test 11 | BEAM=$4 # beam size in beam search 12 | ALPHA=$5 # length penalty 13 | 14 | # German to English (IWSLT 15) 15 | DATAPATH=/data/de-en 16 | SRC=$DATAPATH/$SPLIT.de.bert 17 | TGT=$DATAPATH/$SPLIT.en.bert 18 | REF=$DATAPATH/ref/$SPLIT.en 19 | TGT_LANG='en' 20 | 21 | OUT_PATH=${MODEL}/output 22 | EXP="ckpt_${CKPT}-beam_${BEAM}-alpha_${ALPHA}.${SPLIT}" 23 | GPUID=0 24 | 25 | echo "running IWSLT De-En translation with beam size ${BEAM}, length penalty ${ALPHA}" 26 | if [ ! -d "$OUT_PATH" ]; then 27 | mkdir $OUT_PATH 28 | fi 29 | 30 | # run inference 31 | python opennmt/translate.py -gpu ${GPUID} \ 32 | -model ${MODEL}/ckpt/model_step_${CKPT}.pt \ 33 | -src ${SRC} \ 34 | -tgt ${TGT} \ 35 | -output ${OUT_PATH}/${EXP}.${TGT_LANG} \ 36 | -log_file ${OUT_PATH}/${EXP}.log \ 37 | -beam_size ${BEAM} -alpha ${ALPHA} \ 38 | -length_penalty wu \ 39 | -replace_unk -verbose -fp32 40 | 41 | # detokenize BERT BPE 42 | python scripts/bert_detokenize.py --file ${OUT_PATH}/${EXP}.${TGT_LANG} \ 43 | --output_dir ${OUT_PATH} 44 | 45 | # evaluation 46 | perl opennmt/tools/multi-bleu.perl $REF \ 47 | < ${OUT_PATH}/${EXP}.${TGT_LANG}.detok \ 48 | | tee ${OUT_PATH}/${EXP}.bleu 49 | -------------------------------------------------------------------------------- /scripts/bert_detokenize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | detokenize output due to BERT preprocessing 6 | """ 7 | import argparse 8 | import os 9 | from os.path import basename, exists, join 10 | 11 | import ipdb 12 | from tqdm import tqdm 13 | 14 | IN_WORD = '@@' 15 | BERT_IN_WORD = '##' 16 | 17 | # special chars in moses tokenizer 18 | MOSES_SPECIALS = {'|': '|', '<': '<', '>': '>', 19 | "'": ''', '"': '"', '[': '[', ']': ']'} 20 | AMP = '&' 21 | AMP_MOSES = '&' 22 | UNK = '' 23 | 24 | 25 | def convert_moses(tok): 26 | if tok in MOSES_SPECIALS: 27 | return MOSES_SPECIALS[tok] 28 | return tok 29 | 30 | 31 | def detokenize(line, moses=True): 32 | word = '' 33 | words = [] 34 | for tok in line.split(): 35 | if tok.startswith(IN_WORD): 36 | tok = tok[2:] 37 | if tok.startswith(BERT_IN_WORD): 38 | tok = tok[2:] 39 | tok = tok.replace(AMP, AMP_MOSES) 40 | if moses: 41 | tok = convert_moses(tok) 42 | word += tok 43 | else: 44 | if tok.startswith(BERT_IN_WORD): 45 | ipdb.set_trace() 46 | raise ValueError() 47 | words.append(word) 48 | tok = tok.replace(AMP, AMP_MOSES) 49 | if moses: 50 | tok = convert_moses(tok) 51 | word = tok 52 | words.append(word) 53 | text = ' '.join(words).strip() 54 | return text 55 | 56 | 57 | def process(reader, writer, unk, moses=True): 58 | for line in tqdm(reader, desc='tokenizing'): 59 | output = detokenize(line, moses) 60 | output = output.replace(UNK, unk) # UNK format change 61 | writer.write(output + '\n') 62 | 63 | 64 | def main(opts): 65 | if not exists(opts.output_dir): 66 | os.makedirs(opts.output_dir) 67 | output_file = join(opts.output_dir, f'{basename(opts.file)}.detok') 68 | with open(opts.file, 'r') as reader, open(output_file, 'w') as writer: 69 | process(reader, writer, opts.unk, opts.moses) 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--file', action='store', required=True, 75 | help='line by line text file for data') 76 | parser.add_argument('--output_dir', action='store', required=True, 77 | help='path to output') 78 | parser.add_argument('--unk', action='store', default='UNK', 79 | choices=['UNK', ''], 80 | help='gigaword dev and test has different UNK') 81 | parser.add_argument('--no-moses', action='store_true', 82 | help='turn off moses sepcial character mapping ' 83 | '(for gigaword)') 84 | args = parser.parse_args() 85 | args.moses = not args.no_moses 86 | main(args) 87 | -------------------------------------------------------------------------------- /scripts/bert_prepro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | preprocess text files for C-MLM finetuning 6 | """ 7 | import argparse 8 | import shelve 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | def make_db(src_reader, tgt_reader, db): 14 | print() 15 | for i, (src, tgt) in tqdm(enumerate(zip(src_reader, tgt_reader))): 16 | src_toks = src.strip().split() 17 | tgt_toks = tgt.strip().split() 18 | if src_toks and tgt_toks: 19 | dump = {'src': src_toks, 'tgt': tgt_toks} 20 | db[str(i)] = dump 21 | 22 | 23 | def main(args): 24 | # process the dataset 25 | with open(args.src) as src_reader, open(args.tgt) as tgt_reader, \ 26 | shelve.open(args.output, 'n') as db: 27 | make_db(src_reader, tgt_reader, db) 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--src', action='store', required=True, 33 | help='line by line text file for source data ') 34 | parser.add_argument('--tgt', action='store', required=True, 35 | help='line by line text file for target data ') 36 | parser.add_argument('--output', action='store', required=True, 37 | help='path to output') 38 | args = parser.parse_args() 39 | main(args) 40 | -------------------------------------------------------------------------------- /scripts/bert_tokenize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | use BERT tokenizer to process seq2seq data 6 | """ 7 | import argparse 8 | import glob 9 | import gzip 10 | import multiprocessing as mp 11 | import os 12 | from os.path import basename, exists, join 13 | 14 | from pytorch_pretrained_bert import BertTokenizer 15 | from tqdm import tqdm 16 | from cytoolz import curry, partition_all 17 | 18 | 19 | IN_WORD = '@@' # This prefix is used for reconstructing the original 20 | # tokenization after generation. (BERT tokenizer does not 21 | # preserve white spaces) 22 | # it seems not conflicting for the corpus we test on 23 | 24 | # special chars in moses tokenizer 25 | MOSES_SPECIALS = {'&': '&', '|': '|', '<': '<', '>': '>', 26 | ''': "'", '"': '"', '[': '[', ']': ']'} 27 | HYPHEN = '@-@' 28 | 29 | UNK = '' 30 | 31 | BUF = 65536 32 | CHUNK = 4096 33 | 34 | 35 | @curry 36 | def tokenize(bert_toker, line): 37 | if IN_WORD in line: 38 | # safe guard if the corpus cotains the IN_WORD tag 39 | raise ValueError() 40 | line = line.strip() 41 | # Gigaword test set 42 | line = line.replace(' UNK ', f' {UNK} ') 43 | if line.startswith('UNK '): 44 | line = UNK + line[3:] 45 | if line.endswith(' UNK'): 46 | line = line[:-3] + UNK 47 | 48 | words = [] 49 | for word in line.split(): 50 | if word[0] == '&': 51 | for special, char in MOSES_SPECIALS.items(): 52 | if word.startswith(special): 53 | words.append(char) 54 | words.append(IN_WORD+word[len(special):]) 55 | break 56 | else: 57 | raise ValueError() 58 | else: 59 | words.append(word) 60 | 61 | tokens = [] 62 | for word in words: 63 | if word == UNK: 64 | tokens.append(word) 65 | elif word == HYPHEN: 66 | tokens.append(word) 67 | elif word.startswith(IN_WORD): 68 | tokens.extend(IN_WORD+tok 69 | for tok in bert_toker.tokenize(word[len(IN_WORD):])) 70 | else: 71 | tokens.extend(tok if i == 0 else IN_WORD+tok 72 | for i, tok in enumerate(bert_toker.tokenize(word))) 73 | return tokens 74 | 75 | 76 | def write(writer, tokens): 77 | writer.write(' '.join(tokens) + '\n') 78 | 79 | 80 | def process(reader, writer, tokenizer): 81 | with mp.Pool() as pool, tqdm(desc='tokenizing') as pbar: 82 | for lines in partition_all(BUF, reader): 83 | for tokens in pool.imap(tokenize(tokenizer), lines, 84 | chunksize=CHUNK): 85 | write(writer, tokens) 86 | pbar.update(len(lines)) 87 | 88 | 89 | def main(opts): 90 | tokenizer = BertTokenizer.from_pretrained( 91 | opts.bert, do_lower_case='uncased' in opts.bert) 92 | for prefix in opts.prefixes: 93 | input_files = glob.glob(f'{prefix}*') 94 | if not exists(opts.output_dir): 95 | os.makedirs(opts.output_dir) 96 | for input_file in input_files: 97 | if input_file.endswith('.gz'): 98 | out_name = basename(input_file)[:-3] 99 | reader = gzip.open(input_file, 'rt') 100 | else: 101 | out_name = basename(input_file) 102 | reader = open(input_file, 'r') 103 | output_file = join(opts.output_dir, f'{out_name}.bert') 104 | with open(output_file, 'w') as writer: 105 | process(reader, writer, tokenizer) 106 | reader.close() 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--bert', action='store', default='bert-large-uncased', 112 | help='bert model to use') 113 | parser.add_argument('--prefixes', action='store', required=True, nargs='+', 114 | help='line by line text file for data ' 115 | '(will apply to all prefix)') 116 | parser.add_argument('--output_dir', action='store', required=True, 117 | help='path to output') 118 | args = parser.parse_args() 119 | main(args) 120 | -------------------------------------------------------------------------------- /scripts/download-iwslt_deen.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT license. 5 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 6 | 7 | set -x 8 | set -e 9 | 10 | RAW=$1 11 | TMP=$2 12 | 13 | 14 | SCRIPTS=/workspace/mosesdecoder/scripts 15 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 16 | LC=$SCRIPTS/tokenizer/lowercase.perl 17 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 18 | 19 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 20 | GZ="de-en.tgz" 21 | 22 | src=de 23 | tgt=en 24 | lang=de-en 25 | 26 | prep=$RAW/de-en 27 | orig=$TMP 28 | 29 | 30 | mkdir -p $orig $prep 31 | cd $orig 32 | 33 | if [ -f $GZ ]; then 34 | echo "$file already exists, skipping download" 35 | else 36 | echo "Downloading data from ${URL}..." 37 | wget "$URL" 38 | if [ -f $GZ ]; then 39 | echo "Data successfully downloaded." 40 | else 41 | echo "Data not successfully downloaded." 42 | exit 43 | fi 44 | tar zxvf $GZ 45 | fi 46 | cd - 47 | 48 | if [ -f $prep/train.en ] && [ -f $prep/train.de ] && \ 49 | [ -f $prep/valid.en ] && [ -f $prep/valid.de ] && \ 50 | [ -f $prep/test.en ] && [ -f $prep/test.de ]; then 51 | echo "iwslt dataset is already preprocessed, skip" 52 | else 53 | echo "pre-processing train data..." 54 | for l in $src $tgt; do 55 | f=train.tags.$lang.$l 56 | tok=train.tags.$lang.tok.$l 57 | 58 | cat $orig/$lang/$f | \ 59 | grep -v '' | \ 60 | grep -v '' | \ 61 | grep -v '' | \ 62 | sed -e 's///g' | \ 63 | sed -e 's/<\/title>//g' | \ 64 | sed -e 's/<description>//g' | \ 65 | sed -e 's/<\/description>//g' | \ 66 | perl $TOKENIZER -threads 8 -l $l > $prep/$tok 67 | echo "" 68 | done 69 | perl $CLEAN -ratio 1.5 $prep/train.tags.$lang.tok $src $tgt $prep/train.tags.$lang.clean 1 175 70 | for l in $src $tgt; do 71 | perl $LC < $prep/train.tags.$lang.clean.$l > $prep/train.tags.$lang.$l 72 | done 73 | 74 | echo "pre-processing valid/test data..." 75 | for l in $src $tgt; do 76 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 77 | fname=${o##*/} 78 | f=$prep/${fname%.*} 79 | echo $o $f 80 | grep '<seg id' $o | \ 81 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 82 | sed -e 's/\s*<\/seg>\s*//g' | \ 83 | sed -e "s/\’/\'/g" | \ 84 | perl $TOKENIZER -threads 8 -l $l | \ 85 | perl $LC > $f 86 | echo "" 87 | done 88 | done 89 | 90 | 91 | echo "creating train, valid, test..." 92 | for l in $src $tgt; do 93 | awk '{if (NR%23 == 0) print $0; }' $prep/train.tags.de-en.$l > $prep/valid.$l 94 | awk '{if (NR%23 != 0) print $0; }' $prep/train.tags.de-en.$l > $prep/train.$l 95 | 96 | cat $prep/IWSLT14.TED.dev2010.de-en.$l \ 97 | $prep/IWSLT14.TEDX.dev2012.de-en.$l \ 98 | $prep/IWSLT14.TED.tst2010.de-en.$l \ 99 | $prep/IWSLT14.TED.tst2011.de-en.$l \ 100 | $prep/IWSLT14.TED.tst2012.de-en.$l \ 101 | > $prep/test.$l 102 | done 103 | fi 104 | -------------------------------------------------------------------------------- /scripts/prepare-iwslt_deen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | set -x 5 | set -e 6 | export PYTHONUNBUFFERED=1 7 | 8 | RAW=/data/raw 9 | TMP=/data/tmp 10 | DUMP=/data/dump 11 | DOWNLOAD=/data/download 12 | 13 | # download 14 | echo "===========================================" 15 | bash /src/scripts/download-iwslt_deen.sh $RAW $DOWNLOAD 16 | 17 | RAW=$RAW/de-en 18 | TMP=$TMP/de-en 19 | DUMP=$DUMP/de-en 20 | 21 | 22 | # BERT tokenization 23 | python /src/scripts/bert_tokenize.py \ 24 | --bert bert-base-multilingual-cased \ 25 | --prefixes $RAW/train.en $RAW/train.de $RAW/valid $RAW/test \ 26 | --output_dir $TMP 27 | 28 | 29 | # prepare bert teacher training dataset 30 | mkdir -p $DUMP 31 | python /src/scripts/bert_prepro.py --src $TMP/train.de.bert \ 32 | --tgt $TMP/train.en.bert \ 33 | --output $DUMP/DEEN.db 34 | 35 | # OpenNMT preprocessing 36 | VSIZE=200000 37 | FREQ=0 38 | SHARD_SIZE=200000 39 | python /src/opennmt/preprocess.py \ 40 | -train_src $TMP/train.de.bert \ 41 | -train_tgt $TMP/train.en.bert \ 42 | -valid_src $TMP/valid.de.bert \ 43 | -valid_tgt $TMP/valid.en.bert \ 44 | -save_data $DUMP/DEEN \ 45 | -src_seq_length 150 \ 46 | -tgt_seq_length 150 \ 47 | -src_vocab_size $VSIZE \ 48 | -tgt_vocab_size $VSIZE \ 49 | -vocab_size_multiple 8 \ 50 | -src_words_min_frequency $FREQ \ 51 | -tgt_words_min_frequency $FREQ \ 52 | -share_vocab \ 53 | -shard_size $SHARD_SIZE 54 | 55 | 56 | # move needed files to dump 57 | mv $TMP/valid.en.bert $DUMP/dev.en.bert 58 | mv $TMP/valid.de.bert $DUMP/dev.de.bert 59 | mv $TMP/test.en.bert $DUMP/test.en.bert 60 | mv $TMP/test.de.bert $DUMP/test.de.bert 61 | REFDIR=$DUMP/ref/ 62 | mkdir -p $REFDIR 63 | cp $RAW/valid.en $REFDIR/dev.en 64 | cp $RAW/test.en $REFDIR/test.en 65 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DATA=$1 5 | 6 | if [ ! -d $DATA ]; then 7 | mkdir -p $DATA 8 | fi 9 | 10 | CMD='./scripts/prepare-iwslt_deen.sh' 11 | 12 | docker run --rm \ 13 | --mount src=$(pwd),dst=/src,type=bind \ 14 | --mount src=$DATA,dst=/data,type=bind \ 15 | chenrocks/distill-bert-textgen:latest \ 16 | bash -c $CMD 17 | --------------------------------------------------------------------------------