├── data └── test.txt ├── modeling.pyc ├── __pycache__ ├── predict.cpython-36.pyc ├── modeling.cpython-36.pyc ├── optimization.cpython-36.pyc └── tokenization.cpython-36.pyc ├── requirements.txt ├── predict.sh ├── README.md ├── train.sh ├── __init__.py ├── intent.py ├── optimization.py ├── multilingual.md ├── tokenization.py ├── extract_features.py ├── create_pretraining_data.py ├── run_pretraining.py ├── predict.py ├── modeling.py └── run_classifier.py /data/test.txt: -------------------------------------------------------------------------------- 1 | 应有人说存在侵权嫌隙数据已删除! 2 | -------------------------------------------------------------------------------- /modeling.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengming617/bert_textMatching/HEAD/modeling.pyc -------------------------------------------------------------------------------- /__pycache__/predict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengming617/bert_textMatching/HEAD/__pycache__/predict.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengming617/bert_textMatching/HEAD/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengming617/bert_textMatching/HEAD/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengming617/bert_textMatching/HEAD/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | pandas -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3.6 run_classifier.py \ 3 | --task_name=sim \ 4 | --do_predict=true \ 5 | --data_dir=data \ 6 | --vocab_file=chinese_L-12_H-768_A-12/vocab.txt \ 7 | --bert_config_file=chinese_L-12_H-768_A-12/bert_config.json \ 8 | --init_checkpoint=tmp/sim_model \ 9 | --max_seq_length=50 \ 10 | --output_dir=tmp/output 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 利用bert预训练的中文语言模型进行文本匹配 2 | 3 | 训练脚本 train.sh 4 | 5 | 批量测试脚本 predict.sh 6 | 7 | 利用模型进行单条语句测试 intent.py 8 | 9 | chinese_L-12_H-768_A-12为预训练的相关模型和词典 10 | 11 | data文件夹中为训练语料,验证语料,测试语料 数据集为LCQMC官方数据 12 | 13 | 14 | 参数说明: 15 | max_seq_length sentence的最大长度(字) 16 | train_batch_size batch_size的大小 17 | 18 | max_seq_length = 50 19 | eval_accuracy = 0.87207 20 | test_accuracy = 0.86272 21 | 22 | max_seq_length = 40 23 | eval_accuracy = 0.88093615 24 | test_accuracy = 0.86256 25 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3.6 run_classifier.py \ 4 | --data_dir=data \ 5 | --task_name=sim \ 6 | --vocab_file=chinese_L-12_H-768_A-12/vocab.txt \ 7 | --bert_config_file=chinese_L-12_H-768_A-12/bert_config.json \ 8 | --output_dir=tmp/sim_model \ 9 | --do_train=true \ 10 | --do_eval=true \ 11 | --init_checkpoint=chinese_L-12_H-768_A-12/bert_model.ckpt \ 12 | --max_seq_length=50 \ 13 | --train_batch_size=32 \ 14 | --learning_rate=5e-5 \ 15 | --num_train_epochs=3.0 16 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /intent.py: -------------------------------------------------------------------------------- 1 | from predict import predicts 2 | 3 | sentences = [['长的清新是什么意思', '小清新的意思是什么']] 4 | for sentence in sentences: 5 | dic = predicts([sentence]) 6 | print(dic) 7 | 8 | file = open('data/test.txt', 'r') 9 | f = open('data/erro.txt', 'w') 10 | sentences = [] 11 | tag = [] 12 | for line in file.readlines(): 13 | data = line.replace("\n", "").split("\t") 14 | sentences.append([data[0], data[1]]) 15 | tag.append(data[2]) 16 | resule = predicts(sentences) 17 | erro = 0 18 | for x in range(len(resule)): 19 | if resule[x][0] != tag[x]: 20 | erro += 1 21 | print(sentences[x]) 22 | f.writelines(sentences[x][0]+"\t"+sentences[x][1]+"\t"+resule[x][0]+"\t"+str(resule[x][1])+"\n") 23 | print("test数据集的accuracy为:"+str(1-erro/len(resule))) 24 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual Cased (New, recommended)`](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)**: 8 | 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 10 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 11 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 12 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 13 | parameters 14 | 15 | **The `Multilingual Cased (New)` model also fixes normalization issues in many 16 | languages, so it is recommended in languages with non-Latin alphabets (and is 17 | often better for most languages with Latin alphabets). When using this model, 18 | make sure to pass `--do_lower_case=false` to `run_pretraining.py` and other 19 | scripts.** 20 | 21 | See the [list of languages](#list-of-languages) that the Multilingual model 22 | supports. The Multilingual model does include Chinese (and English), but if your 23 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 24 | better results. 25 | 26 | ## Results 27 | 28 | To evaluate these systems, we use the 29 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 30 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 31 | dev and test sets have been translated (by humans) into 15 languages. Note that 32 | the training set was *machine* translated (we used the translations provided by 33 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 34 | 35 | 36 | 37 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 38 | | --------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 39 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 40 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 41 | | BERT - Translate Train Cased | **81.9** | **76.6** | **77.8** | **75.9** | **70.7** | 61.6 | 42 | | BERT - Translate Train Uncased | 81.4 | 74.2 | 77.3 | 75.2 | 70.5 | 61.7 | 43 | | BERT - Translate Test Uncased | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 44 | | BERT - Zero Shot Uncased | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 45 | 46 | 47 | 48 | The first two rows are baselines from the XNLI paper and the last three rows are 49 | our results with BERT. 50 | 51 | **Translate Train** means that the MultiNLI training set was machine translated 52 | from English into the foreign language. So training and evaluation were both 53 | done in the foreign language. Unfortunately, training was done on 54 | machine-translated data, so it is impossible to quantify how much of the lower 55 | accuracy (compared to English) is due to the quality of the machine translation 56 | vs. the quality of the pre-trained model. 57 | 58 | **Translate Test** means that the XNLI test set was machine translated from the 59 | foreign language into English. So training and evaluation were both done on 60 | English. However, test evaluation was done on machine-translated English, so the 61 | accuracy depends on the quality of the machine translation system. 62 | 63 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 64 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 65 | machine translation was not involved at all in either the pre-training or 66 | fine-tuning. 67 | 68 | Note that the English result is worse than the 84.2 MultiNLI baseline because 69 | this training used Multilingual BERT rather than English-only BERT. This implies 70 | that for high-resource languages, the Multilingual model is somewhat worse than 71 | a single-language model. However, it is not feasible for us to train and 72 | maintain dozens of single-language model. Therefore, if your goal is to maximize 73 | performance with a language other than English or Chinese, you might find it 74 | beneficial to run pre-training for additional steps starting from our 75 | Multilingual model on data from your language of interest. 76 | 77 | Here is a comparison of training Chinese models with the Multilingual 78 | `BERT-Base` and Chinese-only `BERT-Base`: 79 | 80 | System | Chinese 81 | ----------------------- | ------- 82 | XNLI Baseline | 67.0 83 | BERT Multilingual Model | 74.2 84 | BERT Chinese-only Model | 77.2 85 | 86 | Similar to English, the single-language model does 3% better than the 87 | Multilingual model. 88 | 89 | ## Fine-tuning Example 90 | 91 | The multilingual model does **not** require any special consideration or API 92 | changes. We did update the implementation of `BasicTokenizer` in 93 | `tokenization.py` to support Chinese character tokenization, so please update if 94 | you forked it. However, we did not change the tokenization API. 95 | 96 | To test the new models, we did modify `run_classifier.py` to add support for the 97 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 98 | version of MultiNLI where the dev/test sets have been human-translated, and the 99 | training set has been machine-translated. 100 | 101 | To run the fine-tuning code, please download the 102 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 103 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 104 | and then unpack both .zip files into some directory `$XNLI_DIR`. 105 | 106 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 107 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 108 | another language. 109 | 110 | This is a large dataset, so this will training will take a few hours on a GPU 111 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 112 | debugging, just set `num_train_epochs` to a small value like `0.1`. 113 | 114 | ```shell 115 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 116 | export XNLI_DIR=/path/to/xnli 117 | 118 | python run_classifier.py \ 119 | --task_name=XNLI \ 120 | --do_train=true \ 121 | --do_eval=true \ 122 | --data_dir=$XNLI_DIR \ 123 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 124 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 125 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 126 | --max_seq_length=128 \ 127 | --train_batch_size=32 \ 128 | --learning_rate=5e-5 \ 129 | --num_train_epochs=2.0 \ 130 | --output_dir=/tmp/xnli_output/ 131 | ``` 132 | 133 | With the Chinese-only model, the results should look something like this: 134 | 135 | ``` 136 | ***** Eval results ***** 137 | eval_accuracy = 0.774116 138 | eval_loss = 0.83554 139 | global_step = 24543 140 | loss = 0.74603 141 | ``` 142 | 143 | ## Details 144 | 145 | ### Data Source and Sampling 146 | 147 | The languages chosen were the 148 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 149 | The entire Wikipedia dump for each language (excluding user and talk pages) was 150 | taken as the training data for each language 151 | 152 | However, the size of the Wikipedia for a given language varies greatly, and 153 | therefore low-resource languages may be "under-represented" in terms of the 154 | neural network model (under the assumption that languages are "competing" for 155 | limited model capacity to some extent). 156 | 157 | However, the size of a Wikipedia also correlates with the number of speakers of 158 | a language, and we also don't want to overfit the model by performing thousands 159 | of epochs over a tiny Wikipedia for a particular language. 160 | 161 | To balance these two factors, we performed exponentially smoothed weighting of 162 | the data during pre-training data creation (and WordPiece vocab creation). In 163 | other words, let's say that the probability of a language is *P(L)*, e.g., 164 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 165 | together, 21% of our data is English. We exponentiate each probability by some 166 | factor *S* and then re-normalize, and sample from that distribution. In our case 167 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 168 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 169 | original distribution English would be sampled 1000x more than Icelandic, but 170 | after smoothing it's only sampled 100x more. 171 | 172 | ### Tokenization 173 | 174 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 175 | weighted the same way as the data, so low-resource languages are upweighted by 176 | some factor. We intentionally do *not* use any marker to denote the input 177 | language (so that zero-shot training can work). 178 | 179 | Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace 180 | characters, we add spaces around every character in the 181 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 182 | before applying WordPiece. This means that Chinese is effectively 183 | character-tokenized. Note that the CJK Unicode block only includes 184 | Chinese-origin characters and does *not* include Hangul Korean or 185 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 186 | all other languages. 187 | 188 | For all other languages, we apply the 189 | [same recipe as English](https://github.com/google-research/bert#tokenization): 190 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 191 | tokenization. We understand that accent markers have substantial meaning in some 192 | languages, but felt that the benefits of reducing the effective vocabulary make 193 | up for this. Generally the strong contextual models of BERT should make up for 194 | any ambiguity introduced by stripping accent markers. 195 | 196 | ### List of Languages 197 | 198 | The multilingual model supports the following languages. These languages were 199 | chosen because they are the top 100 languages with the largest Wikipedias: 200 | 201 | * Afrikaans 202 | * Albanian 203 | * Arabic 204 | * Aragonese 205 | * Armenian 206 | * Asturian 207 | * Azerbaijani 208 | * Bashkir 209 | * Basque 210 | * Bavarian 211 | * Belarusian 212 | * Bengali 213 | * Bishnupriya Manipuri 214 | * Bosnian 215 | * Breton 216 | * Bulgarian 217 | * Burmese 218 | * Catalan 219 | * Cebuano 220 | * Chechen 221 | * Chinese (Simplified) 222 | * Chinese (Traditional) 223 | * Chuvash 224 | * Croatian 225 | * Czech 226 | * Danish 227 | * Dutch 228 | * English 229 | * Estonian 230 | * Finnish 231 | * French 232 | * Galician 233 | * Georgian 234 | * German 235 | * Greek 236 | * Gujarati 237 | * Haitian 238 | * Hebrew 239 | * Hindi 240 | * Hungarian 241 | * Icelandic 242 | * Ido 243 | * Indonesian 244 | * Irish 245 | * Italian 246 | * Japanese 247 | * Javanese 248 | * Kannada 249 | * Kazakh 250 | * Kirghiz 251 | * Korean 252 | * Latin 253 | * Latvian 254 | * Lithuanian 255 | * Lombard 256 | * Low Saxon 257 | * Luxembourgish 258 | * Macedonian 259 | * Malagasy 260 | * Malay 261 | * Malayalam 262 | * Marathi 263 | * Minangkabau 264 | * Nepali 265 | * Newar 266 | * Norwegian (Bokmal) 267 | * Norwegian (Nynorsk) 268 | * Occitan 269 | * Persian (Farsi) 270 | * Piedmontese 271 | * Polish 272 | * Portuguese 273 | * Punjabi 274 | * Romanian 275 | * Russian 276 | * Scots 277 | * Serbian 278 | * Serbo-Croatian 279 | * Sicilian 280 | * Slovak 281 | * Slovenian 282 | * South Azerbaijani 283 | * Spanish 284 | * Sundanese 285 | * Swahili 286 | * Swedish 287 | * Tagalog 288 | * Tajik 289 | * Tamil 290 | * Tatar 291 | * Telugu 292 | * Turkish 293 | * Ukrainian 294 | * Urdu 295 | * Uzbek 296 | * Vietnamese 297 | * Volapük 298 | * Waray-Waray 299 | * Welsh 300 | * West Frisian 301 | * Western Punjabi 302 | * Yoruba 303 | 304 | The **Multilingual Cased (New)** release contains additionally **Thai** and 305 | **Mongolian**, which were not included in the original release. 306 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tensorflow as tf 24 | import tokenization 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("input_file", None, 31 | "Input raw text file (or comma-separated list of files).") 32 | 33 | flags.DEFINE_string( 34 | "output_file", None, 35 | "Output TF example file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("vocab_file", None, 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_bool( 41 | "do_lower_case", True, 42 | "Whether to lower case the input text. Should be True for uncased " 43 | "models and False for cased models.") 44 | 45 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 46 | 47 | flags.DEFINE_integer("max_predictions_per_seq", 20, 48 | "Maximum number of masked LM predictions per sequence.") 49 | 50 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 51 | 52 | flags.DEFINE_integer( 53 | "dupe_factor", 10, 54 | "Number of times to duplicate the input data (with different masks).") 55 | 56 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 57 | 58 | flags.DEFINE_float( 59 | "short_seq_prob", 0.1, 60 | "Probability of creating sequences which are shorter than the " 61 | "maximum length.") 62 | 63 | 64 | class TrainingInstance(object): 65 | """A single training instance (sentence pair).""" 66 | 67 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 68 | is_random_next): 69 | self.tokens = tokens 70 | self.segment_ids = segment_ids 71 | self.is_random_next = is_random_next 72 | self.masked_lm_positions = masked_lm_positions 73 | self.masked_lm_labels = masked_lm_labels 74 | 75 | def __str__(self): 76 | s = "" 77 | s += "tokens: %s\n" % (" ".join( 78 | [tokenization.printable_text(x) for x in self.tokens])) 79 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 80 | s += "is_random_next: %s\n" % self.is_random_next 81 | s += "masked_lm_positions: %s\n" % (" ".join( 82 | [str(x) for x in self.masked_lm_positions])) 83 | s += "masked_lm_labels: %s\n" % (" ".join( 84 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 85 | s += "\n" 86 | return s 87 | 88 | def __repr__(self): 89 | return self.__str__() 90 | 91 | 92 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 93 | max_predictions_per_seq, output_files): 94 | """Create TF example files from `TrainingInstance`s.""" 95 | writers = [] 96 | for output_file in output_files: 97 | writers.append(tf.python_io.TFRecordWriter(output_file)) 98 | 99 | writer_index = 0 100 | 101 | total_written = 0 102 | for (inst_index, instance) in enumerate(instances): 103 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 104 | input_mask = [1] * len(input_ids) 105 | segment_ids = list(instance.segment_ids) 106 | assert len(input_ids) <= max_seq_length 107 | 108 | while len(input_ids) < max_seq_length: 109 | input_ids.append(0) 110 | input_mask.append(0) 111 | segment_ids.append(0) 112 | 113 | assert len(input_ids) == max_seq_length 114 | assert len(input_mask) == max_seq_length 115 | assert len(segment_ids) == max_seq_length 116 | 117 | masked_lm_positions = list(instance.masked_lm_positions) 118 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 119 | masked_lm_weights = [1.0] * len(masked_lm_ids) 120 | 121 | while len(masked_lm_positions) < max_predictions_per_seq: 122 | masked_lm_positions.append(0) 123 | masked_lm_ids.append(0) 124 | masked_lm_weights.append(0.0) 125 | 126 | next_sentence_label = 1 if instance.is_random_next else 0 127 | 128 | features = collections.OrderedDict() 129 | features["input_ids"] = create_int_feature(input_ids) 130 | features["input_mask"] = create_int_feature(input_mask) 131 | features["segment_ids"] = create_int_feature(segment_ids) 132 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 133 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 134 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 135 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 136 | 137 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 138 | 139 | writers[writer_index].write(tf_example.SerializeToString()) 140 | writer_index = (writer_index + 1) % len(writers) 141 | 142 | total_written += 1 143 | 144 | if inst_index < 20: 145 | tf.logging.info("*** Example ***") 146 | tf.logging.info("tokens: %s" % " ".join( 147 | [tokenization.printable_text(x) for x in instance.tokens])) 148 | 149 | for feature_name in features.keys(): 150 | feature = features[feature_name] 151 | values = [] 152 | if feature.int64_list.value: 153 | values = feature.int64_list.value 154 | elif feature.float_list.value: 155 | values = feature.float_list.value 156 | tf.logging.info( 157 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 158 | 159 | for writer in writers: 160 | writer.close() 161 | 162 | tf.logging.info("Wrote %d total instances", total_written) 163 | 164 | 165 | def create_int_feature(values): 166 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 167 | return feature 168 | 169 | 170 | def create_float_feature(values): 171 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 172 | return feature 173 | 174 | 175 | def create_training_instances(input_files, tokenizer, max_seq_length, 176 | dupe_factor, short_seq_prob, masked_lm_prob, 177 | max_predictions_per_seq, rng): 178 | """Create `TrainingInstance`s from raw text.""" 179 | all_documents = [[]] 180 | 181 | # Input file format: 182 | # (1) One sentence per line. These should ideally be actual sentences, not 183 | # entire paragraphs or arbitrary spans of text. (Because we use the 184 | # sentence boundaries for the "next sentence prediction" task). 185 | # (2) Blank lines between documents. Document boundaries are needed so 186 | # that the "next sentence prediction" task doesn't span between documents. 187 | for input_file in input_files: 188 | with tf.gfile.GFile(input_file, "r") as reader: 189 | while True: 190 | line = tokenization.convert_to_unicode(reader.readline()) 191 | if not line: 192 | break 193 | line = line.strip() 194 | 195 | # Empty lines are used as document delimiters 196 | if not line: 197 | all_documents.append([]) 198 | tokens = tokenizer.tokenize(line) 199 | if tokens: 200 | all_documents[-1].append(tokens) 201 | 202 | # Remove empty documents 203 | all_documents = [x for x in all_documents if x] 204 | rng.shuffle(all_documents) 205 | 206 | vocab_words = list(tokenizer.vocab.keys()) 207 | instances = [] 208 | for _ in range(dupe_factor): 209 | for document_index in range(len(all_documents)): 210 | instances.extend( 211 | create_instances_from_document( 212 | all_documents, document_index, max_seq_length, short_seq_prob, 213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 214 | 215 | rng.shuffle(instances) 216 | return instances 217 | 218 | 219 | def create_instances_from_document( 220 | all_documents, document_index, max_seq_length, short_seq_prob, 221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 222 | """Creates `TrainingInstance`s for a single document.""" 223 | document = all_documents[document_index] 224 | 225 | # Account for [CLS], [SEP], [SEP] 226 | max_num_tokens = max_seq_length - 3 227 | 228 | # We *usually* want to fill up the entire sequence since we are padding 229 | # to `max_seq_length` anyways, so short sequences are generally wasted 230 | # computation. However, we *sometimes* 231 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 232 | # sequences to minimize the mismatch between pre-training and fine-tuning. 233 | # The `target_seq_length` is just a rough target however, whereas 234 | # `max_seq_length` is a hard limit. 235 | target_seq_length = max_num_tokens 236 | if rng.random() < short_seq_prob: 237 | target_seq_length = rng.randint(2, max_num_tokens) 238 | 239 | # We DON'T just concatenate all of the tokens from a document into a long 240 | # sequence and choose an arbitrary split point because this would make the 241 | # next sentence prediction task too easy. Instead, we split the input into 242 | # segments "A" and "B" based on the actual "sentences" provided by the user 243 | # input. 244 | instances = [] 245 | current_chunk = [] 246 | current_length = 0 247 | i = 0 248 | while i < len(document): 249 | segment = document[i] 250 | current_chunk.append(segment) 251 | current_length += len(segment) 252 | if i == len(document) - 1 or current_length >= target_seq_length: 253 | if current_chunk: 254 | # `a_end` is how many segments from `current_chunk` go into the `A` 255 | # (first) sentence. 256 | a_end = 1 257 | if len(current_chunk) >= 2: 258 | a_end = rng.randint(1, len(current_chunk) - 1) 259 | 260 | tokens_a = [] 261 | for j in range(a_end): 262 | tokens_a.extend(current_chunk[j]) 263 | 264 | tokens_b = [] 265 | # Random next 266 | is_random_next = False 267 | if len(current_chunk) == 1 or rng.random() < 0.5: 268 | is_random_next = True 269 | target_b_length = target_seq_length - len(tokens_a) 270 | 271 | # This should rarely go for more than one iteration for large 272 | # corpora. However, just to be careful, we try to make sure that 273 | # the random document is not the same as the document 274 | # we're processing. 275 | for _ in range(10): 276 | random_document_index = rng.randint(0, len(all_documents) - 1) 277 | if random_document_index != document_index: 278 | break 279 | 280 | random_document = all_documents[random_document_index] 281 | random_start = rng.randint(0, len(random_document) - 1) 282 | for j in range(random_start, len(random_document)): 283 | tokens_b.extend(random_document[j]) 284 | if len(tokens_b) >= target_b_length: 285 | break 286 | # We didn't actually use these segments so we "put them back" so 287 | # they don't go to waste. 288 | num_unused_segments = len(current_chunk) - a_end 289 | i -= num_unused_segments 290 | # Actual next 291 | else: 292 | is_random_next = False 293 | for j in range(a_end, len(current_chunk)): 294 | tokens_b.extend(current_chunk[j]) 295 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 296 | 297 | assert len(tokens_a) >= 1 298 | assert len(tokens_b) >= 1 299 | 300 | tokens = [] 301 | segment_ids = [] 302 | tokens.append("[CLS]") 303 | segment_ids.append(0) 304 | for token in tokens_a: 305 | tokens.append(token) 306 | segment_ids.append(0) 307 | 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | for token in tokens_b: 312 | tokens.append(token) 313 | segment_ids.append(1) 314 | tokens.append("[SEP]") 315 | segment_ids.append(1) 316 | 317 | (tokens, masked_lm_positions, 318 | masked_lm_labels) = create_masked_lm_predictions( 319 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 320 | instance = TrainingInstance( 321 | tokens=tokens, 322 | segment_ids=segment_ids, 323 | is_random_next=is_random_next, 324 | masked_lm_positions=masked_lm_positions, 325 | masked_lm_labels=masked_lm_labels) 326 | instances.append(instance) 327 | current_chunk = [] 328 | current_length = 0 329 | i += 1 330 | 331 | return instances 332 | 333 | 334 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 335 | ["index", "label"]) 336 | 337 | 338 | def create_masked_lm_predictions(tokens, masked_lm_prob, 339 | max_predictions_per_seq, vocab_words, rng): 340 | """Creates the predictions for the masked LM objective.""" 341 | 342 | cand_indexes = [] 343 | for (i, token) in enumerate(tokens): 344 | if token == "[CLS]" or token == "[SEP]": 345 | continue 346 | cand_indexes.append(i) 347 | 348 | rng.shuffle(cand_indexes) 349 | 350 | output_tokens = list(tokens) 351 | 352 | num_to_predict = min(max_predictions_per_seq, 353 | max(1, int(round(len(tokens) * masked_lm_prob)))) 354 | 355 | masked_lms = [] 356 | covered_indexes = set() 357 | for index in cand_indexes: 358 | if len(masked_lms) >= num_to_predict: 359 | break 360 | if index in covered_indexes: 361 | continue 362 | covered_indexes.add(index) 363 | 364 | masked_token = None 365 | # 80% of the time, replace with [MASK] 366 | if rng.random() < 0.8: 367 | masked_token = "[MASK]" 368 | else: 369 | # 10% of the time, keep original 370 | if rng.random() < 0.5: 371 | masked_token = tokens[index] 372 | # 10% of the time, replace with random word 373 | else: 374 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 375 | 376 | output_tokens[index] = masked_token 377 | 378 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 379 | 380 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 381 | 382 | masked_lm_positions = [] 383 | masked_lm_labels = [] 384 | for p in masked_lms: 385 | masked_lm_positions.append(p.index) 386 | masked_lm_labels.append(p.label) 387 | 388 | return (output_tokens, masked_lm_positions, masked_lm_labels) 389 | 390 | 391 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 392 | """Truncates a pair of sequences to a maximum sequence length.""" 393 | while True: 394 | total_length = len(tokens_a) + len(tokens_b) 395 | if total_length <= max_num_tokens: 396 | break 397 | 398 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 399 | assert len(trunc_tokens) >= 1 400 | 401 | # We want to sometimes truncate from the front and sometimes from the 402 | # back to add more randomness and avoid biases. 403 | if rng.random() < 0.5: 404 | del trunc_tokens[0] 405 | else: 406 | trunc_tokens.pop() 407 | 408 | 409 | def main(_): 410 | tf.logging.set_verbosity(tf.logging.INFO) 411 | 412 | tokenizer = tokenization.FullTokenizer( 413 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 414 | 415 | input_files = [] 416 | for input_pattern in FLAGS.input_file.split(","): 417 | input_files.extend(tf.gfile.Glob(input_pattern)) 418 | 419 | tf.logging.info("*** Reading from input files ***") 420 | for input_file in input_files: 421 | tf.logging.info(" %s", input_file) 422 | 423 | rng = random.Random(FLAGS.random_seed) 424 | instances = create_training_instances( 425 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 426 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 427 | rng) 428 | 429 | output_files = FLAGS.output_file.split(",") 430 | tf.logging.info("*** Writing to output files ***") 431 | for output_file in output_files: 432 | tf.logging.info(" %s", output_file) 433 | 434 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 435 | FLAGS.max_predictions_per_seq, output_files) 436 | 437 | 438 | if __name__ == "__main__": 439 | flags.mark_flag_as_required("input_file") 440 | flags.mark_flag_as_required("output_file") 441 | flags.mark_flag_as_required("vocab_file") 442 | tf.app.run() 443 | -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | import pandas as pd 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string( 36 | "data_dir", 'data', 37 | "The input data dir. Should contain the .tsv files (or other data files) " 38 | "for the task.") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", 'chinese_L-12_H-768_A-12/bert_config.json', 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_string("task_name", 'sim', "The name of the task to train.") 46 | 47 | flags.DEFINE_string("vocab_file", 'chinese_L-12_H-768_A-12/vocab.txt', 48 | "The vocabulary file that the BERT model was trained on.") 49 | 50 | flags.DEFINE_string( 51 | "output_dir", 'tmp/output', 52 | "The output directory where the model checkpoints will be written.") 53 | 54 | ## Other parameters 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", 'tmp/sim_model', 58 | "Initial checkpoint (usually from a pre-trained BERT model).") 59 | 60 | flags.DEFINE_bool( 61 | "do_lower_case", True, 62 | "Whether to lower case the input text. Should be True for uncased " 63 | "models and False for cased models.") 64 | 65 | flags.DEFINE_integer( 66 | "max_seq_length", 50, 67 | "The maximum total input sequence length after WordPiece tokenization. " 68 | "Sequences longer than this will be truncated, and sequences shorter " 69 | "than this will be padded.") 70 | 71 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 72 | 73 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 74 | 75 | flags.DEFINE_bool( 76 | "do_predict", True, 77 | "Whether to run the model in inference mode on the test set.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 80 | 81 | flags.DEFINE_integer("eval_batch_size", 16, "Total batch size for eval.") 82 | 83 | flags.DEFINE_integer("predict_batch_size", 16, "Total batch size for predict.") 84 | 85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 86 | 87 | flags.DEFINE_float("num_train_epochs", 3.0, 88 | "Total number of training epochs to perform.") 89 | 90 | flags.DEFINE_float( 91 | "warmup_proportion", 0.1, 92 | "Proportion of training to perform linear learning rate warmup for. " 93 | "E.g., 0.1 = 10% of training.") 94 | 95 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 96 | "How often to save the model checkpoint.") 97 | 98 | flags.DEFINE_integer("iterations_per_loop", 1000, 99 | "How many steps to make in each estimator call.") 100 | 101 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | 128 | class InputExample(object): 129 | """A single training/test example for simple sequence classification.""" 130 | 131 | def __init__(self, guid, text_a, text_b=None, label=None): 132 | """Constructs a InputExample. 133 | 134 | Args: 135 | guid: Unique id for the example. 136 | text_a: string. The untokenized text of the first sequence. For single 137 | sequence tasks, only this sequence must be specified. 138 | text_b: (Optional) string. The untokenized text of the second sequence. 139 | Only must be specified for sequence pair tasks. 140 | label: (Optional) string. The label of the example. This should be 141 | specified for train and dev examples, but not for test examples. 142 | """ 143 | self.guid = guid 144 | self.text_a = text_a 145 | self.text_b = text_b 146 | self.label = label 147 | 148 | 149 | class PaddingInputExample(object): 150 | """Fake example so the num input examples is a multiple of the batch size. 151 | 152 | When running eval/predict on the TPU, we need to pad the number of examples 153 | to be a multiple of the batch size, because the TPU requires a fixed batch 154 | size. The alternative is to drop the last batch, which is bad because it means 155 | the entire output data won't be generated. 156 | 157 | We use this class instead of `None` because treating `None` as padding 158 | battches could cause silent errors. 159 | """ 160 | 161 | 162 | class InputFeatures(object): 163 | """A single set of features of data.""" 164 | 165 | def __init__(self, 166 | input_ids, 167 | input_mask, 168 | segment_ids, 169 | # label_id, 170 | is_real_example=True): 171 | self.input_ids = input_ids 172 | self.input_mask = input_mask 173 | self.segment_ids = segment_ids 174 | # self.label_id = label_id 175 | self.is_real_example = is_real_example 176 | 177 | 178 | class DataProcessor(object): 179 | """Base class for data converters for sequence classification data sets.""" 180 | 181 | def get_train_examples(self, data_dir): 182 | """Gets a collection of `InputExample`s for the train set.""" 183 | raise NotImplementedError() 184 | 185 | def get_dev_examples(self, data_dir): 186 | """Gets a collection of `InputExample`s for the dev set.""" 187 | raise NotImplementedError() 188 | 189 | def get_test_examples(self, data_dir): 190 | """Gets a collection of `InputExample`s for prediction.""" 191 | raise NotImplementedError() 192 | 193 | def get_labels(self): 194 | """Gets the list of labels for this data set.""" 195 | raise NotImplementedError() 196 | 197 | @classmethod 198 | def _read_tsv(cls, input_file, quotechar=None): 199 | """Reads a tab separated value file.""" 200 | with tf.gfile.Open(input_file, "r") as f: 201 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 202 | lines = [] 203 | for line in reader: 204 | lines.append(line) 205 | return lines 206 | 207 | class SimProcessor(DataProcessor): 208 | """Processor for the Sim task""" 209 | 210 | def get_train_examples(self, data_dir): 211 | raise NotImplementedError() 212 | 213 | def get_dev_examples(self, data_dir): 214 | raise NotImplementedError() 215 | 216 | def get_test_examples(self, data_dir): 217 | file_path = os.path.join(data_dir, 'test.txt') 218 | f = open(file_path, 'r') 219 | test_data = [] 220 | index = 0 221 | for line in f.readlines(): 222 | guid = 'test-%d' % index 223 | line = line.replace("\n", "").split("\t") 224 | text_a = tokenization.convert_to_unicode(str(line[0])) 225 | text_b = tokenization.convert_to_unicode(str(line[1])) 226 | label = str(line[2]) 227 | test_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 228 | index += 1 229 | return test_data 230 | 231 | def get_labels(self): 232 | return ['0', '1'] 233 | 234 | def convert_single_example(ex_index, example, label_list, max_seq_length, 235 | tokenizer): 236 | """Converts a single `InputExample` into a single `InputFeatures`.""" 237 | 238 | if isinstance(example, PaddingInputExample): 239 | return InputFeatures( 240 | input_ids=[0] * max_seq_length, 241 | input_mask=[0] * max_seq_length, 242 | segment_ids=[0] * max_seq_length, 243 | # label_id=0, 244 | is_real_example=False) 245 | 246 | label_map = {} 247 | for (i, label) in enumerate(label_list): 248 | label_map[label] = i 249 | 250 | tokens_a = tokenizer.tokenize(example.text_a) 251 | tokens_b = None 252 | if example.text_b: 253 | tokens_b = tokenizer.tokenize(example.text_b) 254 | 255 | if tokens_b: 256 | # Modifies `tokens_a` and `tokens_b` in place so that the total 257 | # length is less than the specified length. 258 | # Account for [CLS], [SEP], [SEP] with "- 3" 259 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 260 | else: 261 | # Account for [CLS] and [SEP] with "- 2" 262 | if len(tokens_a) > max_seq_length - 2: 263 | tokens_a = tokens_a[0:(max_seq_length - 2)] 264 | 265 | # The convention in BERT is: 266 | # (a) For sequence pairs: 267 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 268 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 269 | # (b) For single sequences: 270 | # tokens: [CLS] the dog is hairy . [SEP] 271 | # type_ids: 0 0 0 0 0 0 0 272 | # 273 | # Where "type_ids" are used to indicate whether this is the first 274 | # sequence or the second sequence. The embedding vectors for `type=0` and 275 | # `type=1` were learned during pre-training and are added to the wordpiece 276 | # embedding vector (and position vector). This is not *strictly* necessary 277 | # since the [SEP] token unambiguously separates the sequences, but it makes 278 | # it easier for the model to learn the concept of sequences. 279 | # 280 | # For classification tasks, the first vector (corresponding to [CLS]) is 281 | # used as the "sentence vector". Note that this only makes sense because 282 | # the entire model is fine-tuned. 283 | tokens = [] 284 | segment_ids = [] 285 | tokens.append("[CLS]") 286 | segment_ids.append(0) 287 | for token in tokens_a: 288 | tokens.append(token) 289 | segment_ids.append(0) 290 | tokens.append("[SEP]") 291 | segment_ids.append(0) 292 | 293 | if tokens_b: 294 | for token in tokens_b: 295 | tokens.append(token) 296 | segment_ids.append(1) 297 | tokens.append("[SEP]") 298 | segment_ids.append(1) 299 | 300 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 301 | 302 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 303 | # tokens are attended to. 304 | input_mask = [1] * len(input_ids) 305 | 306 | # Zero-pad up to the sequence length. 307 | while len(input_ids) < max_seq_length: 308 | input_ids.append(0) 309 | input_mask.append(0) 310 | segment_ids.append(0) 311 | 312 | assert len(input_ids) == max_seq_length 313 | assert len(input_mask) == max_seq_length 314 | assert len(segment_ids) == max_seq_length 315 | 316 | # label_id = label_map[example.label] 317 | if ex_index < 5: 318 | tf.logging.info("*** Example ***") 319 | tf.logging.info("guid: %s" % (example.guid)) 320 | tf.logging.info("tokens: %s" % " ".join( 321 | [tokenization.printable_text(x) for x in tokens])) 322 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 323 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 324 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 325 | # tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 326 | 327 | feature = InputFeatures( 328 | input_ids=input_ids, 329 | input_mask=input_mask, 330 | segment_ids=segment_ids, 331 | # label_id=label_id, 332 | is_real_example=True) 333 | return feature 334 | 335 | 336 | def file_based_convert_examples_to_features( 337 | examples, label_list, max_seq_length, tokenizer, output_file): 338 | """Convert a set of `InputExample`s to a TFRecord file.""" 339 | 340 | writer = tf.python_io.TFRecordWriter(output_file) 341 | 342 | for (ex_index, example) in enumerate(examples): 343 | if ex_index % 10000 == 0: 344 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 345 | 346 | feature = convert_single_example(ex_index, example, label_list, 347 | max_seq_length, tokenizer) 348 | 349 | def create_int_feature(values): 350 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 351 | return f 352 | 353 | features = collections.OrderedDict() 354 | features["input_ids"] = create_int_feature(feature.input_ids) 355 | features["input_mask"] = create_int_feature(feature.input_mask) 356 | features["segment_ids"] = create_int_feature(feature.segment_ids) 357 | # features["label_ids"] = create_int_feature([feature.label_id]) 358 | features["is_real_example"] = create_int_feature( 359 | [int(feature.is_real_example)]) 360 | 361 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 362 | writer.write(tf_example.SerializeToString()) 363 | writer.close() 364 | 365 | 366 | def file_based_input_fn_builder(input_file, seq_length, is_training, 367 | drop_remainder): 368 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 369 | 370 | name_to_features = { 371 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 372 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 373 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 374 | # "label_ids": tf.FixedLenFeature([], tf.int64), 375 | "is_real_example": tf.FixedLenFeature([], tf.int64), 376 | } 377 | 378 | def _decode_record(record, name_to_features): 379 | """Decodes a record to a TensorFlow example.""" 380 | example = tf.parse_single_example(record, name_to_features) 381 | 382 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 383 | # So cast all int64 to int32. 384 | for name in list(example.keys()): 385 | t = example[name] 386 | if t.dtype == tf.int64: 387 | t = tf.to_int32(t) 388 | example[name] = t 389 | 390 | return example 391 | 392 | def input_fn(params): 393 | """The actual input function.""" 394 | batch_size = params["batch_size"] 395 | 396 | # For training, we want a lot of parallel reading and shuffling. 397 | # For eval, we want no shuffling and parallel reading doesn't matter. 398 | d = tf.data.TFRecordDataset(input_file) 399 | if is_training: 400 | d = d.repeat() 401 | d = d.shuffle(buffer_size=100) 402 | 403 | d = d.apply( 404 | tf.contrib.data.map_and_batch( 405 | lambda record: _decode_record(record, name_to_features), 406 | batch_size=batch_size, 407 | drop_remainder=drop_remainder)) 408 | 409 | return d 410 | 411 | return input_fn 412 | 413 | 414 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 415 | """Truncates a sequence pair in place to the maximum length.""" 416 | 417 | # This is a simple heuristic which will always truncate the longer sequence 418 | # one token at a time. This makes more sense than truncating an equal percent 419 | # of tokens from each, since if one sequence is very short then each token 420 | # that's truncated likely contains more information than a longer sequence. 421 | while True: 422 | total_length = len(tokens_a) + len(tokens_b) 423 | if total_length <= max_length: 424 | break 425 | if len(tokens_a) > len(tokens_b): 426 | tokens_a.pop() 427 | else: 428 | tokens_b.pop() 429 | 430 | 431 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 432 | num_labels, use_one_hot_embeddings): 433 | """Creates a classification model.""" 434 | model = modeling.BertModel( 435 | config=bert_config, 436 | is_training=is_training, 437 | input_ids=input_ids, 438 | input_mask=input_mask, 439 | token_type_ids=segment_ids, 440 | use_one_hot_embeddings=use_one_hot_embeddings) 441 | 442 | # In the demo, we are doing a simple classification task on the entire 443 | # segment. 444 | # 445 | # If you want to use the token-level output, use model.get_sequence_output() 446 | # instead. 447 | output_layer = model.get_pooled_output() 448 | 449 | hidden_size = output_layer.shape[-1].value 450 | 451 | output_weights = tf.get_variable( 452 | "output_weights", [num_labels, hidden_size], 453 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 454 | 455 | output_bias = tf.get_variable( 456 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 457 | 458 | with tf.variable_scope("loss"): 459 | if is_training: 460 | # I.e., 0.1 dropout 461 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 462 | 463 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 464 | logits = tf.nn.bias_add(logits, output_bias) 465 | probabilities = tf.nn.softmax(logits, axis=-1) 466 | # log_probs = tf.nn.log_softmax(logits, axis=-1) 467 | # 468 | # one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 469 | # 470 | # per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 471 | # loss = tf.reduce_mean(per_example_loss) 472 | 473 | return probabilities 474 | 475 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 476 | num_train_steps, num_warmup_steps, use_tpu, 477 | use_one_hot_embeddings): 478 | """Returns `model_fn` closure for TPUEstimator.""" 479 | 480 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 481 | """The `model_fn` for TPUEstimator.""" 482 | 483 | tf.logging.info("*** Features ***") 484 | for name in sorted(features.keys()): 485 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 486 | 487 | input_ids = features["input_ids"] 488 | input_mask = features["input_mask"] 489 | segment_ids = features["segment_ids"] 490 | 491 | is_training = False 492 | 493 | probabilities = create_model( 494 | bert_config, is_training, input_ids, input_mask, segment_ids, 495 | num_labels, use_one_hot_embeddings) 496 | 497 | tvars = tf.trainable_variables() 498 | initialized_variable_names = {} 499 | scaffold_fn = None 500 | if init_checkpoint: 501 | (assignment_map, initialized_variable_names 502 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 503 | if use_tpu: 504 | 505 | def tpu_scaffold(): 506 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 507 | return tf.train.Scaffold() 508 | 509 | scaffold_fn = tpu_scaffold 510 | else: 511 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 512 | 513 | tf.logging.info("**** Trainable Variables ****") 514 | for var in tvars: 515 | init_string = "" 516 | if var.name in initialized_variable_names: 517 | init_string = ", *INIT_FROM_CKPT*" 518 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 519 | init_string) 520 | 521 | output_spec = None 522 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 523 | mode=mode, 524 | predictions={"probabilities": probabilities}, 525 | scaffold_fn=scaffold_fn) 526 | return output_spec 527 | 528 | return model_fn 529 | 530 | 531 | # This function is not used by this file but is still used by the Colab and 532 | # people who depend on it. 533 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 534 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 535 | 536 | all_input_ids = [] 537 | all_input_mask = [] 538 | all_segment_ids = [] 539 | # all_label_ids = [] 540 | 541 | for feature in features: 542 | all_input_ids.append(feature.input_ids) 543 | all_input_mask.append(feature.input_mask) 544 | all_segment_ids.append(feature.segment_ids) 545 | # all_label_ids.append(feature.label_id) 546 | 547 | def input_fn(params): 548 | """The actual input function.""" 549 | batch_size = params["batch_size"] 550 | 551 | num_examples = len(features) 552 | 553 | # This is for demo purposes and does NOT scale to large data sets. We do 554 | # not use Dataset.from_generator() because that uses tf.py_func which is 555 | # not TPU compatible. The right way to load data is with TFRecordReader. 556 | d = tf.data.Dataset.from_tensor_slices({ 557 | "input_ids": 558 | tf.constant( 559 | all_input_ids, shape=[num_examples, seq_length], 560 | dtype=tf.int32), 561 | "input_mask": 562 | tf.constant( 563 | all_input_mask, 564 | shape=[num_examples, seq_length], 565 | dtype=tf.int32), 566 | "segment_ids": 567 | tf.constant( 568 | all_segment_ids, 569 | shape=[num_examples, seq_length], 570 | dtype=tf.int32), 571 | # "label_ids": 572 | # tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 573 | }) 574 | 575 | if is_training: 576 | d = d.repeat() 577 | d = d.shuffle(buffer_size=100) 578 | 579 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 580 | return d 581 | 582 | return input_fn 583 | 584 | 585 | # This function is not used by this file but is still used by the Colab and 586 | # people who depend on it. 587 | def convert_examples_to_features(examples, label_list, max_seq_length, 588 | tokenizer): 589 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 590 | 591 | features = [] 592 | for (ex_index, example) in enumerate(examples): 593 | if ex_index % 10000 == 0: 594 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 595 | 596 | feature = convert_single_example(ex_index, example, label_list, 597 | max_seq_length, tokenizer) 598 | 599 | features.append(feature) 600 | return features 601 | 602 | 603 | def predicts(text_data): 604 | tf.logging.set_verbosity(tf.logging.INFO) 605 | 606 | processors = { 607 | "sim": SimProcessor, 608 | } 609 | 610 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 611 | FLAGS.init_checkpoint) 612 | 613 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 614 | raise ValueError( 615 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 616 | 617 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 618 | 619 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 620 | raise ValueError( 621 | "Cannot use sequence length %d because the BERT model " 622 | "was only trained up to sequence length %d" % 623 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 624 | 625 | tf.gfile.MakeDirs(FLAGS.output_dir) 626 | 627 | task_name = FLAGS.task_name.lower() 628 | 629 | if task_name not in processors: 630 | raise ValueError("Task not found: %s" % (task_name)) 631 | 632 | processor = processors[task_name]() 633 | 634 | label_list = processor.get_labels() 635 | 636 | tokenizer = tokenization.FullTokenizer( 637 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 638 | 639 | tpu_cluster_resolver = None 640 | if FLAGS.use_tpu and FLAGS.tpu_name: 641 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 642 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 643 | 644 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 645 | run_config = tf.contrib.tpu.RunConfig( 646 | cluster=tpu_cluster_resolver, 647 | master=FLAGS.master, 648 | model_dir=FLAGS.output_dir, 649 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 650 | tpu_config=tf.contrib.tpu.TPUConfig( 651 | iterations_per_loop=FLAGS.iterations_per_loop, 652 | num_shards=FLAGS.num_tpu_cores, 653 | per_host_input_for_training=is_per_host)) 654 | 655 | train_examples = None 656 | num_train_steps = None 657 | num_warmup_steps = None 658 | 659 | model_fn = model_fn_builder( 660 | bert_config=bert_config, 661 | num_labels=len(label_list), 662 | init_checkpoint=FLAGS.init_checkpoint, 663 | learning_rate=FLAGS.learning_rate, 664 | num_train_steps=num_train_steps, 665 | num_warmup_steps=num_warmup_steps, 666 | use_tpu=FLAGS.use_tpu, 667 | use_one_hot_embeddings=FLAGS.use_tpu) 668 | 669 | # If TPU is not available, this will fall back to normal Estimator on CPU 670 | # or GPU. 671 | estimator = tf.contrib.tpu.TPUEstimator( 672 | use_tpu=FLAGS.use_tpu, 673 | model_fn=model_fn, 674 | config=run_config, 675 | train_batch_size=FLAGS.train_batch_size, 676 | eval_batch_size=FLAGS.eval_batch_size, 677 | predict_batch_size=FLAGS.predict_batch_size) 678 | 679 | if FLAGS.do_predict: 680 | # predict_examples = processor.get_test_examples(FLAGS.data_dir) 681 | test_data = [] 682 | for index in range(len(text_data)): 683 | guid = 'test-%d' % index 684 | text_a = tokenization.convert_to_unicode(str(text_data[index][0])) 685 | text_b = tokenization.convert_to_unicode(str(text_data[index][1])) 686 | # label = str(test[2]) 687 | test_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=None)) 688 | 689 | predict_examples = test_data 690 | num_actual_predict_examples = len(predict_examples) 691 | if FLAGS.use_tpu: 692 | # TPU requires a fixed batch size for all batches, therefore the number 693 | # of examples must be a multiple of the batch size, or else examples 694 | # will get dropped. So we pad with fake examples which are ignored 695 | # later on. 696 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 697 | predict_examples.append(PaddingInputExample()) 698 | 699 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 700 | file_based_convert_examples_to_features(predict_examples, label_list, 701 | FLAGS.max_seq_length, tokenizer, 702 | predict_file) 703 | 704 | tf.logging.info("***** Running prediction*****") 705 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 706 | len(predict_examples), num_actual_predict_examples, 707 | len(predict_examples) - num_actual_predict_examples) 708 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 709 | 710 | predict_drop_remainder = True if FLAGS.use_tpu else False 711 | predict_input_fn = file_based_input_fn_builder( 712 | input_file=predict_file, 713 | seq_length=FLAGS.max_seq_length, 714 | is_training=False, 715 | drop_remainder=predict_drop_remainder) 716 | 717 | result = estimator.predict(input_fn=predict_input_fn) 718 | 719 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 720 | res = [] 721 | with tf.gfile.GFile(output_predict_file, "w") as writer: 722 | num_written_lines = 0 723 | tf.logging.info("***** Predict results *****") 724 | for (i, prediction) in enumerate(result): 725 | probabilities = prediction["probabilities"] 726 | if i >= num_actual_predict_examples: 727 | break 728 | output_line = "\t".join( 729 | str(class_probability) 730 | for class_probability in probabilities) + "\n" 731 | writer.write(output_line) 732 | num_written_lines += 1 733 | dicts = {} 734 | for i in range(len(probabilities)): 735 | dicts[label_list[i]] = probabilities[i] 736 | print(dicts) 737 | dicts = sorted(dicts.items(), key=lambda x: x[1], reverse=True) 738 | res.append([dicts[0][0], dicts[0][1]]) 739 | assert num_written_lines == num_actual_predict_examples 740 | return res 741 | 742 | 743 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. true for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is much faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | output = input_tensor 473 | 474 | if use_token_type: 475 | if token_type_ids is None: 476 | raise ValueError("`token_type_ids` must be specified if" 477 | "`use_token_type` is True.") 478 | token_type_table = tf.get_variable( 479 | name=token_type_embedding_name, 480 | shape=[token_type_vocab_size, width], 481 | initializer=create_initializer(initializer_range)) 482 | # This vocab will be small so we always do one-hot here, since it is always 483 | # faster for a small vocabulary. 484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 487 | token_type_embeddings = tf.reshape(token_type_embeddings, 488 | [batch_size, seq_length, width]) 489 | output += token_type_embeddings 490 | 491 | if use_position_embeddings: 492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 493 | with tf.control_dependencies([assert_op]): 494 | full_position_embeddings = tf.get_variable( 495 | name=position_embedding_name, 496 | shape=[max_position_embeddings, width], 497 | initializer=create_initializer(initializer_range)) 498 | # Since the position embedding table is a learned variable, we create it 499 | # using a (long) sequence length `max_position_embeddings`. The actual 500 | # sequence length might be shorter than this, for faster training of 501 | # tasks that do not have long sequences. 502 | # 503 | # So `full_position_embeddings` is effectively an embedding table 504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 506 | # perform a slice. 507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 508 | [seq_length, -1]) 509 | num_dims = len(output.shape.as_list()) 510 | 511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 512 | # we broadcast among the first dimensions, which is typically just 513 | # the batch size. 514 | position_broadcast_shape = [] 515 | for _ in range(num_dims - 2): 516 | position_broadcast_shape.append(1) 517 | position_broadcast_shape.extend([seq_length, width]) 518 | position_embeddings = tf.reshape(position_embeddings, 519 | position_broadcast_shape) 520 | output += position_embeddings 521 | 522 | output = layer_norm_and_dropout(output, dropout_prob) 523 | return output 524 | 525 | 526 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 527 | """Create 3D attention mask from a 2D tensor mask. 528 | 529 | Args: 530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 532 | 533 | Returns: 534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 535 | """ 536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 537 | batch_size = from_shape[0] 538 | from_seq_length = from_shape[1] 539 | 540 | to_shape = get_shape_list(to_mask, expected_rank=2) 541 | to_seq_length = to_shape[1] 542 | 543 | to_mask = tf.cast( 544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 545 | 546 | # We don't assume that `from_tensor` is a mask (although it could be). We 547 | # don't actually care if we attend *from* padding tokens (only *to* padding) 548 | # tokens so we create a tensor of all ones. 549 | # 550 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 551 | broadcast_ones = tf.ones( 552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 553 | 554 | # Here we broadcast along two dimensions to create the mask. 555 | mask = broadcast_ones * to_mask 556 | 557 | return mask 558 | 559 | 560 | def attention_layer(from_tensor, 561 | to_tensor, 562 | attention_mask=None, 563 | num_attention_heads=1, 564 | size_per_head=512, 565 | query_act=None, 566 | key_act=None, 567 | value_act=None, 568 | attention_probs_dropout_prob=0.0, 569 | initializer_range=0.02, 570 | do_return_2d_tensor=False, 571 | batch_size=None, 572 | from_seq_length=None, 573 | to_seq_length=None): 574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 575 | 576 | This is an implementation of multi-headed attention based on "Attention 577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 578 | this is self-attention. Each timestep in `from_tensor` attends to the 579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 580 | 581 | This function first projects `from_tensor` into a "query" tensor and 582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 583 | of tensors of length `num_attention_heads`, where each tensor is of shape 584 | [batch_size, seq_length, size_per_head]. 585 | 586 | Then, the query and key tensors are dot-producted and scaled. These are 587 | softmaxed to obtain attention probabilities. The value tensors are then 588 | interpolated by these probabilities, then concatenated back to a single 589 | tensor and returned. 590 | 591 | In practice, the multi-headed attention are done with transposes and 592 | reshapes rather than actual separate tensors. 593 | 594 | Args: 595 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 596 | from_width]. 597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 598 | attention_mask: (optional) int32 Tensor of shape [batch_size, 599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 600 | attention scores will effectively be set to -infinity for any positions in 601 | the mask that are 0, and will be unchanged for positions that are 1. 602 | num_attention_heads: int. Number of attention heads. 603 | size_per_head: int. Size of each attention head. 604 | query_act: (optional) Activation function for the query transform. 605 | key_act: (optional) Activation function for the key transform. 606 | value_act: (optional) Activation function for the value transform. 607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 608 | attention probabilities. 609 | initializer_range: float. Range of the weight initializer. 610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 612 | output will be of shape [batch_size, from_seq_length, num_attention_heads 613 | * size_per_head]. 614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 615 | of the 3D version of the `from_tensor` and `to_tensor`. 616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `from_tensor`. 618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 619 | of the 3D version of the `to_tensor`. 620 | 621 | Returns: 622 | float Tensor of shape [batch_size, from_seq_length, 623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 624 | true, this will be of shape [batch_size * from_seq_length, 625 | num_attention_heads * size_per_head]). 626 | 627 | Raises: 628 | ValueError: Any of the arguments or tensor shapes are invalid. 629 | """ 630 | 631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 632 | seq_length, width): 633 | output_tensor = tf.reshape( 634 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 635 | 636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 637 | return output_tensor 638 | 639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 641 | 642 | if len(from_shape) != len(to_shape): 643 | raise ValueError( 644 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 645 | 646 | if len(from_shape) == 3: 647 | batch_size = from_shape[0] 648 | from_seq_length = from_shape[1] 649 | to_seq_length = to_shape[1] 650 | elif len(from_shape) == 2: 651 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 652 | raise ValueError( 653 | "When passing in rank 2 tensors to attention_layer, the values " 654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 655 | "must all be specified.") 656 | 657 | # Scalar dimensions referenced here: 658 | # B = batch size (number of sequences) 659 | # F = `from_tensor` sequence length 660 | # T = `to_tensor` sequence length 661 | # N = `num_attention_heads` 662 | # H = `size_per_head` 663 | 664 | from_tensor_2d = reshape_to_matrix(from_tensor) 665 | to_tensor_2d = reshape_to_matrix(to_tensor) 666 | 667 | # `query_layer` = [B*F, N*H] 668 | query_layer = tf.layers.dense( 669 | from_tensor_2d, 670 | num_attention_heads * size_per_head, 671 | activation=query_act, 672 | name="query", 673 | kernel_initializer=create_initializer(initializer_range)) 674 | 675 | # `key_layer` = [B*T, N*H] 676 | key_layer = tf.layers.dense( 677 | to_tensor_2d, 678 | num_attention_heads * size_per_head, 679 | activation=key_act, 680 | name="key", 681 | kernel_initializer=create_initializer(initializer_range)) 682 | 683 | # `value_layer` = [B*T, N*H] 684 | value_layer = tf.layers.dense( 685 | to_tensor_2d, 686 | num_attention_heads * size_per_head, 687 | activation=value_act, 688 | name="value", 689 | kernel_initializer=create_initializer(initializer_range)) 690 | 691 | # `query_layer` = [B, N, F, H] 692 | query_layer = transpose_for_scores(query_layer, batch_size, 693 | num_attention_heads, from_seq_length, 694 | size_per_head) 695 | 696 | # `key_layer` = [B, N, T, H] 697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 698 | to_seq_length, size_per_head) 699 | 700 | # Take the dot product between "query" and "key" to get the raw 701 | # attention scores. 702 | # `attention_scores` = [B, N, F, T] 703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 704 | attention_scores = tf.multiply(attention_scores, 705 | 1.0 / math.sqrt(float(size_per_head))) 706 | 707 | if attention_mask is not None: 708 | # `attention_mask` = [B, 1, F, T] 709 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 710 | 711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 712 | # masked positions, this operation will create a tensor which is 0.0 for 713 | # positions we want to attend and -10000.0 for masked positions. 714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 715 | 716 | # Since we are adding it to the raw scores before the softmax, this is 717 | # effectively the same as removing these entirely. 718 | attention_scores += adder 719 | 720 | # Normalize the attention scores to probabilities. 721 | # `attention_probs` = [B, N, F, T] 722 | attention_probs = tf.nn.softmax(attention_scores) 723 | 724 | # This is actually dropping out entire tokens to attend to, which might 725 | # seem a bit unusual, but is taken from the original Transformer paper. 726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 727 | 728 | # `value_layer` = [B, T, N, H] 729 | value_layer = tf.reshape( 730 | value_layer, 731 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 732 | 733 | # `value_layer` = [B, N, T, H] 734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 735 | 736 | # `context_layer` = [B, N, F, H] 737 | context_layer = tf.matmul(attention_probs, value_layer) 738 | 739 | # `context_layer` = [B, F, N, H] 740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 741 | 742 | if do_return_2d_tensor: 743 | # `context_layer` = [B*F, N*H] 744 | context_layer = tf.reshape( 745 | context_layer, 746 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 747 | else: 748 | # `context_layer` = [B, F, N*H] 749 | context_layer = tf.reshape( 750 | context_layer, 751 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 752 | 753 | return context_layer 754 | 755 | 756 | def transformer_model(input_tensor, 757 | attention_mask=None, 758 | hidden_size=768, 759 | num_hidden_layers=12, 760 | num_attention_heads=12, 761 | intermediate_size=3072, 762 | intermediate_act_fn=gelu, 763 | hidden_dropout_prob=0.1, 764 | attention_probs_dropout_prob=0.1, 765 | initializer_range=0.02, 766 | do_return_all_layers=False): 767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 768 | 769 | This is almost an exact implementation of the original Transformer encoder. 770 | 771 | See the original paper: 772 | https://arxiv.org/abs/1706.03762 773 | 774 | Also see: 775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 776 | 777 | Args: 778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 780 | seq_length], with 1 for positions that can be attended to and 0 in 781 | positions that should not be. 782 | hidden_size: int. Hidden size of the Transformer. 783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 784 | num_attention_heads: int. Number of attention heads in the Transformer. 785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 786 | forward) layer. 787 | intermediate_act_fn: function. The non-linear activation function to apply 788 | to the output of the intermediate/feed-forward layer. 789 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 790 | attention_probs_dropout_prob: float. Dropout probability of the attention 791 | probabilities. 792 | initializer_range: float. Range of the initializer (stddev of truncated 793 | normal). 794 | do_return_all_layers: Whether to also return all layers or just the final 795 | layer. 796 | 797 | Returns: 798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 799 | hidden layer of the Transformer. 800 | 801 | Raises: 802 | ValueError: A Tensor shape or parameter is invalid. 803 | """ 804 | if hidden_size % num_attention_heads != 0: 805 | raise ValueError( 806 | "The hidden size (%d) is not a multiple of the number of attention " 807 | "heads (%d)" % (hidden_size, num_attention_heads)) 808 | 809 | attention_head_size = int(hidden_size / num_attention_heads) 810 | input_shape = get_shape_list(input_tensor, expected_rank=3) 811 | batch_size = input_shape[0] 812 | seq_length = input_shape[1] 813 | input_width = input_shape[2] 814 | 815 | # The Transformer performs sum residuals on all layers so the input needs 816 | # to be the same as the hidden size. 817 | if input_width != hidden_size: 818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 819 | (input_width, hidden_size)) 820 | 821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 824 | # help the optimizer. 825 | prev_output = reshape_to_matrix(input_tensor) 826 | 827 | all_layer_outputs = [] 828 | for layer_idx in range(num_hidden_layers): 829 | with tf.variable_scope("layer_%d" % layer_idx): 830 | layer_input = prev_output 831 | 832 | with tf.variable_scope("attention"): 833 | attention_heads = [] 834 | with tf.variable_scope("self"): 835 | attention_head = attention_layer( 836 | from_tensor=layer_input, 837 | to_tensor=layer_input, 838 | attention_mask=attention_mask, 839 | num_attention_heads=num_attention_heads, 840 | size_per_head=attention_head_size, 841 | attention_probs_dropout_prob=attention_probs_dropout_prob, 842 | initializer_range=initializer_range, 843 | do_return_2d_tensor=True, 844 | batch_size=batch_size, 845 | from_seq_length=seq_length, 846 | to_seq_length=seq_length) 847 | attention_heads.append(attention_head) 848 | 849 | attention_output = None 850 | if len(attention_heads) == 1: 851 | attention_output = attention_heads[0] 852 | else: 853 | # In the case where we have other sequences, we just concatenate 854 | # them to the self-attention head before the projection. 855 | attention_output = tf.concat(attention_heads, axis=-1) 856 | 857 | # Run a linear projection of `hidden_size` then add a residual 858 | # with `layer_input`. 859 | with tf.variable_scope("output"): 860 | attention_output = tf.layers.dense( 861 | attention_output, 862 | hidden_size, 863 | kernel_initializer=create_initializer(initializer_range)) 864 | attention_output = dropout(attention_output, hidden_dropout_prob) 865 | attention_output = layer_norm(attention_output + layer_input) 866 | 867 | # The activation is only applied to the "intermediate" hidden layer. 868 | with tf.variable_scope("intermediate"): 869 | intermediate_output = tf.layers.dense( 870 | attention_output, 871 | intermediate_size, 872 | activation=intermediate_act_fn, 873 | kernel_initializer=create_initializer(initializer_range)) 874 | 875 | # Down-project back to `hidden_size` then add the residual. 876 | with tf.variable_scope("output"): 877 | layer_output = tf.layers.dense( 878 | intermediate_output, 879 | hidden_size, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | layer_output = dropout(layer_output, hidden_dropout_prob) 882 | layer_output = layer_norm(layer_output + attention_output) 883 | prev_output = layer_output 884 | all_layer_outputs.append(layer_output) 885 | 886 | if do_return_all_layers: 887 | final_outputs = [] 888 | for layer_output in all_layer_outputs: 889 | final_output = reshape_from_matrix(layer_output, input_shape) 890 | final_outputs.append(final_output) 891 | return final_outputs 892 | else: 893 | final_output = reshape_from_matrix(prev_output, input_shape) 894 | return final_output 895 | 896 | 897 | def get_shape_list(tensor, expected_rank=None, name=None): 898 | """Returns a list of the shape of tensor, preferring static dimensions. 899 | 900 | Args: 901 | tensor: A tf.Tensor object to find the shape of. 902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 903 | specified and the `tensor` has a different rank, and exception will be 904 | thrown. 905 | name: Optional name of the tensor for the error message. 906 | 907 | Returns: 908 | A list of dimensions of the shape of tensor. All static dimensions will 909 | be returned as python integers, and dynamic dimensions will be returned 910 | as tf.Tensor scalars. 911 | """ 912 | if name is None: 913 | name = tensor.name 914 | 915 | if expected_rank is not None: 916 | assert_rank(tensor, expected_rank, name) 917 | 918 | shape = tensor.shape.as_list() 919 | 920 | non_static_indexes = [] 921 | for (index, dim) in enumerate(shape): 922 | if dim is None: 923 | non_static_indexes.append(index) 924 | 925 | if not non_static_indexes: 926 | return shape 927 | 928 | dyn_shape = tf.shape(tensor) 929 | for index in non_static_indexes: 930 | shape[index] = dyn_shape[index] 931 | return shape 932 | 933 | 934 | def reshape_to_matrix(input_tensor): 935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 936 | ndims = input_tensor.shape.ndims 937 | if ndims < 2: 938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 939 | (input_tensor.shape)) 940 | if ndims == 2: 941 | return input_tensor 942 | 943 | width = input_tensor.shape[-1] 944 | output_tensor = tf.reshape(input_tensor, [-1, width]) 945 | return output_tensor 946 | 947 | 948 | def reshape_from_matrix(output_tensor, orig_shape_list): 949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 950 | if len(orig_shape_list) == 2: 951 | return output_tensor 952 | 953 | output_shape = get_shape_list(output_tensor) 954 | 955 | orig_dims = orig_shape_list[0:-1] 956 | width = output_shape[-1] 957 | 958 | return tf.reshape(output_tensor, orig_dims + [width]) 959 | 960 | 961 | def assert_rank(tensor, expected_rank, name=None): 962 | """Raises an exception if the tensor rank is not of the expected rank. 963 | 964 | Args: 965 | tensor: A tf.Tensor to check the rank of. 966 | expected_rank: Python integer or list of integers, expected rank. 967 | name: Optional name of the tensor for the error message. 968 | 969 | Raises: 970 | ValueError: If the expected shape doesn't match the actual shape. 971 | """ 972 | if name is None: 973 | name = tensor.name 974 | 975 | expected_rank_dict = {} 976 | if isinstance(expected_rank, six.integer_types): 977 | expected_rank_dict[expected_rank] = True 978 | else: 979 | for x in expected_rank: 980 | expected_rank_dict[x] = True 981 | 982 | actual_rank = tensor.shape.ndims 983 | if actual_rank not in expected_rank_dict: 984 | scope_name = tf.get_variable_scope().name 985 | raise ValueError( 986 | "For the tensor `%s` in scope `%s`, the actual rank " 987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 989 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | import pandas as pd 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string( 36 | "data_dir", None, 37 | "The input data dir. Should contain the .tsv files (or other data files) " 38 | "for the task.") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 46 | 47 | flags.DEFINE_string("vocab_file", None, 48 | "The vocabulary file that the BERT model was trained on.") 49 | 50 | flags.DEFINE_string( 51 | "output_dir", None, 52 | "The output directory where the model checkpoints will be written.") 53 | 54 | ## Other parameters 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", None, 58 | "Initial checkpoint (usually from a pre-trained BERT model).") 59 | 60 | flags.DEFINE_bool( 61 | "do_lower_case", True, 62 | "Whether to lower case the input text. Should be True for uncased " 63 | "models and False for cased models.") 64 | 65 | flags.DEFINE_integer( 66 | "max_seq_length", 128, 67 | "The maximum total input sequence length after WordPiece tokenization. " 68 | "Sequences longer than this will be truncated, and sequences shorter " 69 | "than this will be padded.") 70 | 71 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 72 | 73 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 74 | 75 | flags.DEFINE_bool( 76 | "do_predict", False, 77 | "Whether to run the model in inference mode on the test set.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 80 | 81 | flags.DEFINE_integer("eval_batch_size", 16, "Total batch size for eval.") 82 | 83 | flags.DEFINE_integer("predict_batch_size", 16, "Total batch size for predict.") 84 | 85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 86 | 87 | flags.DEFINE_float("num_train_epochs", 3.0, 88 | "Total number of training epochs to perform.") 89 | 90 | flags.DEFINE_float( 91 | "warmup_proportion", 0.1, 92 | "Proportion of training to perform linear learning rate warmup for. " 93 | "E.g., 0.1 = 10% of training.") 94 | 95 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 96 | "How often to save the model checkpoint.") 97 | 98 | flags.DEFINE_integer("iterations_per_loop", 1000, 99 | "How many steps to make in each estimator call.") 100 | 101 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | 128 | class InputExample(object): 129 | """A single training/test example for simple sequence classification.""" 130 | 131 | def __init__(self, guid, text_a, text_b=None, label=None): 132 | """Constructs a InputExample. 133 | 134 | Args: 135 | guid: Unique id for the example. 136 | text_a: string. The untokenized text of the first sequence. For single 137 | sequence tasks, only this sequence must be specified. 138 | text_b: (Optional) string. The untokenized text of the second sequence. 139 | Only must be specified for sequence pair tasks. 140 | label: (Optional) string. The label of the example. This should be 141 | specified for train and dev examples, but not for test examples. 142 | """ 143 | self.guid = guid 144 | self.text_a = text_a 145 | self.text_b = text_b 146 | self.label = label 147 | 148 | 149 | class PaddingInputExample(object): 150 | """Fake example so the num input examples is a multiple of the batch size. 151 | 152 | When running eval/predict on the TPU, we need to pad the number of examples 153 | to be a multiple of the batch size, because the TPU requires a fixed batch 154 | size. The alternative is to drop the last batch, which is bad because it means 155 | the entire output data won't be generated. 156 | 157 | We use this class instead of `None` because treating `None` as padding 158 | battches could cause silent errors. 159 | """ 160 | 161 | 162 | class InputFeatures(object): 163 | """A single set of features of data.""" 164 | 165 | def __init__(self, 166 | input_ids, 167 | input_mask, 168 | segment_ids, 169 | label_id, 170 | is_real_example=True): 171 | self.input_ids = input_ids 172 | self.input_mask = input_mask 173 | self.segment_ids = segment_ids 174 | self.label_id = label_id 175 | self.is_real_example = is_real_example 176 | 177 | 178 | class DataProcessor(object): 179 | """Base class for data converters for sequence classification data sets.""" 180 | 181 | def get_train_examples(self, data_dir): 182 | """Gets a collection of `InputExample`s for the train set.""" 183 | raise NotImplementedError() 184 | 185 | def get_dev_examples(self, data_dir): 186 | """Gets a collection of `InputExample`s for the dev set.""" 187 | raise NotImplementedError() 188 | 189 | def get_test_examples(self, data_dir): 190 | """Gets a collection of `InputExample`s for prediction.""" 191 | raise NotImplementedError() 192 | 193 | def get_labels(self): 194 | """Gets the list of labels for this data set.""" 195 | raise NotImplementedError() 196 | 197 | @classmethod 198 | def _read_tsv(cls, input_file, quotechar=None): 199 | """Reads a tab separated value file.""" 200 | with tf.gfile.Open(input_file, "r") as f: 201 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 202 | lines = [] 203 | for line in reader: 204 | lines.append(line) 205 | return lines 206 | 207 | 208 | class XnliProcessor(DataProcessor): 209 | """Processor for the XNLI data set.""" 210 | 211 | def __init__(self): 212 | self.language = "zh" 213 | 214 | def get_train_examples(self, data_dir): 215 | """See base class.""" 216 | lines = self._read_tsv( 217 | os.path.join(data_dir, "multinli", 218 | "multinli.train.%s.tsv" % self.language)) 219 | examples = [] 220 | for (i, line) in enumerate(lines): 221 | if i == 0: 222 | continue 223 | guid = "train-%d" % (i) 224 | text_a = tokenization.convert_to_unicode(line[0]) 225 | text_b = tokenization.convert_to_unicode(line[1]) 226 | label = tokenization.convert_to_unicode(line[2]) 227 | if label == tokenization.convert_to_unicode("contradictory"): 228 | label = tokenization.convert_to_unicode("contradiction") 229 | examples.append( 230 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 231 | return examples 232 | 233 | def get_dev_examples(self, data_dir): 234 | """See base class.""" 235 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 236 | examples = [] 237 | for (i, line) in enumerate(lines): 238 | if i == 0: 239 | continue 240 | guid = "dev-%d" % (i) 241 | language = tokenization.convert_to_unicode(line[0]) 242 | if language != tokenization.convert_to_unicode(self.language): 243 | continue 244 | text_a = tokenization.convert_to_unicode(line[6]) 245 | text_b = tokenization.convert_to_unicode(line[7]) 246 | label = tokenization.convert_to_unicode(line[1]) 247 | examples.append( 248 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 249 | return examples 250 | 251 | def get_labels(self): 252 | """See base class.""" 253 | return ["contradiction", "entailment", "neutral"] 254 | 255 | 256 | class SimProcessor(DataProcessor): 257 | """Processor for the Sim task""" 258 | 259 | # read csv 260 | # def get_train_examples(self, data_dir): 261 | # file_path = os.path.join(data_dir, 'train.csv') 262 | # train_df = pd.read_csv(file_path, encoding='utf-8') 263 | # train_data = [] 264 | # for index, train in enumerate(train_df.values): 265 | # guid = 'train-%d' % index 266 | # text_a = tokenization.convert_to_unicode(str(train[0])) 267 | # # text_b = tokenization.convert_to_unicode(str(train[1])) 268 | # label = str(train[1]) 269 | # train_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 270 | # return train_data 271 | 272 | # read txt 273 | def get_train_examples(self, data_dir): 274 | file_path = os.path.join(data_dir, 'train.txt') 275 | f = open(file_path, 'r') 276 | train_data = [] 277 | index = 0 278 | for line in f.readlines(): 279 | guid = 'train-%d' % index 280 | line = line.replace("\n", "").split("\t") 281 | text_a = tokenization.convert_to_unicode(str(line[0])) 282 | text_b = tokenization.convert_to_unicode(str(line[1])) 283 | label = str(line[2]) 284 | train_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 285 | index += 1 286 | return train_data 287 | 288 | # csv 289 | # def get_dev_examples(self, data_dir): 290 | # file_path = os.path.join(data_dir, 'dev.csv') 291 | # dev_df = pd.read_csv(file_path, encoding='utf-8') 292 | # dev_data = [] 293 | # for index, dev in enumerate(dev_df.values): 294 | # guid = 'dev-%d' % index 295 | # text_a = tokenization.convert_to_unicode(str(dev[0])) 296 | # # text_b = tokenization.convert_to_unicode(str(dev[1])) 297 | # label = str(dev[1]) 298 | # dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 299 | # return dev_data 300 | 301 | def get_dev_examples(self, data_dir): 302 | file_path = os.path.join(data_dir, 'dev.txt') 303 | f = open(file_path, 'r') 304 | dev_data = [] 305 | index = 0 306 | for line in f.readlines(): 307 | guid = 'dev-%d' % index 308 | line = line.replace("\n", "").split("\t") 309 | text_a = tokenization.convert_to_unicode(str(line[0])) 310 | text_b = tokenization.convert_to_unicode(str(line[1])) 311 | label = str(line[2]) 312 | dev_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 313 | index += 1 314 | return dev_data 315 | 316 | def get_test_examples(self, data_dir): 317 | file_path = os.path.join(data_dir, 'test.txt') 318 | f = open(file_path, 'r') 319 | test_data = [] 320 | index = 0 321 | for line in f.readlines(): 322 | guid = 'test-%d' % index 323 | line = line.replace("\n", "").split("\t") 324 | text_a = tokenization.convert_to_unicode(str(line[0])) 325 | text_b = tokenization.convert_to_unicode(str(line[1])) 326 | label = str(line[2]) 327 | test_data.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 328 | index += 1 329 | return test_data 330 | 331 | def get_labels(self): 332 | return ['0', '1'] 333 | 334 | 335 | class MnliProcessor(DataProcessor): 336 | """Processor for the MultiNLI data set (GLUE version).""" 337 | 338 | def get_train_examples(self, data_dir): 339 | """See base class.""" 340 | return self._create_examples( 341 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 342 | 343 | def get_dev_examples(self, data_dir): 344 | """See base class.""" 345 | return self._create_examples( 346 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 347 | "dev_matched") 348 | 349 | def get_test_examples(self, data_dir): 350 | """See base class.""" 351 | return self._create_examples( 352 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 353 | 354 | def get_labels(self): 355 | """See base class.""" 356 | return ["contradiction", "entailment", "neutral"] 357 | 358 | def _create_examples(self, lines, set_type): 359 | """Creates examples for the training and dev sets.""" 360 | examples = [] 361 | for (i, line) in enumerate(lines): 362 | if i == 0: 363 | continue 364 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 365 | text_a = tokenization.convert_to_unicode(line[8]) 366 | text_b = tokenization.convert_to_unicode(line[9]) 367 | if set_type == "test": 368 | label = "contradiction" 369 | else: 370 | label = tokenization.convert_to_unicode(line[-1]) 371 | examples.append( 372 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 373 | return examples 374 | 375 | 376 | class MrpcProcessor(DataProcessor): 377 | """Processor for the MRPC data set (GLUE version).""" 378 | 379 | def get_train_examples(self, data_dir): 380 | """See base class.""" 381 | return self._create_examples( 382 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 383 | 384 | def get_dev_examples(self, data_dir): 385 | """See base class.""" 386 | return self._create_examples( 387 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 388 | 389 | def get_test_examples(self, data_dir): 390 | """See base class.""" 391 | return self._create_examples( 392 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 393 | 394 | def get_labels(self): 395 | """See base class.""" 396 | return ["0", "1"] 397 | 398 | def _create_examples(self, lines, set_type): 399 | """Creates examples for the training and dev sets.""" 400 | examples = [] 401 | for (i, line) in enumerate(lines): 402 | if i == 0: 403 | continue 404 | guid = "%s-%s" % (set_type, i) 405 | text_a = tokenization.convert_to_unicode(line[3]) 406 | text_b = tokenization.convert_to_unicode(line[4]) 407 | if set_type == "test": 408 | label = "0" 409 | else: 410 | label = tokenization.convert_to_unicode(line[0]) 411 | examples.append( 412 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 413 | return examples 414 | 415 | 416 | class ColaProcessor(DataProcessor): 417 | """Processor for the CoLA data set (GLUE version).""" 418 | 419 | def get_train_examples(self, data_dir): 420 | """See base class.""" 421 | return self._create_examples( 422 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 423 | 424 | def get_dev_examples(self, data_dir): 425 | """See base class.""" 426 | return self._create_examples( 427 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 428 | 429 | def get_test_examples(self, data_dir): 430 | """See base class.""" 431 | return self._create_examples( 432 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 433 | 434 | def get_labels(self): 435 | """See base class.""" 436 | return ["0", "1"] 437 | 438 | def _create_examples(self, lines, set_type): 439 | """Creates examples for the training and dev sets.""" 440 | examples = [] 441 | for (i, line) in enumerate(lines): 442 | # Only the test set has a header 443 | if set_type == "test" and i == 0: 444 | continue 445 | guid = "%s-%s" % (set_type, i) 446 | if set_type == "test": 447 | text_a = tokenization.convert_to_unicode(line[1]) 448 | label = "0" 449 | else: 450 | text_a = tokenization.convert_to_unicode(line[3]) 451 | label = tokenization.convert_to_unicode(line[1]) 452 | examples.append( 453 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 454 | return examples 455 | 456 | 457 | def convert_single_example(ex_index, example, label_list, max_seq_length, 458 | tokenizer): 459 | """Converts a single `InputExample` into a single `InputFeatures`.""" 460 | 461 | if isinstance(example, PaddingInputExample): 462 | return InputFeatures( 463 | input_ids=[0] * max_seq_length, 464 | input_mask=[0] * max_seq_length, 465 | segment_ids=[0] * max_seq_length, 466 | label_id=0, 467 | is_real_example=False) 468 | 469 | label_map = {} 470 | for (i, label) in enumerate(label_list): 471 | label_map[label] = i 472 | 473 | tokens_a = tokenizer.tokenize(example.text_a) 474 | tokens_b = None 475 | if example.text_b: 476 | tokens_b = tokenizer.tokenize(example.text_b) 477 | 478 | if tokens_b: 479 | # Modifies `tokens_a` and `tokens_b` in place so that the total 480 | # length is less than the specified length. 481 | # Account for [CLS], [SEP], [SEP] with "- 3" 482 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 483 | else: 484 | # Account for [CLS] and [SEP] with "- 2" 485 | if len(tokens_a) > max_seq_length - 2: 486 | tokens_a = tokens_a[0:(max_seq_length - 2)] 487 | 488 | # The convention in BERT is: 489 | # (a) For sequence pairs: 490 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 491 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 492 | # (b) For single sequences: 493 | # tokens: [CLS] the dog is hairy . [SEP] 494 | # type_ids: 0 0 0 0 0 0 0 495 | # 496 | # Where "type_ids" are used to indicate whether this is the first 497 | # sequence or the second sequence. The embedding vectors for `type=0` and 498 | # `type=1` were learned during pre-training and are added to the wordpiece 499 | # embedding vector (and position vector). This is not *strictly* necessary 500 | # since the [SEP] token unambiguously separates the sequences, but it makes 501 | # it easier for the model to learn the concept of sequences. 502 | # 503 | # For classification tasks, the first vector (corresponding to [CLS]) is 504 | # used as the "sentence vector". Note that this only makes sense because 505 | # the entire model is fine-tuned. 506 | tokens = [] 507 | segment_ids = [] 508 | tokens.append("[CLS]") 509 | segment_ids.append(0) 510 | for token in tokens_a: 511 | tokens.append(token) 512 | segment_ids.append(0) 513 | tokens.append("[SEP]") 514 | segment_ids.append(0) 515 | 516 | if tokens_b: 517 | for token in tokens_b: 518 | tokens.append(token) 519 | segment_ids.append(1) 520 | tokens.append("[SEP]") 521 | segment_ids.append(1) 522 | 523 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 524 | 525 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 526 | # tokens are attended to. 527 | input_mask = [1] * len(input_ids) 528 | 529 | # Zero-pad up to the sequence length. 530 | while len(input_ids) < max_seq_length: 531 | input_ids.append(0) 532 | input_mask.append(0) 533 | segment_ids.append(0) 534 | 535 | assert len(input_ids) == max_seq_length 536 | assert len(input_mask) == max_seq_length 537 | assert len(segment_ids) == max_seq_length 538 | 539 | label_id = label_map[example.label] 540 | if ex_index < 5: 541 | tf.logging.info("*** Example ***") 542 | tf.logging.info("guid: %s" % (example.guid)) 543 | tf.logging.info("tokens: %s" % " ".join( 544 | [tokenization.printable_text(x) for x in tokens])) 545 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 546 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 547 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 548 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 549 | 550 | feature = InputFeatures( 551 | input_ids=input_ids, 552 | input_mask=input_mask, 553 | segment_ids=segment_ids, 554 | label_id=label_id, 555 | is_real_example=True) 556 | return feature 557 | 558 | 559 | def file_based_convert_examples_to_features( 560 | examples, label_list, max_seq_length, tokenizer, output_file): 561 | """Convert a set of `InputExample`s to a TFRecord file.""" 562 | 563 | writer = tf.python_io.TFRecordWriter(output_file) 564 | 565 | for (ex_index, example) in enumerate(examples): 566 | if ex_index % 10000 == 0: 567 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 568 | 569 | feature = convert_single_example(ex_index, example, label_list, 570 | max_seq_length, tokenizer) 571 | 572 | def create_int_feature(values): 573 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 574 | return f 575 | 576 | features = collections.OrderedDict() 577 | features["input_ids"] = create_int_feature(feature.input_ids) 578 | features["input_mask"] = create_int_feature(feature.input_mask) 579 | features["segment_ids"] = create_int_feature(feature.segment_ids) 580 | features["label_ids"] = create_int_feature([feature.label_id]) 581 | features["is_real_example"] = create_int_feature( 582 | [int(feature.is_real_example)]) 583 | 584 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 585 | writer.write(tf_example.SerializeToString()) 586 | writer.close() 587 | 588 | 589 | def file_based_input_fn_builder(input_file, seq_length, is_training, 590 | drop_remainder): 591 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 592 | 593 | name_to_features = { 594 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 595 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 596 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 597 | "label_ids": tf.FixedLenFeature([], tf.int64), 598 | "is_real_example": tf.FixedLenFeature([], tf.int64), 599 | } 600 | 601 | def _decode_record(record, name_to_features): 602 | """Decodes a record to a TensorFlow example.""" 603 | example = tf.parse_single_example(record, name_to_features) 604 | 605 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 606 | # So cast all int64 to int32. 607 | for name in list(example.keys()): 608 | t = example[name] 609 | if t.dtype == tf.int64: 610 | t = tf.to_int32(t) 611 | example[name] = t 612 | 613 | return example 614 | 615 | def input_fn(params): 616 | """The actual input function.""" 617 | batch_size = params["batch_size"] 618 | 619 | # For training, we want a lot of parallel reading and shuffling. 620 | # For eval, we want no shuffling and parallel reading doesn't matter. 621 | d = tf.data.TFRecordDataset(input_file) 622 | if is_training: 623 | d = d.repeat() 624 | d = d.shuffle(buffer_size=100) 625 | 626 | d = d.apply( 627 | tf.contrib.data.map_and_batch( 628 | lambda record: _decode_record(record, name_to_features), 629 | batch_size=batch_size, 630 | drop_remainder=drop_remainder)) 631 | 632 | return d 633 | 634 | return input_fn 635 | 636 | 637 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 638 | """Truncates a sequence pair in place to the maximum length.""" 639 | 640 | # This is a simple heuristic which will always truncate the longer sequence 641 | # one token at a time. This makes more sense than truncating an equal percent 642 | # of tokens from each, since if one sequence is very short then each token 643 | # that's truncated likely contains more information than a longer sequence. 644 | while True: 645 | total_length = len(tokens_a) + len(tokens_b) 646 | if total_length <= max_length: 647 | break 648 | if len(tokens_a) > len(tokens_b): 649 | tokens_a.pop() 650 | else: 651 | tokens_b.pop() 652 | 653 | 654 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 655 | labels, num_labels, use_one_hot_embeddings): 656 | """Creates a classification model.""" 657 | model = modeling.BertModel( 658 | config=bert_config, 659 | is_training=is_training, 660 | input_ids=input_ids, 661 | input_mask=input_mask, 662 | token_type_ids=segment_ids, 663 | use_one_hot_embeddings=use_one_hot_embeddings) 664 | 665 | # In the demo, we are doing a simple classification task on the entire 666 | # segment. 667 | # 668 | # If you want to use the token-level output, use model.get_sequence_output() 669 | # instead. 670 | output_layer = model.get_pooled_output() 671 | 672 | hidden_size = output_layer.shape[-1].value 673 | 674 | output_weights = tf.get_variable( 675 | "output_weights", [num_labels, hidden_size], 676 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 677 | 678 | output_bias = tf.get_variable( 679 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 680 | 681 | with tf.variable_scope("loss"): 682 | if is_training: 683 | # I.e., 0.1 dropout 684 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 685 | 686 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 687 | logits = tf.nn.bias_add(logits, output_bias) 688 | probabilities = tf.nn.softmax(logits, axis=-1) 689 | log_probs = tf.nn.log_softmax(logits, axis=-1) 690 | 691 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 692 | 693 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 694 | loss = tf.reduce_mean(per_example_loss) 695 | 696 | return (loss, per_example_loss, logits, probabilities) 697 | 698 | 699 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 700 | num_train_steps, num_warmup_steps, use_tpu, 701 | use_one_hot_embeddings): 702 | """Returns `model_fn` closure for TPUEstimator.""" 703 | 704 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 705 | """The `model_fn` for TPUEstimator.""" 706 | 707 | tf.logging.info("*** Features ***") 708 | for name in sorted(features.keys()): 709 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 710 | 711 | input_ids = features["input_ids"] 712 | input_mask = features["input_mask"] 713 | segment_ids = features["segment_ids"] 714 | label_ids = features["label_ids"] 715 | is_real_example = None 716 | if "is_real_example" in features: 717 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 718 | else: 719 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 720 | 721 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 722 | 723 | (total_loss, per_example_loss, logits, probabilities) = create_model( 724 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 725 | num_labels, use_one_hot_embeddings) 726 | 727 | tvars = tf.trainable_variables() 728 | initialized_variable_names = {} 729 | scaffold_fn = None 730 | if init_checkpoint: 731 | (assignment_map, initialized_variable_names 732 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 733 | if use_tpu: 734 | 735 | def tpu_scaffold(): 736 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 737 | return tf.train.Scaffold() 738 | 739 | scaffold_fn = tpu_scaffold 740 | else: 741 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 742 | 743 | tf.logging.info("**** Trainable Variables ****") 744 | for var in tvars: 745 | init_string = "" 746 | if var.name in initialized_variable_names: 747 | init_string = ", *INIT_FROM_CKPT*" 748 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 749 | init_string) 750 | 751 | output_spec = None 752 | if mode == tf.estimator.ModeKeys.TRAIN: 753 | 754 | train_op = optimization.create_optimizer( 755 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 756 | 757 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 758 | mode=mode, 759 | loss=total_loss, 760 | train_op=train_op, 761 | scaffold_fn=scaffold_fn) 762 | elif mode == tf.estimator.ModeKeys.EVAL: 763 | 764 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 765 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 766 | accuracy = tf.metrics.accuracy( 767 | labels=label_ids, predictions=predictions, weights=is_real_example) 768 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 769 | return { 770 | "eval_accuracy": accuracy, 771 | "eval_loss": loss, 772 | } 773 | 774 | eval_metrics = (metric_fn, 775 | [per_example_loss, label_ids, logits, is_real_example]) 776 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 777 | mode=mode, 778 | loss=total_loss, 779 | eval_metrics=eval_metrics, 780 | scaffold_fn=scaffold_fn) 781 | else: 782 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 783 | mode=mode, 784 | predictions={"probabilities": probabilities}, 785 | scaffold_fn=scaffold_fn) 786 | return output_spec 787 | 788 | return model_fn 789 | 790 | 791 | # This function is not used by this file but is still used by the Colab and 792 | # people who depend on it. 793 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 794 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 795 | 796 | all_input_ids = [] 797 | all_input_mask = [] 798 | all_segment_ids = [] 799 | all_label_ids = [] 800 | 801 | for feature in features: 802 | all_input_ids.append(feature.input_ids) 803 | all_input_mask.append(feature.input_mask) 804 | all_segment_ids.append(feature.segment_ids) 805 | all_label_ids.append(feature.label_id) 806 | 807 | def input_fn(params): 808 | """The actual input function.""" 809 | batch_size = params["batch_size"] 810 | 811 | num_examples = len(features) 812 | 813 | # This is for demo purposes and does NOT scale to large data sets. We do 814 | # not use Dataset.from_generator() because that uses tf.py_func which is 815 | # not TPU compatible. The right way to load data is with TFRecordReader. 816 | d = tf.data.Dataset.from_tensor_slices({ 817 | "input_ids": 818 | tf.constant( 819 | all_input_ids, shape=[num_examples, seq_length], 820 | dtype=tf.int32), 821 | "input_mask": 822 | tf.constant( 823 | all_input_mask, 824 | shape=[num_examples, seq_length], 825 | dtype=tf.int32), 826 | "segment_ids": 827 | tf.constant( 828 | all_segment_ids, 829 | shape=[num_examples, seq_length], 830 | dtype=tf.int32), 831 | "label_ids": 832 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 833 | }) 834 | 835 | if is_training: 836 | d = d.repeat() 837 | d = d.shuffle(buffer_size=100) 838 | 839 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 840 | return d 841 | 842 | return input_fn 843 | 844 | 845 | # This function is not used by this file but is still used by the Colab and 846 | # people who depend on it. 847 | def convert_examples_to_features(examples, label_list, max_seq_length, 848 | tokenizer): 849 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 850 | 851 | features = [] 852 | for (ex_index, example) in enumerate(examples): 853 | if ex_index % 10000 == 0: 854 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 855 | 856 | feature = convert_single_example(ex_index, example, label_list, 857 | max_seq_length, tokenizer) 858 | 859 | features.append(feature) 860 | return features 861 | 862 | 863 | def main(_): 864 | tf.logging.set_verbosity(tf.logging.INFO) 865 | 866 | processors = { 867 | "cola": ColaProcessor, 868 | "mnli": MnliProcessor, 869 | "mrpc": MrpcProcessor, 870 | "xnli": XnliProcessor, 871 | "sim": SimProcessor, 872 | } 873 | 874 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 875 | FLAGS.init_checkpoint) 876 | 877 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 878 | raise ValueError( 879 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 880 | 881 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 882 | 883 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 884 | raise ValueError( 885 | "Cannot use sequence length %d because the BERT model " 886 | "was only trained up to sequence length %d" % 887 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 888 | 889 | tf.gfile.MakeDirs(FLAGS.output_dir) 890 | 891 | task_name = FLAGS.task_name.lower() 892 | 893 | if task_name not in processors: 894 | raise ValueError("Task not found: %s" % (task_name)) 895 | 896 | processor = processors[task_name]() 897 | 898 | label_list = processor.get_labels() 899 | 900 | tokenizer = tokenization.FullTokenizer( 901 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 902 | 903 | tpu_cluster_resolver = None 904 | if FLAGS.use_tpu and FLAGS.tpu_name: 905 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 906 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 907 | 908 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 909 | run_config = tf.contrib.tpu.RunConfig( 910 | cluster=tpu_cluster_resolver, 911 | master=FLAGS.master, 912 | model_dir=FLAGS.output_dir, 913 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 914 | tpu_config=tf.contrib.tpu.TPUConfig( 915 | iterations_per_loop=FLAGS.iterations_per_loop, 916 | num_shards=FLAGS.num_tpu_cores, 917 | per_host_input_for_training=is_per_host)) 918 | 919 | train_examples = None 920 | num_train_steps = None 921 | num_warmup_steps = None 922 | if FLAGS.do_train: 923 | train_examples = processor.get_train_examples(FLAGS.data_dir) 924 | num_train_steps = int( 925 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 926 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 927 | 928 | model_fn = model_fn_builder( 929 | bert_config=bert_config, 930 | num_labels=len(label_list), 931 | init_checkpoint=FLAGS.init_checkpoint, 932 | learning_rate=FLAGS.learning_rate, 933 | num_train_steps=num_train_steps, 934 | num_warmup_steps=num_warmup_steps, 935 | use_tpu=FLAGS.use_tpu, 936 | use_one_hot_embeddings=FLAGS.use_tpu) 937 | 938 | # If TPU is not available, this will fall back to normal Estimator on CPU 939 | # or GPU. 940 | estimator = tf.contrib.tpu.TPUEstimator( 941 | use_tpu=FLAGS.use_tpu, 942 | model_fn=model_fn, 943 | config=run_config, 944 | train_batch_size=FLAGS.train_batch_size, 945 | eval_batch_size=FLAGS.eval_batch_size, 946 | predict_batch_size=FLAGS.predict_batch_size) 947 | 948 | if FLAGS.do_train: 949 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 950 | file_based_convert_examples_to_features( 951 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 952 | tf.logging.info("***** Running training *****") 953 | tf.logging.info(" Num examples = %d", len(train_examples)) 954 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 955 | tf.logging.info(" Num steps = %d", num_train_steps) 956 | train_input_fn = file_based_input_fn_builder( 957 | input_file=train_file, 958 | seq_length=FLAGS.max_seq_length, 959 | is_training=True, 960 | drop_remainder=True) 961 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 962 | 963 | if FLAGS.do_eval: 964 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 965 | num_actual_eval_examples = len(eval_examples) 966 | if FLAGS.use_tpu: 967 | # TPU requires a fixed batch size for all batches, therefore the number 968 | # of examples must be a multiple of the batch size, or else examples 969 | # will get dropped. So we pad with fake examples which are ignored 970 | # later on. These do NOT count towards the metric (all tf.metrics 971 | # support a per-instance weight, and these get a weight of 0.0). 972 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 973 | eval_examples.append(PaddingInputExample()) 974 | 975 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 976 | file_based_convert_examples_to_features( 977 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 978 | 979 | tf.logging.info("***** Running evaluation *****") 980 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 981 | len(eval_examples), num_actual_eval_examples, 982 | len(eval_examples) - num_actual_eval_examples) 983 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 984 | 985 | # This tells the estimator to run through the entire set. 986 | eval_steps = None 987 | # However, if running eval on the TPU, you will need to specify the 988 | # number of steps. 989 | if FLAGS.use_tpu: 990 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 991 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 992 | 993 | eval_drop_remainder = True if FLAGS.use_tpu else False 994 | eval_input_fn = file_based_input_fn_builder( 995 | input_file=eval_file, 996 | seq_length=FLAGS.max_seq_length, 997 | is_training=False, 998 | drop_remainder=eval_drop_remainder) 999 | 1000 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 1001 | 1002 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 1003 | with tf.gfile.GFile(output_eval_file, "w") as writer: 1004 | tf.logging.info("***** Eval results *****") 1005 | for key in sorted(result.keys()): 1006 | tf.logging.info(" %s = %s", key, str(result[key])) 1007 | writer.write("%s = %s\n" % (key, str(result[key]))) 1008 | 1009 | if FLAGS.do_predict: 1010 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 1011 | num_actual_predict_examples = len(predict_examples) 1012 | if FLAGS.use_tpu: 1013 | # TPU requires a fixed batch size for all batches, therefore the number 1014 | # of examples must be a multiple of the batch size, or else examples 1015 | # will get dropped. So we pad with fake examples which are ignored 1016 | # later on. 1017 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 1018 | predict_examples.append(PaddingInputExample()) 1019 | 1020 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 1021 | file_based_convert_examples_to_features(predict_examples, label_list, 1022 | FLAGS.max_seq_length, tokenizer, 1023 | predict_file) 1024 | 1025 | tf.logging.info("***** Running prediction*****") 1026 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 1027 | len(predict_examples), num_actual_predict_examples, 1028 | len(predict_examples) - num_actual_predict_examples) 1029 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1030 | 1031 | predict_drop_remainder = True if FLAGS.use_tpu else False 1032 | predict_input_fn = file_based_input_fn_builder( 1033 | input_file=predict_file, 1034 | seq_length=FLAGS.max_seq_length, 1035 | is_training=False, 1036 | drop_remainder=predict_drop_remainder) 1037 | 1038 | result = estimator.predict(input_fn=predict_input_fn) 1039 | 1040 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 1041 | with tf.gfile.GFile(output_predict_file, "w") as writer: 1042 | num_written_lines = 0 1043 | tf.logging.info("***** Predict results *****") 1044 | for (i, prediction) in enumerate(result): 1045 | probabilities = prediction["probabilities"] 1046 | if i >= num_actual_predict_examples: 1047 | break 1048 | output_line = "\t".join( 1049 | str(class_probability) 1050 | for class_probability in probabilities) + "\n" 1051 | writer.write(output_line) 1052 | num_written_lines += 1 1053 | assert num_written_lines == num_actual_predict_examples 1054 | 1055 | 1056 | if __name__ == "__main__": 1057 | flags.mark_flag_as_required("data_dir") 1058 | flags.mark_flag_as_required("task_name") 1059 | flags.mark_flag_as_required("vocab_file") 1060 | flags.mark_flag_as_required("bert_config_file") 1061 | flags.mark_flag_as_required("output_dir") 1062 | tf.app.run() 1063 | --------------------------------------------------------------------------------