├── README.md ├── create_pretrain_data.sh ├── create_pretraining_data.py ├── modeling.py ├── optimization.py ├── optimization_finetuning.py ├── resources ├── RoBERTa_zh_Large_Learning_Curve.png └── vocab.txt ├── run_classifier.py ├── run_pretraining.py └── tokenization.py /README.md: -------------------------------------------------------------------------------- 1 | RoBERTa for Chinese, TensorFlow & PyTorch 2 | 3 | 中文预训练RoBERTa模型 4 | ------------------------------------------------- 5 | RoBERTa是BERT的改进版,通过改进训练任务和数据生成方式、训练更久、使用更大批次、使用更多数据等获得了State of The Art的效果;可以用Bert直接加载。 6 | 7 | 本项目是用TensorFlow实现了在大规模中文上RoBERTa的预训练,也会提供PyTorch的预训练模型和加载方式。 8 | 9 | *** 2019-10-12:添加【阅读理解】不同模型上测试效果对比 *** 10 | 11 | *** 2019-09-08: 添加国内下载地址、PyTorch版本、与多个模型bert-wwm、xlnet等模型效果初步对比 *** 12 | 13 | 14 | NLP自动标注工具(提效最多100X)-预约 15 | 16 | Pre-trained model of albert, chinese version is also available for you now. 17 | 18 | 中文预训练RoBERTa模型-下载 19 | ------------------------------------------------- 20 | *** 6层RoBERTa体验版 *** 21 | RoBERTa-zh-Layer6: Google Drive百度网盘,TensorFlow版本,Bert 直接加载, 大小为200M 22 | 23 | ###### ** 推荐 RoBERTa-zh-Large 通过验证** 24 | RoBERTa-zh-Large: Google Drive 百度网盘 ,TensorFlow版本,Bert 直接加载 25 | 26 | RoBERTa-zh-Large: Google Drive 百度网盘 ,PyTorch版本,Bert的PyTorch版直接加载 27 | 28 | RoBERTa 24/12层版训练数据:30G原始文本,近3亿个句子,100亿个中文字(token),产生了2.5亿个训练数据(instance); 29 | 30 | 覆盖新闻、社区问答、多个百科数据等; 31 | 32 | 本项目与中文预训练24层XLNet模型 XLNet_zh项目,使用相同的训练数据。 33 | 34 | RoBERTa_zh_L12: Google Drive百度网盘 TensorFlow版本,Bert 直接加载 35 | 36 | RoBERTa_zh_L12: Google Drive百度网盘 PyTorch版本,Bert的PyTorch版直接加载 37 | 38 | --------------------------------------------------------------- 39 | 40 | Roberta_l24_zh_base TensorFlow版本,Bert 直接加载 41 | 42 | 24层base版训练数据:10G文本,包含新闻、社区问答、多个百科数据等。 43 | 44 | 45 | 46 | What is RoBERTa: 47 | ------------------------------------------------- 48 | A robustly optimized method for pretraining natural language processing (NLP) systems that improves on Bidirectional Encoder Representations from Transformers, or BERT, the self-supervised method released by Google in 2018. 49 | 50 | RoBERTa, produces state-of-the-art results on the widely used NLP benchmark, General Language Understanding Evaluation (GLUE). The model delivered state-of-the-art performance on the MNLI, QNLI, RTE, STS-B, and RACE tasks and a sizable performance improvement on the GLUE benchmark. With a score of 88.5, RoBERTa reached the top position on the GLUE leaderboard, matching the performance of the previous leader, XLNet-Large. 51 | 52 | (Introduction from Facebook blog) 53 | 54 | 发布计划 Release Plan: 55 | ------------------------------------------------- 56 | 1、24层RoBERTa模型(roberta_l24_zh),使用30G文件训练, 9月8日 57 | 58 | 2、12层RoBERTa模型(roberta_l12_zh),使用30G文件训练, 9月8日 59 | 60 | 3、6层RoBERTa模型(roberta_l6_zh), 使用30G文件训练, 9月8日 61 | 62 | 4、PyTorch版本的模型(roberta_l6_zh_pytorch) 9月8日 63 | 64 | 5、30G中文语料,预训练格式,可直接训练(bert,xlent,gpt2) 待定 65 | 66 | 6、测试集测试和效果对比 9月14日 67 | 68 | 效果测试与对比 Performance 69 | ------------------------------------------------- 70 | ### 互联网新闻情感分析:CCF-Sentiment-Analysis 71 | 72 | | 模型 | 线上F1 | 73 | | :------- | :---------: | 74 | | BERT | 80.3 | 75 | | Bert-wwm-ext | 80.5 | 76 | | XLNet | 79.6 | 77 | | Roberta-mid | 80.5 | 78 | | Roberta-large (max_seq_length=512, split_num=1) | 81.25 | 79 | 80 | 注:数据来源于guoday的开源项目;数据集和任务介绍见:CCF互联网新闻情感分析 81 | 82 | ### 自然语言推断:XNLI 83 | 84 | | 模型 | 开发集 | 测试集 | 85 | | :------- | :---------: | :---------: | 86 | | BERT | 77.8 (77.4) | 77.8 (77.5) | 87 | | ERNIE | 79.7 (79.4) | 78.6 (78.2) | 88 | | BERT-wwm | 79.0 (78.4) | 78.2 (78.0) | 89 | | BERT-wwm-ext | 79.4 (78.6) | 78.7 (78.3) | 90 | | XLNet | 79.2 | 78.7 | 91 | | RoBERTa-zh-base | 79.8 |78.8 | 92 | | **RoBERTa-zh-Large** | **80.2 (80.0)** | **79.9 (79.5)** | 93 | 94 | 注:RoBERTa_l24_zh,只跑了两次,Performance可能还会提升; 95 | 96 | BERT-wwm-ext来自于这里;XLNet来自于这里; RoBERTa-zh-base,指12层RoBERTa中文模型 97 | 98 | ### 问题匹配语任务:LCQMC(Sentence Pair Matching) 99 | 100 | | 模型 | 开发集(Dev) | 测试集(Test) | 101 | | :------- | :---------: | :---------: | 102 | | BERT | 89.4(88.4) | 86.9(86.4) | 103 | | ERNIE | 89.8 (89.6) | **87.2** (87.0) | 104 | | BERT-wwm |89.4 (89.2) | 87.0 (86.8) | 105 | | BERT-wwm-ext | - |- | 106 | | RoBERTa-zh-base | 88.7 | 87.0 | 107 | | **RoBERTa-zh-Large** | **89.9**(89.6) | **87.2**(86.7) | 108 | | RoBERTa-zh-Large(20w_steps) | 89.7| 87.0 | 109 | 110 | 注:RoBERTa_l24_zh,只跑了两次,Performance可能还会提升。保持训练轮次和论文一致: 111 | 112 | ### 阅读理解测试 113 | 目前阅读理解类问题bert和roberta最优参数均为epoch2, batch=32, lr=3e-5, warmup=0.1 114 | 115 | #### cmrc2018(阅读理解) 116 | 117 | | models | DEV | 118 | | ------ | ------ | 119 | | sibert_base | F1:87.521(88.628) EM:67.381(69.152) | 120 | | sialbert_middle | F1:87.6956(87.878) EM:67.897(68.624) | 121 | | 哈工大讯飞 roberta_wwm_ext_base | F1:87.521(88.628) EM:67.381(69.152) | 122 | | brightmart roberta_middle | F1:86.841(87.242) EM:67.195(68.313) | 123 | | brightmart roberta_large | **F1:88.608(89.431) EM:69.935(72.538)** | 124 | 125 | #### DRCD(阅读理解) 126 | 127 | | models | DEV | 128 | | ------ | ------ | 129 | | siBert_base | F1:93.343(93.524) EM:87.968(88.28) | 130 | | siALBert_middle | F1:93.865(93.975) EM:88.723(88.961) | 131 | | 哈工大讯飞 roberta_wwm_ext_base | F1:94.257(94.48) EM:89.291(89.642) | 132 | | brightmart roberta_large | **F1:94.933(95.057) EM:90.113(90.238)** | 133 | 134 | #### CJRC(带有yes,no,unkown的阅读理解) 135 | 136 | | models | DEV | 137 | | ------ | ------ | 138 | | siBert_base | F1:80.714(81.14) EM:64.44(65.04) | 139 | | siALBert_middle | F1:80.9838(81.299) EM:63.796(64.202) | 140 | | 哈工大讯飞 roberta_wwm_ext_base | F1:81.510(81.684) EM:64.924(65.574) | 141 | | brightmart roberta_large | F1:80.16(80.475) EM:65.249(66.133) | 142 | 143 | 阅读理解测试对比数据来源bert_cn_finetune 144 | 145 | ? 处地方,将会很快更新到具体的值 146 | 147 | RoBERTa中文版 Chinese Version 148 | ------------------------------------------------- 149 | 本项目所指的中文预训练RoBERTa模型只指按照RoBERTa论文主要精神训练的模型。包括: 150 | 151 | 1、数据生成方式和任务改进:取消下一个句子预测,并且数据连续从一个文档中获得(见:Model Input Format and Next Sentence Prediction,DOC-SENTENCES) 152 | 153 | 2、更大更多样性的数据:使用30G中文训练,包含3亿个句子,100亿个字(即token)。由新闻、社区讨论、多个百科,包罗万象,覆盖数十万个主题, 154 | 155 | 所以数据具有多样性(为了更有多样性,可以可以加入网络书籍、小说、故事类文学、微博等)。 156 | 157 | 3、训练更久:总共训练了近20万,总共见过近16亿个训练数据(instance); 在Cloud TPU v3-256 上训练了24小时,相当于在TPU v3-8(128G显存)上需要训练一个月。 158 | 159 | 4、更大批次:使用了超大(8k)的批次batch size。 160 | 161 | 5、调整优化器等超参数。 162 | 163 | 除以上外,本项目中文版,使用了全词mask(whole word mask)。在全词Mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask,即全词Mask。 164 | 165 | 本项目中并没有直接实现dynamic mask。通过复制一个训练样本得到多份数据,每份数据使用不同mask,并加大复制的份数,可间接得到dynamic mask效果。 166 | 167 | ##### 使用说明 Instructions for Use 168 | 169 | 当前本项目是使用sequence length为256训练的,所以可能对长度在这个范围内的效果不错;如果你的任务的输入比较长(如序列长度为512),或许效果有影响。 170 | 171 | 有同学结合滑动窗口的形式,将序列做拆分,还是得到了比较好的效果,见#issue-16 172 | 173 | ##### 中文全词遮蔽 Whole Word Mask 174 | 175 | | 说明 | 样例 | 176 | | :------- | :--------- | 177 | | 原始文本 | 使用语言模型来预测下一个词的probability。 | 178 | | 分词文本 | 使用 语言 模型 来 预测 下 一个 词 的 probability 。 | 179 | | 原始Mask输入 | 使 用 语 言 [MASK] 型 来 [MASK] 测 下 一 个 词 的 pro [MASK] ##lity 。 | 180 | | 全词Mask输入 | 使 用 语 言 [MASK] [MASK] 来 [MASK] [MASK] 下 一 个 词 的 [MASK] [MASK] [MASK] 。 | 181 | 182 | 模型加载(以Sentence Pair Matching即句子对任务,LCQMC为例) 183 | ------------------------------------------------- 184 | 185 | 下载LCQMC数据集,包含训练、验证和测试集,训练集包含24万口语化描述的中文句子对,标签为1或0。1为句子语义相似,0为语义不相似。 186 | 187 | tensorFlow版本: 188 | 189 | 1、复制本项目: git clone https://github.com/brightmart/roberta_zh 190 | 191 | 2、进到项目(roberta_zh)中。 192 | 193 | 假设你将RoBERTa预训练模型下载并解压到该改项目的roberta_zh_large目录,即roberta_zh/roberta_zh_large 194 | 195 | 运行命令: 196 | 197 | export BERT_BASE_DIR=./roberta_zh_large 198 | export MY_DATA_DIR=./data/lcqmc 199 | python run_classifier.py \ 200 | --task_name=lcqmc_pair \ 201 | --do_train=true \ 202 | --do_eval=true \ 203 | --data_dir=$MY_DATA_DIR \ 204 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 205 | --bert_config_file=$BERT_BASE_DIR/bert_config_large.json \ 206 | --init_checkpoint=$BERT_BASE_DIR/roberta_zh_large_model.ckpt \ 207 | --max_seq_length=128 \ 208 | --train_batch_size=64 \ 209 | --learning_rate=2e-5 \ 210 | --num_train_epochs=3 \ 211 | --output_dir=./checkpoint_lcqmc 212 | 213 | 注:task_name为lcqmc_pair。这里已经在run_classifier.py中的添加一个processor,并加到processors中,用于指定做lcqmc任务,并加载训练和验证数据。 214 | 215 | PyTorch加载方式,先参考issue 9;将很快提供更具体方式。 216 | 217 | 预训练 Pre-training 218 | ------------------------------------------------- 219 | #### 1) 预训练的数据 data of pre-training 220 | 你可以使用你的任务相关领域的数据来训练,也可以从通用的语料中筛选出一部分与你领域相关的数据做训练。 221 | 222 | 通用语料数据见nlp_chinese_corpus:包含多个拥有数千万句子的语料的数据集。 223 | 224 | #### 2) 生成预训练数据 generate data for pre-training 225 | 包括使用参照DOC-SENTENCES的形式,连续从一个文档中获得数据;以及做全词遮蔽(whole word mask) 226 | 227 | shell脚本:批量将多个txt文本转化为tfrecord的数据。 228 | 229 | 如将第1到10个txt转化为tfrecords文件: 230 | 231 | nohup bash create_pretrain_data.sh 1 10 & 232 | 233 | 注:在我们的实验中使用15%的比例做全词遮蔽,模型学习难度大、收敛困难,所以我们用了10%的比例; 234 | 235 | #### 3)运行预训练命令 pre-training 236 | 去掉next sentence prediction任务 237 | 238 | export BERT_BASE_DIR= 239 | nohup python3 run_pretraining.py --input_file=./tf_records_all/tf*.tfrecord \ 240 | --output_dir=my_new_model_path --do_train=True --do_eval=True --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 241 | --train_batch_size=8192 --max_seq_length=256 --max_predictions_per_seq=23 \ 242 | --num_train_steps=200000 --num_warmup_steps=10000 --learning_rate=1e-4 \ 243 | --save_checkpoints_steps=3000 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt & 244 | 245 | 注:如果你重头开始训练,可以不指定init_checkpoint; 246 | 如果你从现有的模型基础上训练,指定一下BERT_BASE_DIR的路径,并确保bert_config_file和init_checkpoint两个参数的值能对应到相应的文件上; 247 | 领域上的预训练,可以不用训练特别久。 248 | 249 | Learning Curve 学习曲线 250 | ------------------------------------------------- 251 | 252 | 253 | 对显存的要求 Trade off between batch Size and sequence length 254 | ------------------------------------------------- 255 | 256 | System | Seq Length | Max Batch Size 257 | ------------ | ---------- | -------------- 258 | `RoBERTa-Base` | 64 | 64 259 | ... | 128 | 32 260 | ... | 256 | 16 261 | ... | 320 | 14 262 | ... | 384 | 12 263 | ... | 512 | 6 264 | `RoBERTa-Large` | 64 | 12 265 | ... | 128 | 6 266 | ... | 256 | 2 267 | ... | 320 | 1 268 | ... | 384 | 0 269 | ... | 512 | 0 270 | 271 | 272 | 273 | #### 技术交流与问题讨论QQ群: 836811304 274 | 275 | If you have any question, you can raise an issue, or send me an email: brightmart@hotmail.com; 276 | 277 | You can also send pull request to report you performance on your task or add methods on how to load models for PyTorch and so on. 278 | 279 | If you have ideas for generate best performance pre-training Chinese model, please also let me know. 280 | 281 | 请报告在你的任务上的准确率情况及与其他模型的比较。 282 | 283 | 284 | 项目贡献者,还包括: 285 | ------------------------------------------------- 286 | skyhawk1990 287 | 288 | 289 | ##### Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC) 290 | 291 | 292 | 293 | 294 | Reference 295 | ------------------------------------------------- 296 | 1、RoBERTa: A Robustly Optimized BERT Pretraining Approach 297 | 298 | 2、Pre-Training with Whole Word Masking for Chinese BERT 299 | 300 | 3、BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 301 | 302 | 4、LCQMC: A Large-scale Chinese Question Matching Corpus 303 | -------------------------------------------------------------------------------- /create_pretrain_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo $1,$2 3 | 4 | for((i=$1;i<=$2;i++)); 5 | do 6 | python3 create_pretraining_data.py --do_whole_word_mask=True --input_file=./raw_text/news2016zh_$i.txt \ 7 | --output_file=./tf_records_all/tf_news2016zh_$i.tfrecord --vocab_file=./resources/vocab.txt \ 8 | --do_lower_case=True --max_seq_length=256 --max_predictions_per_seq=23 --masked_lm_prob=0.10 --random_seed=12345 --dupe_factor=5 9 | done 10 | -------------------------------------------------------------------------------- /create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import re 24 | import tokenization 25 | import tensorflow as tf 26 | import jieba 27 | 28 | flags = tf.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string("input_file", None, 33 | "Input raw text file (or comma-separated list of files).") 34 | 35 | flags.DEFINE_string( 36 | "output_file", None, 37 | "Output TF example file (or comma-separated list of files).") 38 | 39 | flags.DEFINE_string("vocab_file", None, 40 | "The vocabulary file that the BERT model was trained on.") 41 | 42 | flags.DEFINE_bool( 43 | "do_lower_case", True, 44 | "Whether to lower case the input text. Should be True for uncased " 45 | "models and False for cased models.") 46 | 47 | flags.DEFINE_bool( 48 | "do_whole_word_mask", False, 49 | "Whether to use whole word masking rather than per-WordPiece masking.") 50 | 51 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 52 | 53 | flags.DEFINE_integer("max_predictions_per_seq", 20, 54 | "Maximum number of masked LM predictions per sequence.") 55 | 56 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 57 | 58 | flags.DEFINE_integer( 59 | "dupe_factor", 10, 60 | "Number of times to duplicate the input data (with different masks).") 61 | 62 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 63 | 64 | flags.DEFINE_float( 65 | "short_seq_prob", 0.1, 66 | "Probability of creating sequences which are shorter than the " 67 | "maximum length.") 68 | 69 | 70 | class TrainingInstance(object): 71 | """A single training instance (sentence pair).""" 72 | 73 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 74 | is_random_next): 75 | self.tokens = tokens 76 | self.segment_ids = segment_ids 77 | self.is_random_next = is_random_next 78 | self.masked_lm_positions = masked_lm_positions 79 | self.masked_lm_labels = masked_lm_labels 80 | 81 | def __str__(self): 82 | s = "" 83 | s += "tokens: %s\n" % (" ".join( 84 | [tokenization.printable_text(x) for x in self.tokens])) 85 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 86 | s += "is_random_next: %s\n" % self.is_random_next 87 | s += "masked_lm_positions: %s\n" % (" ".join( 88 | [str(x) for x in self.masked_lm_positions])) 89 | s += "masked_lm_labels: %s\n" % (" ".join( 90 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 91 | s += "\n" 92 | return s 93 | 94 | def __repr__(self): 95 | return self.__str__() 96 | 97 | 98 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 99 | max_predictions_per_seq, output_files): 100 | """Create TF example files from `TrainingInstance`s.""" 101 | writers = [] 102 | for output_file in output_files: 103 | writers.append(tf.python_io.TFRecordWriter(output_file)) 104 | 105 | writer_index = 0 106 | 107 | total_written = 0 108 | for (inst_index, instance) in enumerate(instances): 109 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 110 | input_mask = [1] * len(input_ids) 111 | segment_ids = list(instance.segment_ids) 112 | assert len(input_ids) <= max_seq_length 113 | 114 | while len(input_ids) < max_seq_length: 115 | input_ids.append(0) 116 | input_mask.append(0) 117 | segment_ids.append(0) 118 | 119 | assert len(input_ids) == max_seq_length 120 | assert len(input_mask) == max_seq_length 121 | # print("length of segment_ids:",len(segment_ids),"max_seq_length:", max_seq_length) 122 | assert len(segment_ids) == max_seq_length 123 | 124 | masked_lm_positions = list(instance.masked_lm_positions) 125 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 126 | masked_lm_weights = [1.0] * len(masked_lm_ids) 127 | 128 | while len(masked_lm_positions) < max_predictions_per_seq: 129 | masked_lm_positions.append(0) 130 | masked_lm_ids.append(0) 131 | masked_lm_weights.append(0.0) 132 | 133 | next_sentence_label = 1 if instance.is_random_next else 0 134 | 135 | features = collections.OrderedDict() 136 | features["input_ids"] = create_int_feature(input_ids) 137 | features["input_mask"] = create_int_feature(input_mask) 138 | features["segment_ids"] = create_int_feature(segment_ids) 139 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 140 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 141 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 142 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 143 | 144 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 145 | 146 | writers[writer_index].write(tf_example.SerializeToString()) 147 | writer_index = (writer_index + 1) % len(writers) 148 | 149 | total_written += 1 150 | 151 | if inst_index < 20: 152 | tf.logging.info("*** Example ***") 153 | tf.logging.info("tokens: %s" % " ".join( 154 | [tokenization.printable_text(x) for x in instance.tokens])) 155 | 156 | for feature_name in features.keys(): 157 | feature = features[feature_name] 158 | values = [] 159 | if feature.int64_list.value: 160 | values = feature.int64_list.value 161 | elif feature.float_list.value: 162 | values = feature.float_list.value 163 | tf.logging.info( 164 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 165 | 166 | for writer in writers: 167 | writer.close() 168 | 169 | tf.logging.info("Wrote %d total instances", total_written) 170 | 171 | 172 | def create_int_feature(values): 173 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 174 | return feature 175 | 176 | 177 | def create_float_feature(values): 178 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 179 | return feature 180 | 181 | 182 | def create_training_instances(input_files, tokenizer, max_seq_length, 183 | dupe_factor, short_seq_prob, masked_lm_prob, 184 | max_predictions_per_seq, rng): 185 | """Create `TrainingInstance`s from raw text.""" 186 | all_documents = [[]] 187 | 188 | # Input file format: 189 | # (1) One sentence per line. These should ideally be actual sentences, not 190 | # entire paragraphs or arbitrary spans of text. (Because we use the 191 | # sentence boundaries for the "next sentence prediction" task). 192 | # (2) Blank lines between documents. Document boundaries are needed so 193 | # that the "next sentence prediction" task doesn't span between documents. 194 | print("create_training_instances.started...") 195 | for input_file in input_files: 196 | with tf.gfile.GFile(input_file, "r") as reader: 197 | while True: 198 | line = tokenization.convert_to_unicode(reader.readline().replace("",""))# .replace("”","")) # 将、”替换掉。 199 | if not line: 200 | break 201 | line = line.strip() 202 | 203 | # Empty lines are used as document delimiters 204 | if not line: 205 | all_documents.append([]) 206 | tokens = tokenizer.tokenize(line) 207 | if tokens: 208 | all_documents[-1].append(tokens) 209 | 210 | # Remove empty documents 211 | all_documents = [x for x in all_documents if x] 212 | rng.shuffle(all_documents) 213 | 214 | vocab_words = list(tokenizer.vocab.keys()) 215 | instances = [] 216 | for _ in range(dupe_factor): 217 | for document_index in range(len(all_documents)): 218 | instances.extend( 219 | create_instances_from_document( 220 | all_documents, document_index, max_seq_length, short_seq_prob, 221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 222 | 223 | rng.shuffle(instances) 224 | print("create_training_instances.ended...") 225 | 226 | return instances 227 | 228 | 229 | def _is_chinese_char(cp): 230 | """Checks whether CP is the codepoint of a CJK character.""" 231 | # This defines a "chinese character" as anything in the CJK Unicode block: 232 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 233 | # 234 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 235 | # despite its name. The modern Korean Hangul alphabet is a different block, 236 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 237 | # space-separated words, so they are not treated specially and handled 238 | # like the all of the other languages. 239 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 240 | (cp >= 0x3400 and cp <= 0x4DBF) or # 241 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 242 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 243 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 244 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 245 | (cp >= 0xF900 and cp <= 0xFAFF) or # 246 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 247 | return True 248 | 249 | 250 | def get_new_segment(segment): # 新增的方法 #### 251 | """ 252 | 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。 253 | :param segment: 一句话 254 | :return: 一句处理过的话 255 | """ 256 | seq_cws = jieba.lcut("".join(segment)) 257 | seq_cws_dict = {x: 1 for x in seq_cws} 258 | new_segment = [] 259 | i = 0 260 | while i < len(segment): 261 | if len(re.findall('[\u4E00-\u9FA5]', segment[i]))==0: # 不是中文的,原文加进去。 262 | new_segment.append(segment[i]) 263 | i += 1 264 | continue 265 | 266 | has_add = False 267 | for length in range(3,0,-1): 268 | if i+length>len(segment): 269 | continue 270 | if ''.join(segment[i:i+length]) in seq_cws_dict: 271 | new_segment.append(segment[i]) 272 | for l in range(1, length): 273 | new_segment.append('##' + segment[i+l]) 274 | i += length 275 | has_add = True 276 | break 277 | if not has_add: 278 | new_segment.append(segment[i]) 279 | i += 1 280 | return new_segment 281 | 282 | def get_raw_instance(document,max_sequence_length): # 新增的方法 283 | """ 284 | 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 285 | :param document: 一整段 286 | :param max_sequence_length: 287 | :return: a list. each element is a sequence of text 288 | """ 289 | max_sequence_length_allowed=max_sequence_length-2 290 | document = [seq for seq in document if len(seq)max_sequence_length_allowed/2: # /2 306 | result_list.append(curr_seq) 307 | 308 | # # 计算总共可以得到多少份 309 | # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 310 | # print("num_instance:",num_instance) 311 | # # 切分成多份,添加到列表中 312 | # result_list=[] 313 | # for j in range(num_instance): 314 | # index=j*max_sequence_length_allowed 315 | # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 316 | # result_list.append(big_list[index:end_index]) 317 | return result_list 318 | 319 | def create_instances_from_document( # 新增的方法 320 | # 目标按照RoBERTa的思路,使用DOC-SENTENCES,并会去掉NSP任务: 从一个文档中连续的获得文本,直到达到最大长度。如果是从下一个文档中获得,那么加上一个分隔符 321 | # document即一整段话,包含多个句子。每个句子叫做segment. 322 | # 给定一个document即一整段话,生成一些instance. 323 | all_documents, document_index, max_seq_length, short_seq_prob, 324 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 325 | """Creates `TrainingInstance`s for a single document.""" 326 | document = all_documents[document_index] 327 | 328 | # Account for [CLS], [SEP], [SEP] 329 | max_num_tokens = max_seq_length - 3 330 | 331 | # We *usually* want to fill up the entire sequence since we are padding 332 | # to `max_seq_length` anyways, so short sequences are generally wasted 333 | # computation. However, we *sometimes* 334 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 335 | # sequences to minimize the mismatch between pre-training and fine-tuning. 336 | # The `target_seq_length` is just a rough target however, whereas 337 | # `max_seq_length` is a hard limit. 338 | 339 | #target_seq_length = max_num_tokens 340 | #if rng.random() < short_seq_prob: 341 | # target_seq_length = rng.randint(2, max_num_tokens) 342 | 343 | instances = [] 344 | raw_text_list_list=get_raw_instance(document, max_seq_length) # document即一整段话,包含多个句子。每个句子叫做segment. 345 | for j, raw_text_list in enumerate(raw_text_list_list): 346 | #################################################################################################################### 347 | raw_text_list = get_new_segment(raw_text_list) # 结合分词的中文的whole mask设置即在需要的地方加上“##” 348 | # 1、设置token, segment_ids 349 | is_random_next=True # this will not be used, so it's value doesn't matter 350 | tokens = [] 351 | segment_ids = [] 352 | tokens.append("[CLS]") 353 | segment_ids.append(0) 354 | for token in raw_text_list: 355 | tokens.append(token) 356 | segment_ids.append(0) 357 | tokens.append("[SEP]") 358 | segment_ids.append(0) 359 | ################################################################################################################ 360 | # 2、调用原有的方法 361 | (tokens, masked_lm_positions, 362 | masked_lm_labels) = create_masked_lm_predictions( 363 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 364 | instance = TrainingInstance( 365 | tokens=tokens, 366 | segment_ids=segment_ids, 367 | is_random_next=is_random_next, 368 | masked_lm_positions=masked_lm_positions, 369 | masked_lm_labels=masked_lm_labels) 370 | instances.append(instance) 371 | 372 | return instances 373 | 374 | 375 | 376 | def create_instances_from_document_original( 377 | all_documents, document_index, max_seq_length, short_seq_prob, 378 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 379 | """Creates `TrainingInstance`s for a single document.""" 380 | document = all_documents[document_index] 381 | 382 | # Account for [CLS], [SEP], [SEP] 383 | max_num_tokens = max_seq_length - 3 384 | 385 | # We *usually* want to fill up the entire sequence since we are padding 386 | # to `max_seq_length` anyways, so short sequences are generally wasted 387 | # computation. However, we *sometimes* 388 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 389 | # sequences to minimize the mismatch between pre-training and fine-tuning. 390 | # The `target_seq_length` is just a rough target however, whereas 391 | # `max_seq_length` is a hard limit. 392 | target_seq_length = max_num_tokens 393 | if rng.random() < short_seq_prob: 394 | target_seq_length = rng.randint(2, max_num_tokens) 395 | 396 | # We DON'T just concatenate all of the tokens from a document into a long 397 | # sequence and choose an arbitrary split point because this would make the 398 | # next sentence prediction task too easy. Instead, we split the input into 399 | # segments "A" and "B" based on the actual "sentences" provided by the user 400 | # input. 401 | instances = [] 402 | current_chunk = [] 403 | current_length = 0 404 | i = 0 405 | print("document_index:",document_index,"document:",type(document)," ;document:",document) # document即一整段话,包含多个句子。每个句子叫做segment. 406 | while i < len(document): 407 | segment = document[i] # 取到一个部分(可能是一段话) 408 | print("i:",i," ;segment:",segment) 409 | #################################################################################################################### 410 | segment = get_new_segment(segment) # 结合分词的中文的whole mask设置即在需要的地方加上“##” 411 | ################################################################################################################### 412 | current_chunk.append(segment) 413 | current_length += len(segment) 414 | print("#####condition:",i == len(document) - 1 or current_length >= target_seq_length) 415 | if i == len(document) - 1 or current_length >= target_seq_length: 416 | if current_chunk: 417 | # `a_end` is how many segments from `current_chunk` go into the `A` 418 | # (first) sentence. 419 | a_end = 1 420 | if len(current_chunk) >= 2: 421 | a_end = rng.randint(1, len(current_chunk) - 1) 422 | 423 | tokens_a = [] 424 | for j in range(a_end): 425 | tokens_a.extend(current_chunk[j]) 426 | 427 | tokens_b = [] 428 | # Random next 429 | is_random_next = False 430 | if len(current_chunk) == 1 or rng.random() < 0.5: 431 | is_random_next = True 432 | target_b_length = target_seq_length - len(tokens_a) 433 | 434 | # This should rarely go for more than one iteration for large 435 | # corpora. However, just to be careful, we try to make sure that 436 | # the random document is not the same as the document 437 | # we're processing. 438 | for _ in range(10): 439 | random_document_index = rng.randint(0, len(all_documents) - 1) 440 | if random_document_index != document_index: 441 | break 442 | 443 | random_document = all_documents[random_document_index] 444 | random_start = rng.randint(0, len(random_document) - 1) 445 | for j in range(random_start, len(random_document)): 446 | tokens_b.extend(random_document[j]) 447 | if len(tokens_b) >= target_b_length: 448 | break 449 | # We didn't actually use these segments so we "put them back" so 450 | # they don't go to waste. 451 | num_unused_segments = len(current_chunk) - a_end 452 | i -= num_unused_segments 453 | # Actual next 454 | else: 455 | is_random_next = False 456 | for j in range(a_end, len(current_chunk)): 457 | tokens_b.extend(current_chunk[j]) 458 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 459 | 460 | assert len(tokens_a) >= 1 461 | assert len(tokens_b) >= 1 462 | 463 | tokens = [] 464 | segment_ids = [] 465 | tokens.append("[CLS]") 466 | segment_ids.append(0) 467 | for token in tokens_a: 468 | tokens.append(token) 469 | segment_ids.append(0) 470 | 471 | tokens.append("[SEP]") 472 | segment_ids.append(0) 473 | 474 | for token in tokens_b: 475 | tokens.append(token) 476 | segment_ids.append(1) 477 | tokens.append("[SEP]") 478 | segment_ids.append(1) 479 | 480 | (tokens, masked_lm_positions, 481 | masked_lm_labels) = create_masked_lm_predictions( 482 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 483 | instance = TrainingInstance( 484 | tokens=tokens, 485 | segment_ids=segment_ids, 486 | is_random_next=is_random_next, 487 | masked_lm_positions=masked_lm_positions, 488 | masked_lm_labels=masked_lm_labels) 489 | instances.append(instance) 490 | current_chunk = [] 491 | current_length = 0 492 | i += 1 493 | 494 | return instances 495 | 496 | 497 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 498 | ["index", "label"]) 499 | 500 | 501 | def create_masked_lm_predictions(tokens, masked_lm_prob, 502 | max_predictions_per_seq, vocab_words, rng): 503 | """Creates the predictions for the masked LM objective.""" 504 | 505 | cand_indexes = [] 506 | for (i, token) in enumerate(tokens): 507 | if token == "[CLS]" or token == "[SEP]": 508 | continue 509 | # Whole Word Masking means that if we mask all of the wordpieces 510 | # corresponding to an original word. When a word has been split into 511 | # WordPieces, the first token does not have any marker and any subsequence 512 | # tokens are prefixed with ##. So whenever we see the ## token, we 513 | # append it to the previous set of word indexes. 514 | # 515 | # Note that Whole Word Masking does *not* change the training code 516 | # at all -- we still predict each WordPiece independently, softmaxed 517 | # over the entire vocabulary. 518 | if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and 519 | token.startswith("##")): 520 | cand_indexes[-1].append(i) 521 | else: 522 | cand_indexes.append([i]) 523 | 524 | rng.shuffle(cand_indexes) 525 | 526 | output_tokens = [t[2:] if len(re.findall('##[\u4E00-\u9FA5]', t))>0 else t for t in tokens] 527 | 528 | num_to_predict = min(max_predictions_per_seq, 529 | max(1, int(round(len(tokens) * masked_lm_prob)))) 530 | 531 | masked_lms = [] 532 | covered_indexes = set() 533 | for index_set in cand_indexes: 534 | if len(masked_lms) >= num_to_predict: 535 | break 536 | # If adding a whole-word mask would exceed the maximum number of 537 | # predictions, then just skip this candidate. 538 | if len(masked_lms) + len(index_set) > num_to_predict: 539 | continue 540 | is_any_index_covered = False 541 | for index in index_set: 542 | if index in covered_indexes: 543 | is_any_index_covered = True 544 | break 545 | if is_any_index_covered: 546 | continue 547 | for index in index_set: 548 | covered_indexes.add(index) 549 | 550 | masked_token = None 551 | # 80% of the time, replace with [MASK] 552 | if rng.random() < 0.8: 553 | masked_token = "[MASK]" 554 | else: 555 | # 10% of the time, keep original 556 | if rng.random() < 0.5: 557 | masked_token = tokens[index][2:] if len(re.findall('##[\u4E00-\u9FA5]', tokens[index]))>0 else tokens[index] 558 | # 10% of the time, replace with random word 559 | else: 560 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 561 | 562 | output_tokens[index] = masked_token 563 | 564 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 565 | assert len(masked_lms) <= num_to_predict 566 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 567 | 568 | masked_lm_positions = [] 569 | masked_lm_labels = [] 570 | for p in masked_lms: 571 | masked_lm_positions.append(p.index) 572 | masked_lm_labels.append(p.label) 573 | 574 | # tf.logging.info('%s' % (tokens)) 575 | # tf.logging.info('%s' % (output_tokens)) 576 | return (output_tokens, masked_lm_positions, masked_lm_labels) 577 | 578 | 579 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 580 | """Truncates a pair of sequences to a maximum sequence length.""" 581 | while True: 582 | total_length = len(tokens_a) + len(tokens_b) 583 | if total_length <= max_num_tokens: 584 | break 585 | 586 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 587 | assert len(trunc_tokens) >= 1 588 | 589 | # We want to sometimes truncate from the front and sometimes from the 590 | # back to add more randomness and avoid biases. 591 | if rng.random() < 0.5: 592 | del trunc_tokens[0] 593 | else: 594 | trunc_tokens.pop() 595 | 596 | 597 | def main(_): 598 | tf.logging.set_verbosity(tf.logging.INFO) 599 | 600 | tokenizer = tokenization.FullTokenizer( 601 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 602 | 603 | input_files = [] 604 | for input_pattern in FLAGS.input_file.split(","): 605 | input_files.extend(tf.gfile.Glob(input_pattern)) 606 | 607 | tf.logging.info("*** Reading from input files ***") 608 | for input_file in input_files: 609 | tf.logging.info(" %s", input_file) 610 | 611 | rng = random.Random(FLAGS.random_seed) 612 | instances = create_training_instances( 613 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 614 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 615 | rng) 616 | 617 | output_files = FLAGS.output_file.split(",") 618 | tf.logging.info("*** Writing to output files ***") 619 | for output_file in output_files: 620 | tf.logging.info(" %s", output_file) 621 | 622 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 623 | FLAGS.max_predictions_per_seq, output_files) 624 | 625 | 626 | if __name__ == "__main__": 627 | flags.mark_flag_as_required("input_file") 628 | flags.mark_flag_as_required("output_file") 629 | flags.mark_flag_as_required("vocab_file") 630 | tf.app.run() -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=False, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. 150 | scope: (optional) variable scope. Defaults to "bert". 151 | 152 | Raises: 153 | ValueError: The config is invalid or one of the input tensor shapes 154 | is invalid. 155 | """ 156 | config = copy.deepcopy(config) 157 | if not is_training: 158 | config.hidden_dropout_prob = 0.0 159 | config.attention_probs_dropout_prob = 0.0 160 | 161 | input_shape = get_shape_list(input_ids, expected_rank=2) 162 | batch_size = input_shape[0] 163 | seq_length = input_shape[1] 164 | 165 | if input_mask is None: 166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 167 | 168 | if token_type_ids is None: 169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | with tf.variable_scope(scope, default_name="bert"): 172 | with tf.variable_scope("embeddings"): 173 | # Perform embedding lookup on the word ids. 174 | (self.embedding_output, self.embedding_table) = embedding_lookup( 175 | input_ids=input_ids, 176 | vocab_size=config.vocab_size, 177 | embedding_size=config.hidden_size, 178 | initializer_range=config.initializer_range, 179 | word_embedding_name="word_embeddings", 180 | use_one_hot_embeddings=use_one_hot_embeddings) 181 | 182 | # Add positional embeddings and token type embeddings, then layer 183 | # normalize and perform dropout. 184 | self.embedding_output = embedding_postprocessor( 185 | input_tensor=self.embedding_output, 186 | use_token_type=True, 187 | token_type_ids=token_type_ids, 188 | token_type_vocab_size=config.type_vocab_size, 189 | token_type_embedding_name="token_type_embeddings", 190 | use_position_embeddings=True, 191 | position_embedding_name="position_embeddings", 192 | initializer_range=config.initializer_range, 193 | max_position_embeddings=config.max_position_embeddings, 194 | dropout_prob=config.hidden_dropout_prob) 195 | 196 | with tf.variable_scope("encoder"): 197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 198 | # mask of shape [batch_size, seq_length, seq_length] which is used 199 | # for the attention scores. 200 | attention_mask = create_attention_mask_from_input_mask( 201 | input_ids, input_mask) 202 | 203 | # Run the stacked transformer. 204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 205 | self.all_encoder_layers = transformer_model( 206 | input_tensor=self.embedding_output, 207 | attention_mask=attention_mask, 208 | hidden_size=config.hidden_size, 209 | num_hidden_layers=config.num_hidden_layers, 210 | num_attention_heads=config.num_attention_heads, 211 | intermediate_size=config.intermediate_size, 212 | intermediate_act_fn=get_activation(config.hidden_act), 213 | hidden_dropout_prob=config.hidden_dropout_prob, 214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 215 | initializer_range=config.initializer_range, 216 | do_return_all_layers=True) 217 | 218 | self.sequence_output = self.all_encoder_layers[-1] # [batch_size, seq_length, hidden_size] 219 | # The "pooler" converts the encoded sequence tensor of shape 220 | # [batch_size, seq_length, hidden_size] to a tensor of shape 221 | # [batch_size, hidden_size]. This is necessary for segment-level 222 | # (or segment-pair-level) classification tasks where we need a fixed 223 | # dimensional representation of the segment. 224 | with tf.variable_scope("pooler"): 225 | # We "pool" the model by simply taking the hidden state corresponding 226 | # to the first token. We assume that this has been pre-trained 227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 228 | self.pooled_output = tf.layers.dense( 229 | first_token_tensor, 230 | config.hidden_size, 231 | activation=tf.tanh, 232 | kernel_initializer=create_initializer(config.initializer_range)) 233 | 234 | def get_pooled_output(self): 235 | return self.pooled_output 236 | 237 | def get_sequence_output(self): 238 | """Gets final hidden layer of encoder. 239 | 240 | Returns: 241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 242 | to the final hidden of the transformer encoder. 243 | """ 244 | return self.sequence_output 245 | 246 | def get_all_encoder_layers(self): 247 | return self.all_encoder_layers 248 | 249 | def get_embedding_output(self): 250 | """Gets output of the embedding lookup (i.e., input to the transformer). 251 | 252 | Returns: 253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 254 | to the output of the embedding layer, after summing the word 255 | embeddings with the positional embeddings and the token type embeddings, 256 | then performing layer normalization. This is the input to the transformer. 257 | """ 258 | return self.embedding_output 259 | 260 | def get_embedding_table(self): 261 | return self.embedding_table 262 | 263 | 264 | def gelu(x): 265 | """Gaussian Error Linear Unit. 266 | 267 | This is a smoother version of the RELU. 268 | Original paper: https://arxiv.org/abs/1606.08415 269 | Args: 270 | x: float Tensor to perform activation. 271 | 272 | Returns: 273 | `x` with the GELU activation applied. 274 | """ 275 | cdf = 0.5 * (1.0 + tf.tanh( 276 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 277 | return x * cdf 278 | 279 | 280 | def get_activation(activation_string): 281 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 282 | 283 | Args: 284 | activation_string: String name of the activation function. 285 | 286 | Returns: 287 | A Python function corresponding to the activation function. If 288 | `activation_string` is None, empty, or "linear", this will return None. 289 | If `activation_string` is not a string, it will return `activation_string`. 290 | 291 | Raises: 292 | ValueError: The `activation_string` does not correspond to a known 293 | activation. 294 | """ 295 | 296 | # We assume that anything that"s not a string is already an activation 297 | # function, so we just return it. 298 | if not isinstance(activation_string, six.string_types): 299 | return activation_string 300 | 301 | if not activation_string: 302 | return None 303 | 304 | act = activation_string.lower() 305 | if act == "linear": 306 | return None 307 | elif act == "relu": 308 | return tf.nn.relu 309 | elif act == "gelu": 310 | return gelu 311 | elif act == "tanh": 312 | return tf.tanh 313 | else: 314 | raise ValueError("Unsupported activation: %s" % act) 315 | 316 | 317 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 318 | """Compute the union of the current variables and checkpoint variables.""" 319 | assignment_map = {} 320 | initialized_variable_names = {} 321 | 322 | name_to_variable = collections.OrderedDict() 323 | for var in tvars: 324 | name = var.name 325 | m = re.match("^(.*):\\d+$", name) 326 | if m is not None: 327 | name = m.group(1) 328 | name_to_variable[name] = var 329 | 330 | init_vars = tf.train.list_variables(init_checkpoint) 331 | 332 | assignment_map = collections.OrderedDict() 333 | for x in init_vars: 334 | (name, var) = (x[0], x[1]) 335 | if name not in name_to_variable: 336 | continue 337 | assignment_map[name] = name 338 | initialized_variable_names[name] = 1 339 | initialized_variable_names[name + ":0"] = 1 340 | 341 | return (assignment_map, initialized_variable_names) 342 | 343 | 344 | def dropout(input_tensor, dropout_prob): 345 | """Perform dropout. 346 | 347 | Args: 348 | input_tensor: float Tensor. 349 | dropout_prob: Python float. The probability of dropping out a value (NOT of 350 | *keeping* a dimension as in `tf.nn.dropout`). 351 | 352 | Returns: 353 | A version of `input_tensor` with dropout applied. 354 | """ 355 | if dropout_prob is None or dropout_prob == 0.0: 356 | return input_tensor 357 | 358 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 359 | return output 360 | 361 | 362 | def layer_norm(input_tensor, name=None): 363 | """Run layer normalization on the last dimension of the tensor.""" 364 | return tf.contrib.layers.layer_norm( 365 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 366 | 367 | 368 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 369 | """Runs layer normalization followed by dropout.""" 370 | output_tensor = layer_norm(input_tensor, name) 371 | output_tensor = dropout(output_tensor, dropout_prob) 372 | return output_tensor 373 | 374 | 375 | def create_initializer(initializer_range=0.02): 376 | """Creates a `truncated_normal_initializer` with the given range.""" 377 | return tf.truncated_normal_initializer(stddev=initializer_range) 378 | 379 | 380 | def embedding_lookup(input_ids, 381 | vocab_size, 382 | embedding_size=128, 383 | initializer_range=0.02, 384 | word_embedding_name="word_embeddings", 385 | use_one_hot_embeddings=False): 386 | """Looks up words embeddings for id tensor. 387 | 388 | Args: 389 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 390 | ids. 391 | vocab_size: int. Size of the embedding vocabulary. 392 | embedding_size: int. Width of the word embeddings. 393 | initializer_range: float. Embedding initialization range. 394 | word_embedding_name: string. Name of the embedding table. 395 | use_one_hot_embeddings: bool. If True, use one-hot method for word 396 | embeddings. If False, use `tf.gather()`. 397 | 398 | Returns: 399 | float Tensor of shape [batch_size, seq_length, embedding_size]. 400 | """ 401 | # This function assumes that the input is of shape [batch_size, seq_length, 402 | # num_inputs]. 403 | # 404 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 405 | # reshape to [batch_size, seq_length, 1]. 406 | if input_ids.shape.ndims == 2: 407 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 408 | 409 | embedding_table = tf.get_variable( 410 | name=word_embedding_name, 411 | shape=[vocab_size, embedding_size], 412 | initializer=create_initializer(initializer_range)) 413 | 414 | flat_input_ids = tf.reshape(input_ids, [-1]) 415 | if use_one_hot_embeddings: 416 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 417 | output = tf.matmul(one_hot_input_ids, embedding_table) 418 | else: 419 | output = tf.gather(embedding_table, flat_input_ids) 420 | 421 | input_shape = get_shape_list(input_ids) 422 | 423 | output = tf.reshape(output, 424 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 425 | return (output, embedding_table) 426 | 427 | 428 | def embedding_postprocessor(input_tensor, 429 | use_token_type=False, 430 | token_type_ids=None, 431 | token_type_vocab_size=16, 432 | token_type_embedding_name="token_type_embeddings", 433 | use_position_embeddings=True, 434 | position_embedding_name="position_embeddings", 435 | initializer_range=0.02, 436 | max_position_embeddings=512, 437 | dropout_prob=0.1): 438 | """Performs various post-processing on a word embedding tensor. 439 | 440 | Args: 441 | input_tensor: float Tensor of shape [batch_size, seq_length, 442 | embedding_size]. 443 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 444 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 445 | Must be specified if `use_token_type` is True. 446 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 447 | token_type_embedding_name: string. The name of the embedding table variable 448 | for token type ids. 449 | use_position_embeddings: bool. Whether to add position embeddings for the 450 | position of each token in the sequence. 451 | position_embedding_name: string. The name of the embedding table variable 452 | for positional embeddings. 453 | initializer_range: float. Range of the weight initialization. 454 | max_position_embeddings: int. Maximum sequence length that might ever be 455 | used with this model. This can be longer than the sequence length of 456 | input_tensor, but cannot be shorter. 457 | dropout_prob: float. Dropout probability applied to the final output tensor. 458 | 459 | Returns: 460 | float tensor with same shape as `input_tensor`. 461 | 462 | Raises: 463 | ValueError: One of the tensor shapes or input values is invalid. 464 | """ 465 | input_shape = get_shape_list(input_tensor, expected_rank=3) 466 | batch_size = input_shape[0] 467 | seq_length = input_shape[1] 468 | width = input_shape[2] 469 | 470 | output = input_tensor 471 | 472 | if use_token_type: 473 | if token_type_ids is None: 474 | raise ValueError("`token_type_ids` must be specified if" 475 | "`use_token_type` is True.") 476 | token_type_table = tf.get_variable( 477 | name=token_type_embedding_name, 478 | shape=[token_type_vocab_size, width], 479 | initializer=create_initializer(initializer_range)) 480 | # This vocab will be small so we always do one-hot here, since it is always 481 | # faster for a small vocabulary. 482 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 483 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 484 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 485 | token_type_embeddings = tf.reshape(token_type_embeddings, 486 | [batch_size, seq_length, width]) 487 | output += token_type_embeddings 488 | 489 | if use_position_embeddings: 490 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 491 | with tf.control_dependencies([assert_op]): 492 | full_position_embeddings = tf.get_variable( 493 | name=position_embedding_name, 494 | shape=[max_position_embeddings, width], 495 | initializer=create_initializer(initializer_range)) 496 | # Since the position embedding table is a learned variable, we create it 497 | # using a (long) sequence length `max_position_embeddings`. The actual 498 | # sequence length might be shorter than this, for faster training of 499 | # tasks that do not have long sequences. 500 | # 501 | # So `full_position_embeddings` is effectively an embedding table 502 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 503 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 504 | # perform a slice. 505 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 506 | [seq_length, -1]) 507 | num_dims = len(output.shape.as_list()) 508 | 509 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 510 | # we broadcast among the first dimensions, which is typically just 511 | # the batch size. 512 | position_broadcast_shape = [] 513 | for _ in range(num_dims - 2): 514 | position_broadcast_shape.append(1) 515 | position_broadcast_shape.extend([seq_length, width]) 516 | position_embeddings = tf.reshape(position_embeddings, 517 | position_broadcast_shape) 518 | output += position_embeddings 519 | 520 | output = layer_norm_and_dropout(output, dropout_prob) 521 | return output 522 | 523 | 524 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 525 | """Create 3D attention mask from a 2D tensor mask. 526 | 527 | Args: 528 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 529 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 530 | 531 | Returns: 532 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 533 | """ 534 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 535 | batch_size = from_shape[0] 536 | from_seq_length = from_shape[1] 537 | 538 | to_shape = get_shape_list(to_mask, expected_rank=2) 539 | to_seq_length = to_shape[1] 540 | 541 | to_mask = tf.cast( 542 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 543 | 544 | # We don't assume that `from_tensor` is a mask (although it could be). We 545 | # don't actually care if we attend *from* padding tokens (only *to* padding) 546 | # tokens so we create a tensor of all ones. 547 | # 548 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 549 | broadcast_ones = tf.ones( 550 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 551 | 552 | # Here we broadcast along two dimensions to create the mask. 553 | mask = broadcast_ones * to_mask 554 | 555 | return mask 556 | 557 | 558 | def attention_layer(from_tensor, 559 | to_tensor, 560 | attention_mask=None, 561 | num_attention_heads=1, 562 | size_per_head=512, 563 | query_act=None, 564 | key_act=None, 565 | value_act=None, 566 | attention_probs_dropout_prob=0.0, 567 | initializer_range=0.02, 568 | do_return_2d_tensor=False, 569 | batch_size=None, 570 | from_seq_length=None, 571 | to_seq_length=None): 572 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 573 | 574 | This is an implementation of multi-headed attention based on "Attention 575 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 576 | this is self-attention. Each timestep in `from_tensor` attends to the 577 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 578 | 579 | This function first projects `from_tensor` into a "query" tensor and 580 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 581 | of tensors of length `num_attention_heads`, where each tensor is of shape 582 | [batch_size, seq_length, size_per_head]. 583 | 584 | Then, the query and key tensors are dot-producted and scaled. These are 585 | softmaxed to obtain attention probabilities. The value tensors are then 586 | interpolated by these probabilities, then concatenated back to a single 587 | tensor and returned. 588 | 589 | In practice, the multi-headed attention are done with transposes and 590 | reshapes rather than actual separate tensors. 591 | 592 | Args: 593 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 594 | from_width]. 595 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 596 | attention_mask: (optional) int32 Tensor of shape [batch_size, 597 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 598 | attention scores will effectively be set to -infinity for any positions in 599 | the mask that are 0, and will be unchanged for positions that are 1. 600 | num_attention_heads: int. Number of attention heads. 601 | size_per_head: int. Size of each attention head. 602 | query_act: (optional) Activation function for the query transform. 603 | key_act: (optional) Activation function for the key transform. 604 | value_act: (optional) Activation function for the value transform. 605 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 606 | attention probabilities. 607 | initializer_range: float. Range of the weight initializer. 608 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 609 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 610 | output will be of shape [batch_size, from_seq_length, num_attention_heads 611 | * size_per_head]. 612 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 613 | of the 3D version of the `from_tensor` and `to_tensor`. 614 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 615 | of the 3D version of the `from_tensor`. 616 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `to_tensor`. 618 | 619 | Returns: 620 | float Tensor of shape [batch_size, from_seq_length, 621 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 622 | true, this will be of shape [batch_size * from_seq_length, 623 | num_attention_heads * size_per_head]). 624 | 625 | Raises: 626 | ValueError: Any of the arguments or tensor shapes are invalid. 627 | """ 628 | 629 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 630 | seq_length, width): 631 | output_tensor = tf.reshape( 632 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 633 | 634 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 635 | return output_tensor 636 | 637 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 638 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 639 | 640 | if len(from_shape) != len(to_shape): 641 | raise ValueError( 642 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 643 | 644 | if len(from_shape) == 3: 645 | batch_size = from_shape[0] 646 | from_seq_length = from_shape[1] 647 | to_seq_length = to_shape[1] 648 | elif len(from_shape) == 2: 649 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 650 | raise ValueError( 651 | "When passing in rank 2 tensors to attention_layer, the values " 652 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 653 | "must all be specified.") 654 | 655 | # Scalar dimensions referenced here: 656 | # B = batch size (number of sequences) 657 | # F = `from_tensor` sequence length 658 | # T = `to_tensor` sequence length 659 | # N = `num_attention_heads` 660 | # H = `size_per_head` 661 | 662 | from_tensor_2d = reshape_to_matrix(from_tensor) 663 | to_tensor_2d = reshape_to_matrix(to_tensor) 664 | 665 | # `query_layer` = [B*F, N*H] 666 | query_layer = tf.layers.dense( 667 | from_tensor_2d, 668 | num_attention_heads * size_per_head, 669 | activation=query_act, 670 | name="query", 671 | kernel_initializer=create_initializer(initializer_range)) 672 | 673 | # `key_layer` = [B*T, N*H] 674 | key_layer = tf.layers.dense( 675 | to_tensor_2d, 676 | num_attention_heads * size_per_head, 677 | activation=key_act, 678 | name="key", 679 | kernel_initializer=create_initializer(initializer_range)) 680 | 681 | # `value_layer` = [B*T, N*H] 682 | value_layer = tf.layers.dense( 683 | to_tensor_2d, 684 | num_attention_heads * size_per_head, 685 | activation=value_act, 686 | name="value", 687 | kernel_initializer=create_initializer(initializer_range)) 688 | 689 | # `query_layer` = [B, N, F, H] 690 | query_layer = transpose_for_scores(query_layer, batch_size, 691 | num_attention_heads, from_seq_length, 692 | size_per_head) 693 | 694 | # `key_layer` = [B, N, T, H] 695 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 696 | to_seq_length, size_per_head) 697 | 698 | # Take the dot product between "query" and "key" to get the raw 699 | # attention scores. 700 | # `attention_scores` = [B, N, F, T] 701 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 702 | attention_scores = tf.multiply(attention_scores, 703 | 1.0 / math.sqrt(float(size_per_head))) 704 | 705 | if attention_mask is not None: 706 | # `attention_mask` = [B, 1, F, T] 707 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 708 | 709 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 710 | # masked positions, this operation will create a tensor which is 0.0 for 711 | # positions we want to attend and -10000.0 for masked positions. 712 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 713 | 714 | # Since we are adding it to the raw scores before the softmax, this is 715 | # effectively the same as removing these entirely. 716 | attention_scores += adder 717 | 718 | # Normalize the attention scores to probabilities. 719 | # `attention_probs` = [B, N, F, T] 720 | attention_probs = tf.nn.softmax(attention_scores) 721 | 722 | # This is actually dropping out entire tokens to attend to, which might 723 | # seem a bit unusual, but is taken from the original Transformer paper. 724 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 725 | 726 | # `value_layer` = [B, T, N, H] 727 | value_layer = tf.reshape( 728 | value_layer, 729 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 730 | 731 | # `value_layer` = [B, N, T, H] 732 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 733 | 734 | # `context_layer` = [B, N, F, H] 735 | context_layer = tf.matmul(attention_probs, value_layer) 736 | 737 | # `context_layer` = [B, F, N, H] 738 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 739 | 740 | if do_return_2d_tensor: 741 | # `context_layer` = [B*F, N*H] 742 | context_layer = tf.reshape( 743 | context_layer, 744 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 745 | else: 746 | # `context_layer` = [B, F, N*H] 747 | context_layer = tf.reshape( 748 | context_layer, 749 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 750 | 751 | return context_layer 752 | 753 | 754 | def transformer_model(input_tensor, 755 | attention_mask=None, 756 | hidden_size=768, 757 | num_hidden_layers=12, 758 | num_attention_heads=12, 759 | intermediate_size=3072, 760 | intermediate_act_fn=gelu, 761 | hidden_dropout_prob=0.1, 762 | attention_probs_dropout_prob=0.1, 763 | initializer_range=0.02, 764 | do_return_all_layers=False): 765 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 766 | 767 | This is almost an exact implementation of the original Transformer encoder. 768 | 769 | See the original paper: 770 | https://arxiv.org/abs/1706.03762 771 | 772 | Also see: 773 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 774 | 775 | Args: 776 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 777 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 778 | seq_length], with 1 for positions that can be attended to and 0 in 779 | positions that should not be. 780 | hidden_size: int. Hidden size of the Transformer. 781 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 782 | num_attention_heads: int. Number of attention heads in the Transformer. 783 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 784 | forward) layer. 785 | intermediate_act_fn: function. The non-linear activation function to apply 786 | to the output of the intermediate/feed-forward layer. 787 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 788 | attention_probs_dropout_prob: float. Dropout probability of the attention 789 | probabilities. 790 | initializer_range: float. Range of the initializer (stddev of truncated 791 | normal). 792 | do_return_all_layers: Whether to also return all layers or just the final 793 | layer. 794 | 795 | Returns: 796 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 797 | hidden layer of the Transformer. 798 | 799 | Raises: 800 | ValueError: A Tensor shape or parameter is invalid. 801 | """ 802 | if hidden_size % num_attention_heads != 0: 803 | raise ValueError( 804 | "The hidden size (%d) is not a multiple of the number of attention " 805 | "heads (%d)" % (hidden_size, num_attention_heads)) 806 | 807 | attention_head_size = int(hidden_size / num_attention_heads) 808 | input_shape = get_shape_list(input_tensor, expected_rank=3) 809 | batch_size = input_shape[0] 810 | seq_length = input_shape[1] 811 | input_width = input_shape[2] 812 | 813 | # The Transformer performs sum residuals on all layers so the input needs 814 | # to be the same as the hidden size. 815 | if input_width != hidden_size: 816 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 817 | (input_width, hidden_size)) 818 | 819 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 820 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 821 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 822 | # help the optimizer. 823 | prev_output = reshape_to_matrix(input_tensor) 824 | 825 | all_layer_outputs = [] 826 | for layer_idx in range(num_hidden_layers): 827 | with tf.variable_scope("layer_%d" % layer_idx): 828 | layer_input = prev_output 829 | 830 | with tf.variable_scope("attention"): 831 | attention_heads = [] 832 | with tf.variable_scope("self"): 833 | attention_head = attention_layer( 834 | from_tensor=layer_input, 835 | to_tensor=layer_input, 836 | attention_mask=attention_mask, 837 | num_attention_heads=num_attention_heads, 838 | size_per_head=attention_head_size, 839 | attention_probs_dropout_prob=attention_probs_dropout_prob, 840 | initializer_range=initializer_range, 841 | do_return_2d_tensor=True, 842 | batch_size=batch_size, 843 | from_seq_length=seq_length, 844 | to_seq_length=seq_length) 845 | attention_heads.append(attention_head) 846 | 847 | attention_output = None 848 | if len(attention_heads) == 1: 849 | attention_output = attention_heads[0] 850 | else: 851 | # In the case where we have other sequences, we just concatenate 852 | # them to the self-attention head before the projection. 853 | attention_output = tf.concat(attention_heads, axis=-1) 854 | 855 | # Run a linear projection of `hidden_size` then add a residual 856 | # with `layer_input`. 857 | with tf.variable_scope("output"): 858 | attention_output = tf.layers.dense( 859 | attention_output, 860 | hidden_size, 861 | kernel_initializer=create_initializer(initializer_range)) 862 | attention_output = dropout(attention_output, hidden_dropout_prob) 863 | attention_output = layer_norm(attention_output + layer_input) 864 | 865 | # The activation is only applied to the "intermediate" hidden layer. 866 | with tf.variable_scope("intermediate"): 867 | intermediate_output = tf.layers.dense( 868 | attention_output, 869 | intermediate_size, 870 | activation=intermediate_act_fn, 871 | kernel_initializer=create_initializer(initializer_range)) 872 | 873 | # Down-project back to `hidden_size` then add the residual. 874 | with tf.variable_scope("output"): 875 | layer_output = tf.layers.dense( 876 | intermediate_output, 877 | hidden_size, 878 | kernel_initializer=create_initializer(initializer_range)) 879 | layer_output = dropout(layer_output, hidden_dropout_prob) 880 | layer_output = layer_norm(layer_output + attention_output) 881 | prev_output = layer_output 882 | all_layer_outputs.append(layer_output) 883 | 884 | if do_return_all_layers: 885 | final_outputs = [] 886 | for layer_output in all_layer_outputs: 887 | final_output = reshape_from_matrix(layer_output, input_shape) 888 | final_outputs.append(final_output) 889 | return final_outputs 890 | else: 891 | final_output = reshape_from_matrix(prev_output, input_shape) 892 | return final_output 893 | 894 | 895 | def get_shape_list(tensor, expected_rank=None, name=None): 896 | """Returns a list of the shape of tensor, preferring static dimensions. 897 | 898 | Args: 899 | tensor: A tf.Tensor object to find the shape of. 900 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 901 | specified and the `tensor` has a different rank, and exception will be 902 | thrown. 903 | name: Optional name of the tensor for the error message. 904 | 905 | Returns: 906 | A list of dimensions of the shape of tensor. All static dimensions will 907 | be returned as python integers, and dynamic dimensions will be returned 908 | as tf.Tensor scalars. 909 | """ 910 | if name is None: 911 | name = tensor.name 912 | 913 | if expected_rank is not None: 914 | assert_rank(tensor, expected_rank, name) 915 | 916 | shape = tensor.shape.as_list() 917 | 918 | non_static_indexes = [] 919 | for (index, dim) in enumerate(shape): 920 | if dim is None: 921 | non_static_indexes.append(index) 922 | 923 | if not non_static_indexes: 924 | return shape 925 | 926 | dyn_shape = tf.shape(tensor) 927 | for index in non_static_indexes: 928 | shape[index] = dyn_shape[index] 929 | return shape 930 | 931 | 932 | def reshape_to_matrix(input_tensor): 933 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 934 | ndims = input_tensor.shape.ndims 935 | if ndims < 2: 936 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 937 | (input_tensor.shape)) 938 | if ndims == 2: 939 | return input_tensor 940 | 941 | width = input_tensor.shape[-1] 942 | output_tensor = tf.reshape(input_tensor, [-1, width]) 943 | return output_tensor 944 | 945 | 946 | def reshape_from_matrix(output_tensor, orig_shape_list): 947 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 948 | if len(orig_shape_list) == 2: 949 | return output_tensor 950 | 951 | output_shape = get_shape_list(output_tensor) 952 | 953 | orig_dims = orig_shape_list[0:-1] 954 | width = output_shape[-1] 955 | 956 | return tf.reshape(output_tensor, orig_dims + [width]) 957 | 958 | 959 | def assert_rank(tensor, expected_rank, name=None): 960 | """Raises an exception if the tensor rank is not of the expected rank. 961 | 962 | Args: 963 | tensor: A tf.Tensor to check the rank of. 964 | expected_rank: Python integer or list of integers, expected rank. 965 | name: Optional name of the tensor for the error message. 966 | 967 | Raises: 968 | ValueError: If the expected shape doesn't match the actual shape. 969 | """ 970 | if name is None: 971 | name = tensor.name 972 | 973 | expected_rank_dict = {} 974 | if isinstance(expected_rank, six.integer_types): 975 | expected_rank_dict[expected_rank] = True 976 | else: 977 | for x in expected_rank: 978 | expected_rank_dict[x] = True 979 | 980 | actual_rank = tensor.shape.ndims 981 | if actual_rank not in expected_rank_dict: 982 | scope_name = tf.get_variable_scope().name 983 | raise ValueError( 984 | "For the tensor `%s` in scope `%s`, the actual rank " 985 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 986 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 987 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.98, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | 72 | # tvars=find_train_variables(tvars) # fix parameters from layer 0 to layer9. 73 | 74 | grads = tf.gradients(loss, tvars) 75 | 76 | # This is how the model was pre-trained. 77 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 78 | 79 | train_op = optimizer.apply_gradients( 80 | zip(grads, tvars), global_step=global_step) 81 | 82 | # Normally the global step update is done inside of `apply_gradients`. 83 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 84 | # a different optimizer, you should probably take this line out. 85 | new_global_step = global_step + 1 86 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 87 | return train_op 88 | 89 | def find_train_variables(tvars): 90 | """ 91 | get trainable variables only to train from layer 9 to last layer 92 | :param tvars: a list 93 | :return: a new tvars, which is list 94 | """ 95 | # bert/encoder/layer_21, bert/encoder/layer_9, bert/encoder/layer_20/attention/output/dense/bias:0, bert/encoder/layer_20/attention/output/dense/kernel: 96 | tvars_result_list=[] 97 | 98 | for var in tvars: 99 | if 'cls/predictions' in var.name or 'bert/pooler/dense' in var.name: # 最后几层 100 | tvars_result_list.append(var) 101 | else: # 后半个网络的参数 102 | layer_number_list=re.findall("layer_(.+?)/", var.name) 103 | if len(layer_number_list)>0 and isinstance(layer_number_list[0],int): # 匹配到了一个数字 104 | layer_number=int(layer_number_list[0]) 105 | if layer_number>=9: 106 | tvars_result_list.append(var) 107 | 108 | # print train variables 109 | for i,var_ in enumerate(tvars_result_list): 110 | print("####find_train_variables.i:",i, "variable name:",var_.name) 111 | 112 | print("####find_train_variables:length of tvars_result_list:",tvars_result_list) 113 | return tvars_result_list 114 | 115 | 116 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 117 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 118 | 119 | def __init__(self, 120 | learning_rate, 121 | weight_decay_rate=0.0, 122 | beta_1=0.9, 123 | beta_2=0.999, 124 | epsilon=1e-6, 125 | exclude_from_weight_decay=None, 126 | name="AdamWeightDecayOptimizer"): 127 | """Constructs a AdamWeightDecayOptimizer.""" 128 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 129 | 130 | self.learning_rate = learning_rate 131 | self.weight_decay_rate = weight_decay_rate 132 | self.beta_1 = beta_1 133 | self.beta_2 = beta_2 134 | self.epsilon = epsilon 135 | self.exclude_from_weight_decay = exclude_from_weight_decay 136 | 137 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 138 | """See base class.""" 139 | assignments = [] 140 | for (grad, param) in grads_and_vars: 141 | if grad is None or param is None: 142 | continue 143 | 144 | param_name = self._get_variable_name(param.name) 145 | 146 | m = tf.get_variable( 147 | name=param_name + "/adam_m", 148 | shape=param.shape.as_list(), 149 | dtype=tf.float32, 150 | trainable=False, 151 | initializer=tf.zeros_initializer()) 152 | v = tf.get_variable( 153 | name=param_name + "/adam_v", 154 | shape=param.shape.as_list(), 155 | dtype=tf.float32, 156 | trainable=False, 157 | initializer=tf.zeros_initializer()) 158 | 159 | # Standard Adam update. 160 | next_m = ( 161 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 162 | next_v = ( 163 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 164 | tf.square(grad))) 165 | 166 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 167 | 168 | # Just adding the square of the weights to the loss function is *not* 169 | # the correct way of using L2 regularization/weight decay with Adam, 170 | # since that will interact with the m and v parameters in strange ways. 171 | # 172 | # Instead we want ot decay the weights in a manner that doesn't interact 173 | # with the m/v parameters. This is equivalent to adding the square 174 | # of the weights to the loss with plain (non-momentum) SGD. 175 | if self._do_use_weight_decay(param_name): 176 | update += self.weight_decay_rate * param 177 | 178 | update_with_lr = self.learning_rate * update 179 | 180 | next_param = param - update_with_lr 181 | 182 | assignments.extend( 183 | [param.assign(next_param), 184 | m.assign(next_m), 185 | v.assign(next_v)]) 186 | return tf.group(*assignments, name=name) 187 | 188 | def _do_use_weight_decay(self, param_name): 189 | """Whether to use L2 weight decay for `param_name`.""" 190 | if not self.weight_decay_rate: 191 | return False 192 | if self.exclude_from_weight_decay: 193 | for r in self.exclude_from_weight_decay: 194 | if re.search(r, param_name) is not None: 195 | return False 196 | return True 197 | 198 | def _get_variable_name(self, param_name): 199 | """Get the variable name from the tensor name.""" 200 | m = re.match("^(.*):\\d+$", param_name) 201 | if m is not None: 202 | param_name = m.group(1) 203 | return param_name 204 | -------------------------------------------------------------------------------- /optimization_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /resources/RoBERTa_zh_Large_Learning_Curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brightmart/roberta_zh/438476f7da1661faf45e5b8da9f55df403e44997/resources/RoBERTa_zh_Large_Learning_Curve.png -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization_finetuning as optimization 26 | import tokenization 27 | import tensorflow as tf 28 | # from loss import bi_tempered_logistic_loss 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string( 36 | "data_dir", None, 37 | "The input data dir. Should contain the .tsv files (or other data files) " 38 | "for the task.") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 46 | 47 | flags.DEFINE_string("vocab_file", None, 48 | "The vocabulary file that the BERT model was trained on.") 49 | 50 | flags.DEFINE_string( 51 | "output_dir", None, 52 | "The output directory where the model checkpoints will be written.") 53 | 54 | ## Other parameters 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", None, 58 | "Initial checkpoint (usually from a pre-trained BERT model).") 59 | 60 | flags.DEFINE_bool( 61 | "do_lower_case", True, 62 | "Whether to lower case the input text. Should be True for uncased " 63 | "models and False for cased models.") 64 | 65 | flags.DEFINE_integer( 66 | "max_seq_length", 128, 67 | "The maximum total input sequence length after WordPiece tokenization. " 68 | "Sequences longer than this will be truncated, and sequences shorter " 69 | "than this will be padded.") 70 | 71 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 72 | 73 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 74 | 75 | flags.DEFINE_bool( 76 | "do_predict", False, 77 | "Whether to run the model in inference mode on the test set.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 80 | 81 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 82 | 83 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 84 | 85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 86 | 87 | flags.DEFINE_float("num_train_epochs", 3.0, 88 | "Total number of training epochs to perform.") 89 | 90 | flags.DEFINE_float( 91 | "warmup_proportion", 0.1, 92 | "Proportion of training to perform linear learning rate warmup for. " 93 | "E.g., 0.1 = 10% of training.") 94 | 95 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 96 | "How often to save the model checkpoint.") 97 | 98 | flags.DEFINE_integer("iterations_per_loop", 1000, 99 | "How many steps to make in each estimator call.") 100 | 101 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | 128 | class InputExample(object): 129 | """A single training/test example for simple sequence classification.""" 130 | 131 | def __init__(self, guid, text_a, text_b=None, label=None): 132 | """Constructs a InputExample. 133 | Args: 134 | guid: Unique id for the example. 135 | text_a: string. The untokenized text of the first sequence. For single 136 | sequence tasks, only this sequence must be specified. 137 | text_b: (Optional) string. The untokenized text of the second sequence. 138 | Only must be specified for sequence pair tasks. 139 | label: (Optional) string. The label of the example. This should be 140 | specified for train and dev examples, but not for test examples. 141 | """ 142 | self.guid = guid 143 | self.text_a = text_a 144 | self.text_b = text_b 145 | self.label = label 146 | 147 | 148 | class PaddingInputExample(object): 149 | """Fake example so the num input examples is a multiple of the batch size. 150 | When running eval/predict on the TPU, we need to pad the number of examples 151 | to be a multiple of the batch size, because the TPU requires a fixed batch 152 | size. The alternative is to drop the last batch, which is bad because it means 153 | the entire output data won't be generated. 154 | We use this class instead of `None` because treating `None` as padding 155 | battches could cause silent errors. 156 | """ 157 | 158 | 159 | class InputFeatures(object): 160 | """A single set of features of data.""" 161 | 162 | def __init__(self, 163 | input_ids, 164 | input_mask, 165 | segment_ids, 166 | label_id, 167 | is_real_example=True): 168 | self.input_ids = input_ids 169 | self.input_mask = input_mask 170 | self.segment_ids = segment_ids 171 | self.label_id = label_id 172 | self.is_real_example = is_real_example 173 | 174 | 175 | class DataProcessor(object): 176 | """Base class for data converters for sequence classification data sets.""" 177 | 178 | def get_train_examples(self, data_dir): 179 | """Gets a collection of `InputExample`s for the train set.""" 180 | raise NotImplementedError() 181 | 182 | def get_dev_examples(self, data_dir): 183 | """Gets a collection of `InputExample`s for the dev set.""" 184 | raise NotImplementedError() 185 | 186 | def get_test_examples(self, data_dir): 187 | """Gets a collection of `InputExample`s for prediction.""" 188 | raise NotImplementedError() 189 | 190 | def get_labels(self): 191 | """Gets the list of labels for this data set.""" 192 | raise NotImplementedError() 193 | 194 | @classmethod 195 | def _read_tsv(cls, input_file, quotechar=None): 196 | """Reads a tab separated value file.""" 197 | with tf.gfile.Open(input_file, "r") as f: 198 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 199 | lines = [] 200 | for line in reader: 201 | lines.append(line) 202 | return lines 203 | 204 | 205 | class LCQMCPairClassificationProcessor(DataProcessor): # TODO NEED CHANGE2 206 | """Processor for the internal data set. sentence pair classification""" 207 | def __init__(self): 208 | self.language = "zh" 209 | 210 | def get_train_examples(self, data_dir): 211 | """See base class.""" 212 | return self._create_examples( 213 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train") 214 | # dev_0827.tsv 215 | 216 | def get_dev_examples(self, data_dir): 217 | """See base class.""" 218 | return self._create_examples( 219 | self._read_tsv(os.path.join(data_dir, "dev.txt")), "dev") 220 | 221 | def get_test_examples(self, data_dir): 222 | """See base class.""" 223 | return self._create_examples( 224 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 225 | 226 | def get_labels(self): 227 | """See base class.""" 228 | return ["0", "1"] 229 | 230 | def _create_examples(self, lines, set_type): 231 | """Creates examples for the training and dev sets.""" 232 | examples = [] 233 | print("length of lines:",len(lines)) 234 | for (i, line) in enumerate(lines): 235 | #print('#i:',i,line) 236 | if i == 0: 237 | continue 238 | guid = "%s-%s" % (set_type, i) 239 | try: 240 | label = tokenization.convert_to_unicode(line[2]) 241 | text_a = tokenization.convert_to_unicode(line[0]) 242 | text_b = tokenization.convert_to_unicode(line[1]) 243 | examples.append( 244 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 245 | except Exception: 246 | print('###error.i:', i, line) 247 | return examples 248 | 249 | 250 | def convert_single_example(ex_index, example, label_list, max_seq_length, 251 | tokenizer): 252 | """Converts a single `InputExample` into a single `InputFeatures`.""" 253 | 254 | if isinstance(example, PaddingInputExample): 255 | return InputFeatures( 256 | input_ids=[0] * max_seq_length, 257 | input_mask=[0] * max_seq_length, 258 | segment_ids=[0] * max_seq_length, 259 | label_id=0, 260 | is_real_example=False) 261 | 262 | label_map = {} 263 | for (i, label) in enumerate(label_list): 264 | label_map[label] = i 265 | 266 | tokens_a = tokenizer.tokenize(example.text_a) 267 | tokens_b = None 268 | if example.text_b: 269 | tokens_b = tokenizer.tokenize(example.text_b) 270 | 271 | if tokens_b: 272 | # Modifies `tokens_a` and `tokens_b` in place so that the total 273 | # length is less than the specified length. 274 | # Account for [CLS], [SEP], [SEP] with "- 3" 275 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 276 | else: 277 | # Account for [CLS] and [SEP] with "- 2" 278 | if len(tokens_a) > max_seq_length - 2: 279 | tokens_a = tokens_a[0:(max_seq_length - 2)] 280 | 281 | # The convention in BERT is: 282 | # (a) For sequence pairs: 283 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 284 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 285 | # (b) For single sequences: 286 | # tokens: [CLS] the dog is hairy . [SEP] 287 | # type_ids: 0 0 0 0 0 0 0 288 | # 289 | # Where "type_ids" are used to indicate whether this is the first 290 | # sequence or the second sequence. The embedding vectors for `type=0` and 291 | # `type=1` were learned during pre-training and are added to the wordpiece 292 | # embedding vector (and position vector). This is not *strictly* necessary 293 | # since the [SEP] token unambiguously separates the sequences, but it makes 294 | # it easier for the model to learn the concept of sequences. 295 | # 296 | # For classification tasks, the first vector (corresponding to [CLS]) is 297 | # used as the "sentence vector". Note that this only makes sense because 298 | # the entire model is fine-tuned. 299 | tokens = [] 300 | segment_ids = [] 301 | tokens.append("[CLS]") 302 | segment_ids.append(0) 303 | for token in tokens_a: 304 | tokens.append(token) 305 | segment_ids.append(0) 306 | tokens.append("[SEP]") 307 | segment_ids.append(0) 308 | 309 | if tokens_b: 310 | for token in tokens_b: 311 | tokens.append(token) 312 | segment_ids.append(1) 313 | tokens.append("[SEP]") 314 | segment_ids.append(1) 315 | 316 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 317 | 318 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 319 | # tokens are attended to. 320 | input_mask = [1] * len(input_ids) 321 | 322 | # Zero-pad up to the sequence length. 323 | while len(input_ids) < max_seq_length: 324 | input_ids.append(0) 325 | input_mask.append(0) 326 | segment_ids.append(0) 327 | 328 | assert len(input_ids) == max_seq_length 329 | assert len(input_mask) == max_seq_length 330 | assert len(segment_ids) == max_seq_length 331 | 332 | label_id = label_map[example.label] 333 | if ex_index < 5: 334 | tf.logging.info("*** Example ***") 335 | tf.logging.info("guid: %s" % (example.guid)) 336 | tf.logging.info("tokens: %s" % " ".join( 337 | [tokenization.printable_text(x) for x in tokens])) 338 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 339 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 340 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 341 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 342 | 343 | feature = InputFeatures( 344 | input_ids=input_ids, 345 | input_mask=input_mask, 346 | segment_ids=segment_ids, 347 | label_id=label_id, 348 | is_real_example=True) 349 | return feature 350 | 351 | 352 | def file_based_convert_examples_to_features( 353 | examples, label_list, max_seq_length, tokenizer, output_file): 354 | """Convert a set of `InputExample`s to a TFRecord file.""" 355 | 356 | writer = tf.python_io.TFRecordWriter(output_file) 357 | 358 | for (ex_index, example) in enumerate(examples): 359 | if ex_index % 10000 == 0: 360 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 361 | 362 | feature = convert_single_example(ex_index, example, label_list, 363 | max_seq_length, tokenizer) 364 | 365 | def create_int_feature(values): 366 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 367 | return f 368 | 369 | features = collections.OrderedDict() 370 | features["input_ids"] = create_int_feature(feature.input_ids) 371 | features["input_mask"] = create_int_feature(feature.input_mask) 372 | features["segment_ids"] = create_int_feature(feature.segment_ids) 373 | features["label_ids"] = create_int_feature([feature.label_id]) 374 | features["is_real_example"] = create_int_feature( 375 | [int(feature.is_real_example)]) 376 | 377 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 378 | writer.write(tf_example.SerializeToString()) 379 | writer.close() 380 | 381 | 382 | def file_based_input_fn_builder(input_file, seq_length, is_training, 383 | drop_remainder): 384 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 385 | 386 | name_to_features = { 387 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 388 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 389 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 390 | "label_ids": tf.FixedLenFeature([], tf.int64), 391 | "is_real_example": tf.FixedLenFeature([], tf.int64), 392 | } 393 | 394 | def _decode_record(record, name_to_features): 395 | """Decodes a record to a TensorFlow example.""" 396 | example = tf.parse_single_example(record, name_to_features) 397 | 398 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 399 | # So cast all int64 to int32. 400 | for name in list(example.keys()): 401 | t = example[name] 402 | if t.dtype == tf.int64: 403 | t = tf.to_int32(t) 404 | example[name] = t 405 | 406 | return example 407 | 408 | def input_fn(params): 409 | """The actual input function.""" 410 | batch_size = params["batch_size"] 411 | 412 | # For training, we want a lot of parallel reading and shuffling. 413 | # For eval, we want no shuffling and parallel reading doesn't matter. 414 | d = tf.data.TFRecordDataset(input_file) 415 | if is_training: 416 | d = d.repeat() 417 | d = d.shuffle(buffer_size=100) 418 | 419 | d = d.apply( 420 | tf.contrib.data.map_and_batch( 421 | lambda record: _decode_record(record, name_to_features), 422 | batch_size=batch_size, 423 | drop_remainder=drop_remainder)) 424 | 425 | return d 426 | 427 | return input_fn 428 | 429 | 430 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 431 | """Truncates a sequence pair in place to the maximum length.""" 432 | 433 | # This is a simple heuristic which will always truncate the longer sequence 434 | # one token at a time. This makes more sense than truncating an equal percent 435 | # of tokens from each, since if one sequence is very short then each token 436 | # that's truncated likely contains more information than a longer sequence. 437 | while True: 438 | total_length = len(tokens_a) + len(tokens_b) 439 | if total_length <= max_length: 440 | break 441 | if len(tokens_a) > len(tokens_b): 442 | tokens_a.pop() 443 | else: 444 | tokens_b.pop() 445 | 446 | 447 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 448 | labels, num_labels, use_one_hot_embeddings): 449 | """Creates a classification model.""" 450 | model = modeling.BertModel( 451 | config=bert_config, 452 | is_training=is_training, 453 | input_ids=input_ids, 454 | input_mask=input_mask, 455 | token_type_ids=segment_ids, 456 | use_one_hot_embeddings=use_one_hot_embeddings) 457 | 458 | # In the demo, we are doing a simple classification task on the entire 459 | # segment. 460 | # 461 | # If you want to use the token-level output, use model.get_sequence_output() 462 | # instead. 463 | output_layer = model.get_pooled_output() 464 | 465 | hidden_size = output_layer.shape[-1].value 466 | 467 | output_weights = tf.get_variable( 468 | "output_weights", [num_labels, hidden_size], 469 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 470 | 471 | output_bias = tf.get_variable( 472 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 473 | 474 | with tf.variable_scope("loss"): 475 | if is_training: 476 | # I.e., 0.1 dropout 477 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 478 | 479 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 480 | logits = tf.nn.bias_add(logits, output_bias) 481 | probabilities = tf.nn.softmax(logits, axis=-1) 482 | log_probs = tf.nn.log_softmax(logits, axis=-1) 483 | 484 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 485 | 486 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) # todo 08-29 try temp-loss 487 | ###############bi_tempered_logistic_loss############################################################################ 488 | # print("##cross entropy loss is used...."); tf.logging.info("##cross entropy loss is used....") 489 | # t1=0.9 #t1=0.90 490 | # t2=1.05 #t2=1.05 491 | # per_example_loss=bi_tempered_logistic_loss(log_probs,one_hot_labels,t1,t2,label_smoothing=0.1,num_iters=5) # TODO label_smoothing=0.0 492 | #tf.logging.info("per_example_loss:"+str(per_example_loss.shape)) 493 | ##############bi_tempered_logistic_loss############################################################################# 494 | 495 | loss = tf.reduce_mean(per_example_loss) 496 | 497 | return (loss, per_example_loss, logits, probabilities) 498 | 499 | 500 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 501 | num_train_steps, num_warmup_steps, use_tpu, 502 | use_one_hot_embeddings): 503 | """Returns `model_fn` closure for TPUEstimator.""" 504 | 505 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 506 | """The `model_fn` for TPUEstimator.""" 507 | 508 | tf.logging.info("*** Features ***") 509 | for name in sorted(features.keys()): 510 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 511 | 512 | input_ids = features["input_ids"] 513 | input_mask = features["input_mask"] 514 | segment_ids = features["segment_ids"] 515 | label_ids = features["label_ids"] 516 | is_real_example = None 517 | if "is_real_example" in features: 518 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 519 | else: 520 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 521 | 522 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 523 | 524 | (total_loss, per_example_loss, logits, probabilities) = create_model( 525 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 526 | num_labels, use_one_hot_embeddings) 527 | 528 | tvars = tf.trainable_variables() 529 | initialized_variable_names = {} 530 | scaffold_fn = None 531 | if init_checkpoint: 532 | (assignment_map, initialized_variable_names 533 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 534 | if use_tpu: 535 | 536 | def tpu_scaffold(): 537 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 538 | return tf.train.Scaffold() 539 | 540 | scaffold_fn = tpu_scaffold 541 | else: 542 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 543 | 544 | tf.logging.info("**** Trainable Variables ****") 545 | for var in tvars: 546 | init_string = "" 547 | if var.name in initialized_variable_names: 548 | init_string = ", *INIT_FROM_CKPT*" 549 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 550 | init_string) 551 | 552 | output_spec = None 553 | if mode == tf.estimator.ModeKeys.TRAIN: 554 | 555 | train_op = optimization.create_optimizer( 556 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 557 | 558 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 559 | mode=mode, 560 | loss=total_loss, 561 | train_op=train_op, 562 | scaffold_fn=scaffold_fn) 563 | elif mode == tf.estimator.ModeKeys.EVAL: 564 | 565 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 566 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 567 | accuracy = tf.metrics.accuracy( 568 | labels=label_ids, predictions=predictions, weights=is_real_example) 569 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 570 | return { 571 | "eval_accuracy": accuracy, 572 | "eval_loss": loss, 573 | } 574 | 575 | eval_metrics = (metric_fn, 576 | [per_example_loss, label_ids, logits, is_real_example]) 577 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 578 | mode=mode, 579 | loss=total_loss, 580 | eval_metrics=eval_metrics, 581 | scaffold_fn=scaffold_fn) 582 | else: 583 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 584 | mode=mode, 585 | predictions={"probabilities": probabilities}, 586 | scaffold_fn=scaffold_fn) 587 | return output_spec 588 | 589 | return model_fn 590 | 591 | 592 | # This function is not used by this file but is still used by the Colab and 593 | # people who depend on it. 594 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 595 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 596 | 597 | all_input_ids = [] 598 | all_input_mask = [] 599 | all_segment_ids = [] 600 | all_label_ids = [] 601 | 602 | for feature in features: 603 | all_input_ids.append(feature.input_ids) 604 | all_input_mask.append(feature.input_mask) 605 | all_segment_ids.append(feature.segment_ids) 606 | all_label_ids.append(feature.label_id) 607 | 608 | def input_fn(params): 609 | """The actual input function.""" 610 | batch_size = params["batch_size"] 611 | 612 | num_examples = len(features) 613 | 614 | # This is for demo purposes and does NOT scale to large data sets. We do 615 | # not use Dataset.from_generator() because that uses tf.py_func which is 616 | # not TPU compatible. The right way to load data is with TFRecordReader. 617 | d = tf.data.Dataset.from_tensor_slices({ 618 | "input_ids": 619 | tf.constant( 620 | all_input_ids, shape=[num_examples, seq_length], 621 | dtype=tf.int32), 622 | "input_mask": 623 | tf.constant( 624 | all_input_mask, 625 | shape=[num_examples, seq_length], 626 | dtype=tf.int32), 627 | "segment_ids": 628 | tf.constant( 629 | all_segment_ids, 630 | shape=[num_examples, seq_length], 631 | dtype=tf.int32), 632 | "label_ids": 633 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 634 | }) 635 | 636 | if is_training: 637 | d = d.repeat() 638 | d = d.shuffle(buffer_size=100) 639 | 640 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 641 | return d 642 | 643 | return input_fn 644 | 645 | class LCQMCPairClassificationProcessor(DataProcessor): # TODO NEED CHANGE2 646 | """Processor for the internal data set. sentence pair classification""" 647 | def __init__(self): 648 | self.language = "zh" 649 | 650 | def get_train_examples(self, data_dir): 651 | """See base class.""" 652 | return self._create_examples( 653 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train") 654 | # dev_0827.tsv 655 | 656 | def get_dev_examples(self, data_dir): 657 | """See base class.""" 658 | return self._create_examples( 659 | self._read_tsv(os.path.join(data_dir, "test.txt")), "dev") # todo change temp for test purpose 660 | 661 | def get_test_examples(self, data_dir): 662 | """See base class.""" 663 | return self._create_examples( 664 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 665 | 666 | def get_labels(self): 667 | """See base class.""" 668 | return ["0", "1"] 669 | #return ["-1","0", "1"] 670 | 671 | def _create_examples(self, lines, set_type): 672 | """Creates examples for the training and dev sets.""" 673 | examples = [] 674 | print("length of lines:",len(lines)) 675 | for (i, line) in enumerate(lines): 676 | #print('#i:',i,line) 677 | if i == 0: 678 | continue 679 | guid = "%s-%s" % (set_type, i) 680 | try: 681 | label = tokenization.convert_to_unicode(line[2]) 682 | text_a = tokenization.convert_to_unicode(line[0]) 683 | text_b = tokenization.convert_to_unicode(line[1]) 684 | examples.append( 685 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 686 | except Exception: 687 | print('###error.i:', i, line) 688 | return examples 689 | 690 | class SentencePairClassificationProcessor(DataProcessor): 691 | """Processor for the internal data set. sentence pair classification""" 692 | def __init__(self): 693 | self.language = "zh" 694 | 695 | def get_train_examples(self, data_dir): 696 | """See base class.""" 697 | return self._create_examples( 698 | self._read_tsv(os.path.join(data_dir, "train_0827.tsv")), "train") 699 | # dev_0827.tsv 700 | 701 | def get_dev_examples(self, data_dir): 702 | """See base class.""" 703 | return self._create_examples( 704 | self._read_tsv(os.path.join(data_dir, "dev_0827.tsv")), "dev") 705 | 706 | def get_test_examples(self, data_dir): 707 | """See base class.""" 708 | return self._create_examples( 709 | self._read_tsv(os.path.join(data_dir, "test_0827.tsv")), "test") 710 | 711 | def get_labels(self): 712 | """See base class.""" 713 | return ["0", "1"] 714 | #return ["-1","0", "1"] 715 | 716 | def _create_examples(self, lines, set_type): 717 | """Creates examples for the training and dev sets.""" 718 | examples = [] 719 | print("length of lines:",len(lines)) 720 | for (i, line) in enumerate(lines): 721 | #print('#i:',i,line) 722 | if i == 0: 723 | continue 724 | guid = "%s-%s" % (set_type, i) 725 | try: 726 | label = tokenization.convert_to_unicode(line[0]) 727 | text_a = tokenization.convert_to_unicode(line[1]) 728 | text_b = tokenization.convert_to_unicode(line[2]) 729 | examples.append( 730 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 731 | except Exception: 732 | print('###error.i:', i, line) 733 | return examples 734 | 735 | # This function is not used by this file but is still used by the Colab and 736 | # people who depend on it. 737 | def convert_examples_to_features(examples, label_list, max_seq_length, 738 | tokenizer): 739 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 740 | 741 | features = [] 742 | for (ex_index, example) in enumerate(examples): 743 | if ex_index % 10000 == 0: 744 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 745 | 746 | feature = convert_single_example(ex_index, example, label_list, 747 | max_seq_length, tokenizer) 748 | 749 | features.append(feature) 750 | return features 751 | 752 | 753 | def main(_): 754 | tf.logging.set_verbosity(tf.logging.INFO) 755 | 756 | processors = { 757 | "sentence_pair": SentencePairClassificationProcessor, 758 | "lcqmc_pair":LCQMCPairClassificationProcessor 759 | 760 | 761 | } 762 | 763 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 764 | FLAGS.init_checkpoint) 765 | 766 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 767 | raise ValueError( 768 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 769 | 770 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 771 | 772 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 773 | raise ValueError( 774 | "Cannot use sequence length %d because the BERT model " 775 | "was only trained up to sequence length %d" % 776 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 777 | 778 | tf.gfile.MakeDirs(FLAGS.output_dir) 779 | 780 | task_name = FLAGS.task_name.lower() 781 | 782 | if task_name not in processors: 783 | raise ValueError("Task not found: %s" % (task_name)) 784 | 785 | processor = processors[task_name]() 786 | 787 | label_list = processor.get_labels() 788 | 789 | tokenizer = tokenization.FullTokenizer( 790 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 791 | 792 | tpu_cluster_resolver = None 793 | if FLAGS.use_tpu and FLAGS.tpu_name: 794 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 795 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 796 | 797 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 798 | # Cloud TPU: Invalid TPU configuration, ensure ClusterResolver is passed to tpu. 799 | print("###tpu_cluster_resolver:",tpu_cluster_resolver) 800 | run_config = tf.contrib.tpu.RunConfig( 801 | cluster=tpu_cluster_resolver, 802 | master=FLAGS.master, 803 | model_dir=FLAGS.output_dir, 804 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 805 | tpu_config=tf.contrib.tpu.TPUConfig( 806 | iterations_per_loop=FLAGS.iterations_per_loop, 807 | num_shards=FLAGS.num_tpu_cores, 808 | per_host_input_for_training=is_per_host)) 809 | 810 | train_examples = None 811 | num_train_steps = None 812 | num_warmup_steps = None 813 | if FLAGS.do_train: 814 | train_examples =processor.get_train_examples(FLAGS.data_dir) # TODO 815 | print("###length of total train_examples:",len(train_examples)) 816 | num_train_steps = int(len(train_examples)/ FLAGS.train_batch_size * FLAGS.num_train_epochs) 817 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 818 | 819 | model_fn = model_fn_builder( 820 | bert_config=bert_config, 821 | num_labels=len(label_list), 822 | init_checkpoint=FLAGS.init_checkpoint, 823 | learning_rate=FLAGS.learning_rate, 824 | num_train_steps=num_train_steps, 825 | num_warmup_steps=num_warmup_steps, 826 | use_tpu=FLAGS.use_tpu, 827 | use_one_hot_embeddings=FLAGS.use_tpu) 828 | 829 | # If TPU is not available, this will fall back to normal Estimator on CPU 830 | # or GPU. 831 | estimator = tf.contrib.tpu.TPUEstimator( 832 | use_tpu=FLAGS.use_tpu, 833 | model_fn=model_fn, 834 | config=run_config, 835 | train_batch_size=FLAGS.train_batch_size, 836 | eval_batch_size=FLAGS.eval_batch_size, 837 | predict_batch_size=FLAGS.predict_batch_size) 838 | 839 | if FLAGS.do_train: 840 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 841 | train_file_exists=os.path.exists(train_file) 842 | print("###train_file_exists:", train_file_exists," ;train_file:",train_file) 843 | if not train_file_exists: # if tf_record file not exist, convert from raw text file. # TODO 844 | file_based_convert_examples_to_features(train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 845 | tf.logging.info("***** Running training *****") 846 | tf.logging.info(" Num examples = %d", len(train_examples)) 847 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 848 | tf.logging.info(" Num steps = %d", num_train_steps) 849 | train_input_fn = file_based_input_fn_builder( 850 | input_file=train_file, 851 | seq_length=FLAGS.max_seq_length, 852 | is_training=True, 853 | drop_remainder=True) 854 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 855 | 856 | if FLAGS.do_eval: 857 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 858 | num_actual_eval_examples = len(eval_examples) 859 | if FLAGS.use_tpu: 860 | # TPU requires a fixed batch size for all batches, therefore the number 861 | # of examples must be a multiple of the batch size, or else examples 862 | # will get dropped. So we pad with fake examples which are ignored 863 | # later on. These do NOT count towards the metric (all tf.metrics 864 | # support a per-instance weight, and these get a weight of 0.0). 865 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 866 | eval_examples.append(PaddingInputExample()) 867 | 868 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 869 | file_based_convert_examples_to_features( 870 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 871 | 872 | tf.logging.info("***** Running evaluation *****") 873 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 874 | len(eval_examples), num_actual_eval_examples, 875 | len(eval_examples) - num_actual_eval_examples) 876 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 877 | 878 | # This tells the estimator to run through the entire set. 879 | eval_steps = None 880 | # However, if running eval on the TPU, you will need to specify the 881 | # number of steps. 882 | if FLAGS.use_tpu: 883 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 884 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 885 | 886 | eval_drop_remainder = True if FLAGS.use_tpu else False 887 | eval_input_fn = file_based_input_fn_builder( 888 | input_file=eval_file, 889 | seq_length=FLAGS.max_seq_length, 890 | is_training=False, 891 | drop_remainder=eval_drop_remainder) 892 | 893 | ####################################################################################################################### 894 | # evaluate 所有的checkpoint 895 | steps_and_files = [] 896 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 897 | for filename in filenames: 898 | if filename.endswith(".index"): 899 | ckpt_name = filename[:-6] 900 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 901 | global_step = int(cur_filename.split("-")[-1]) 902 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 903 | steps_and_files.append([global_step, cur_filename]) 904 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 905 | 906 | output_eval_file = os.path.join(FLAGS.data_dir, "eval_results16-layer24-4million-2.txt") # finetuning-layer24-4million 907 | print("output_eval_file:",output_eval_file) 908 | tf.logging.info("output_eval_file:"+output_eval_file) 909 | with tf.gfile.GFile(output_eval_file, "w") as writer: 910 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]): 911 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=filename) 912 | 913 | tf.logging.info("***** Eval results %s *****" % (filename)) 914 | writer.write("***** Eval results %s *****\n" % (filename)) 915 | for key in sorted(result.keys()): 916 | tf.logging.info(" %s = %s", key, str(result[key])) 917 | writer.write("%s = %s\n" % (key, str(result[key]))) 918 | ####################################################################################################################### 919 | 920 | #result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 921 | # 922 | #output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 923 | #with tf.gfile.GFile(output_eval_file, "w") as writer: 924 | # tf.logging.info("***** Eval results *****") 925 | # for key in sorted(result.keys()): 926 | # tf.logging.info(" %s = %s", key, str(result[key])) 927 | # writer.write("%s = %s\n" % (key, str(result[key]))) 928 | 929 | if FLAGS.do_predict: 930 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 931 | num_actual_predict_examples = len(predict_examples) 932 | if FLAGS.use_tpu: 933 | # TPU requires a fixed batch size for all batches, therefore the number 934 | # of examples must be a multiple of the batch size, or else examples 935 | # will get dropped. So we pad with fake examples which are ignored 936 | # later on. 937 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 938 | predict_examples.append(PaddingInputExample()) 939 | 940 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 941 | file_based_convert_examples_to_features(predict_examples, label_list, 942 | FLAGS.max_seq_length, tokenizer, 943 | predict_file) 944 | 945 | tf.logging.info("***** Running prediction*****") 946 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 947 | len(predict_examples), num_actual_predict_examples, 948 | len(predict_examples) - num_actual_predict_examples) 949 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 950 | 951 | predict_drop_remainder = True if FLAGS.use_tpu else False 952 | predict_input_fn = file_based_input_fn_builder( 953 | input_file=predict_file, 954 | seq_length=FLAGS.max_seq_length, 955 | is_training=False, 956 | drop_remainder=predict_drop_remainder) 957 | 958 | result = estimator.predict(input_fn=predict_input_fn) 959 | 960 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 961 | with tf.gfile.GFile(output_predict_file, "w") as writer: 962 | num_written_lines = 0 963 | tf.logging.info("***** Predict results *****") 964 | for (i, prediction) in enumerate(result): 965 | probabilities = prediction["probabilities"] 966 | if i >= num_actual_predict_examples: 967 | break 968 | output_line = "\t".join( 969 | str(class_probability) 970 | for class_probability in probabilities) + "\n" 971 | writer.write(output_line) 972 | num_written_lines += 1 973 | assert num_written_lines == num_actual_predict_examples 974 | 975 | 976 | if __name__ == "__main__": 977 | flags.mark_flag_as_required("data_dir") 978 | flags.mark_flag_as_required("task_name") 979 | flags.mark_flag_as_required("vocab_file") 980 | flags.mark_flag_as_required("bert_config_file") 981 | flags.mark_flag_as_required("output_dir") 982 | tf.app.run() -------------------------------------------------------------------------------- /run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, # TODO TODO TODO 可以计算单不算成绩 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | # batch_size=masked_lm_log_probs.shape[0] 148 | # next_sentence_example_loss=tf.zeros((batch_size)) #tf.constant(0.0,dtype=tf.float32) 149 | # next_sentence_log_probs=tf.zeros((batch_size,2)) 150 | total_loss = masked_lm_loss # TODO remove next sentence loss 2019-08-08, + next_sentence_loss 151 | 152 | tvars = tf.trainable_variables() 153 | 154 | initialized_variable_names = {} 155 | print("init_checkpoint:",init_checkpoint) 156 | scaffold_fn = None 157 | if init_checkpoint: 158 | (assignment_map, initialized_variable_names 159 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 160 | if use_tpu: 161 | 162 | def tpu_scaffold(): 163 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 164 | return tf.train.Scaffold() 165 | 166 | scaffold_fn = tpu_scaffold 167 | else: 168 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 169 | 170 | tf.logging.info("**** Trainable Variables ****") 171 | for var in tvars: 172 | init_string = "" 173 | if var.name in initialized_variable_names: 174 | init_string = ", *INIT_FROM_CKPT*" 175 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 176 | init_string) 177 | 178 | output_spec = None 179 | if mode == tf.estimator.ModeKeys.TRAIN: 180 | train_op = optimization.create_optimizer( 181 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 182 | 183 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 184 | mode=mode, 185 | loss=total_loss, 186 | train_op=train_op, 187 | scaffold_fn=scaffold_fn) 188 | elif mode == tf.estimator.ModeKeys.EVAL: 189 | 190 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 191 | masked_lm_weights, next_sentence_example_loss, 192 | next_sentence_log_probs, next_sentence_labels): 193 | """Computes the loss and accuracy of the model.""" 194 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,[-1, masked_lm_log_probs.shape[-1]]) 195 | masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) 196 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 197 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 198 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 199 | masked_lm_accuracy = tf.metrics.accuracy( 200 | labels=masked_lm_ids, 201 | predictions=masked_lm_predictions, 202 | weights=masked_lm_weights) 203 | masked_lm_mean_loss = tf.metrics.mean( 204 | values=masked_lm_example_loss, weights=masked_lm_weights) 205 | 206 | next_sentence_log_probs = tf.reshape( 207 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 208 | next_sentence_predictions = tf.argmax( 209 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 210 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 211 | next_sentence_accuracy = tf.metrics.accuracy( 212 | labels=next_sentence_labels, predictions=next_sentence_predictions) 213 | next_sentence_mean_loss = tf.metrics.mean( 214 | values=next_sentence_example_loss) 215 | 216 | return { 217 | "masked_lm_accuracy": masked_lm_accuracy, 218 | "masked_lm_loss": masked_lm_mean_loss, 219 | "next_sentence_accuracy": next_sentence_accuracy, 220 | "next_sentence_loss": next_sentence_mean_loss, 221 | } 222 | 223 | # next_sentence_example_loss=0.0 TODO 224 | # next_sentence_log_probs=0.0 # TODO 225 | eval_metrics = (metric_fn, [ 226 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 227 | masked_lm_weights, next_sentence_example_loss, 228 | next_sentence_log_probs, next_sentence_labels 229 | ]) 230 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 231 | mode=mode, 232 | loss=total_loss, 233 | eval_metrics=eval_metrics, 234 | scaffold_fn=scaffold_fn) 235 | else: 236 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 237 | 238 | return output_spec 239 | 240 | return model_fn 241 | 242 | 243 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 244 | label_ids, label_weights): 245 | """Get loss and log probs for the masked LM.""" 246 | input_tensor = gather_indexes(input_tensor, positions) 247 | 248 | with tf.variable_scope("cls/predictions"): 249 | # We apply one more non-linear transformation before the output layer. 250 | # This matrix is not used after pre-training. 251 | with tf.variable_scope("transform"): 252 | input_tensor = tf.layers.dense( 253 | input_tensor, 254 | units=bert_config.hidden_size, 255 | activation=modeling.get_activation(bert_config.hidden_act), 256 | kernel_initializer=modeling.create_initializer( 257 | bert_config.initializer_range)) 258 | input_tensor = modeling.layer_norm(input_tensor) 259 | 260 | # The output weights are the same as the input embeddings, but there is 261 | # an output-only bias for each token. 262 | output_bias = tf.get_variable( 263 | "output_bias", 264 | shape=[bert_config.vocab_size], 265 | initializer=tf.zeros_initializer()) 266 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 267 | logits = tf.nn.bias_add(logits, output_bias) 268 | log_probs = tf.nn.log_softmax(logits, axis=-1) 269 | 270 | label_ids = tf.reshape(label_ids, [-1]) 271 | label_weights = tf.reshape(label_weights, [-1]) 272 | 273 | one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 274 | 275 | # The `positions` tensor might be zero-padded (if the sequence is too 276 | # short to have the maximum number of predictions). The `label_weights` 277 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 278 | # padding predictions. 279 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 280 | numerator = tf.reduce_sum(label_weights * per_example_loss) 281 | denominator = tf.reduce_sum(label_weights) + 1e-5 282 | loss = numerator / denominator 283 | 284 | return (loss, per_example_loss, log_probs) 285 | 286 | 287 | def get_next_sentence_output(bert_config, input_tensor, labels): 288 | """Get loss and log probs for the next sentence prediction.""" 289 | 290 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 291 | # "random sentence". This weight matrix is not used after pre-training. 292 | with tf.variable_scope("cls/seq_relationship"): 293 | output_weights = tf.get_variable( 294 | "output_weights", 295 | shape=[2, bert_config.hidden_size], 296 | initializer=modeling.create_initializer(bert_config.initializer_range)) 297 | output_bias = tf.get_variable( 298 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 299 | 300 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 301 | logits = tf.nn.bias_add(logits, output_bias) 302 | log_probs = tf.nn.log_softmax(logits, axis=-1) 303 | labels = tf.reshape(labels, [-1]) 304 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 305 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 306 | loss = tf.reduce_mean(per_example_loss) 307 | return (loss, per_example_loss, log_probs) 308 | 309 | 310 | def gather_indexes(sequence_tensor, positions): 311 | """Gathers the vectors at the specific positions over a minibatch.""" 312 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 313 | batch_size = sequence_shape[0] 314 | seq_length = sequence_shape[1] 315 | width = sequence_shape[2] 316 | 317 | flat_offsets = tf.reshape( 318 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 319 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 320 | flat_sequence_tensor = tf.reshape(sequence_tensor, 321 | [batch_size * seq_length, width]) 322 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 323 | return output_tensor 324 | 325 | 326 | def input_fn_builder(input_files, 327 | max_seq_length, 328 | max_predictions_per_seq, 329 | is_training, 330 | num_cpu_threads=4): 331 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 332 | 333 | def input_fn(params): 334 | """The actual input function.""" 335 | batch_size = params["batch_size"] 336 | 337 | name_to_features = { 338 | "input_ids": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "input_mask": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "segment_ids": 343 | tf.FixedLenFeature([max_seq_length], tf.int64), 344 | "masked_lm_positions": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_ids": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 348 | "masked_lm_weights": 349 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 350 | "next_sentence_labels": 351 | tf.FixedLenFeature([1], tf.int64), 352 | } 353 | 354 | # For training, we want a lot of parallel reading and shuffling. 355 | # For eval, we want no shuffling and parallel reading doesn't matter. 356 | if is_training: 357 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 358 | d = d.repeat() 359 | d = d.shuffle(buffer_size=len(input_files)) 360 | 361 | # `cycle_length` is the number of parallel files that get read. 362 | cycle_length = min(num_cpu_threads, len(input_files)) 363 | 364 | # `sloppy` mode means that the interleaving is not exact. This adds 365 | # even more randomness to the training pipeline. 366 | d = d.apply( 367 | tf.contrib.data.parallel_interleave( 368 | tf.data.TFRecordDataset, 369 | sloppy=is_training, 370 | cycle_length=cycle_length)) 371 | d = d.shuffle(buffer_size=100) 372 | else: 373 | d = tf.data.TFRecordDataset(input_files) 374 | # Since we evaluate for a fixed number of steps we don't want to encounter 375 | # out-of-range exceptions. 376 | d = d.repeat() 377 | 378 | # We must `drop_remainder` on training because the TPU requires fixed 379 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 380 | # and we *don't* want to drop the remainder, otherwise we wont cover 381 | # every sample. 382 | d = d.apply( 383 | tf.contrib.data.map_and_batch( 384 | lambda record: _decode_record(record, name_to_features), 385 | batch_size=batch_size, 386 | num_parallel_batches=num_cpu_threads, 387 | drop_remainder=True)) 388 | return d 389 | 390 | return input_fn 391 | 392 | 393 | def _decode_record(record, name_to_features): 394 | """Decodes a record to a TensorFlow example.""" 395 | example = tf.parse_single_example(record, name_to_features) 396 | 397 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 398 | # So cast all int64 to int32. 399 | for name in list(example.keys()): 400 | t = example[name] 401 | if t.dtype == tf.int64: 402 | t = tf.to_int32(t) 403 | example[name] = t 404 | 405 | return example 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | if not FLAGS.do_train and not FLAGS.do_eval: # 必须是训练或验证的类型 412 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 413 | 414 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) # 从json文件中获得配置信息 415 | 416 | tf.gfile.MakeDirs(FLAGS.output_dir) 417 | 418 | input_files = [] # 输入可以是多个文件,以“逗号隔开”;可以是一个匹配形式的,如“input_x*” 419 | for input_pattern in FLAGS.input_file.split(","): 420 | input_files.extend(tf.gfile.Glob(input_pattern)) 421 | 422 | tf.logging.info("*** Input Files ***") 423 | for input_file in input_files: 424 | tf.logging.info(" %s" % input_file) 425 | 426 | tpu_cluster_resolver = None 427 | #if FLAGS.use_tpu and FLAGS.tpu_name: 428 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( # TODO 429 | tpu=FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 430 | 431 | print("###tpu_cluster_resolver:",tpu_cluster_resolver,";FLAGS.use_tpu:",FLAGS.use_tpu,";FLAGS.tpu_name:",FLAGS.tpu_name,";FLAGS.tpu_zone:",FLAGS.tpu_zone) 432 | # ###tpu_cluster_resolver: ;FLAGS.use_tpu: True ;FLAGS.tpu_name: grpc://10.240.1.83:8470 433 | 434 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 435 | run_config = tf.contrib.tpu.RunConfig( 436 | keep_checkpoint_max=20, # 10 437 | cluster=tpu_cluster_resolver, 438 | master=FLAGS.master, 439 | model_dir=FLAGS.output_dir, 440 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 441 | tpu_config=tf.contrib.tpu.TPUConfig( 442 | iterations_per_loop=FLAGS.iterations_per_loop, 443 | num_shards=FLAGS.num_tpu_cores, 444 | per_host_input_for_training=is_per_host)) 445 | 446 | model_fn = model_fn_builder( 447 | bert_config=bert_config, 448 | init_checkpoint=FLAGS.init_checkpoint, 449 | learning_rate=FLAGS.learning_rate, 450 | num_train_steps=FLAGS.num_train_steps, 451 | num_warmup_steps=FLAGS.num_warmup_steps, 452 | use_tpu=FLAGS.use_tpu, 453 | use_one_hot_embeddings=FLAGS.use_tpu) 454 | 455 | # If TPU is not available, this will fall back to normal Estimator on CPU 456 | # or GPU. 457 | estimator = tf.contrib.tpu.TPUEstimator( 458 | use_tpu=FLAGS.use_tpu, 459 | model_fn=model_fn, 460 | config=run_config, 461 | train_batch_size=FLAGS.train_batch_size, 462 | eval_batch_size=FLAGS.eval_batch_size) 463 | 464 | if FLAGS.do_train: 465 | tf.logging.info("***** Running training *****") 466 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 467 | train_input_fn = input_fn_builder( 468 | input_files=input_files, 469 | max_seq_length=FLAGS.max_seq_length, 470 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 471 | is_training=True) 472 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 473 | 474 | if FLAGS.do_eval: 475 | tf.logging.info("***** Running evaluation *****") 476 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 477 | 478 | eval_input_fn = input_fn_builder( 479 | input_files=input_files, 480 | max_seq_length=FLAGS.max_seq_length, 481 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 482 | is_training=False) 483 | 484 | result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 485 | 486 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 487 | with tf.gfile.GFile(output_eval_file, "w") as writer: 488 | tf.logging.info("***** Eval results *****") 489 | for key in sorted(result.keys()): 490 | tf.logging.info(" %s = %s", key, str(result[key])) 491 | writer.write("%s = %s\n" % (key, str(result[key]))) 492 | 493 | 494 | if __name__ == "__main__": 495 | flags.mark_flag_as_required("input_file") 496 | flags.mark_flag_as_required("bert_config_file") 497 | flags.mark_flag_as_required("output_dir") 498 | tf.app.run() 499 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | #print("items:",items) #['[CLS]', '日', '##期', ',', '但', '被', '##告', '金', '##东', '##福', '载', '##明', '[MASK]', 'U', '##N', '##K', ']', '保', '##证', '本', '##月', '1', '##4', '[MASK]', '到', '##位', ',', '2', '##0', '##1', '##5', '年', '6', '[MASK]', '1', '##1', '日', '[', 'U', '##N', '##K', ']', ',', '原', '##告', '[MASK]', '认', '##可', '于', '2', '##0', '##1', '##5', '[MASK]', '6', '月', '[MASK]', '[MASK]', '日', '##向', '被', '##告', '主', '##张', '权', '##利', '。', '而', '[MASK]', '[MASK]', '自', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '年', '6', '月', '1', '##1', '日', '[SEP]', '原', '##告', '于', '2', '##0', '##1', '##6', '[MASK]', '6', '[MASK]', '2', '##4', '日', '起', '##诉', ',', '主', '##张', '保', '##证', '责', '##任', ',', '已', '超', '##过', '保', '##证', '期', '##限', '[MASK]', '保', '##证', '人', '依', '##法', '不', '##再', '承', '##担', '保', '##证', '[MASK]', '[MASK]', '[MASK]', '[SEP]'] 140 | for i,item in enumerate(items): 141 | #print(i,"item:",item) # ##期 142 | output.append(vocab[item]) 143 | return output 144 | 145 | 146 | def convert_tokens_to_ids(vocab, tokens): 147 | return convert_by_vocab(vocab, tokens) 148 | 149 | 150 | def convert_ids_to_tokens(inv_vocab, ids): 151 | return convert_by_vocab(inv_vocab, ids) 152 | 153 | 154 | def whitespace_tokenize(text): 155 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 156 | text = text.strip() 157 | if not text: 158 | return [] 159 | tokens = text.split() 160 | return tokens 161 | 162 | 163 | class FullTokenizer(object): 164 | """Runs end-to-end tokenziation.""" 165 | 166 | def __init__(self, vocab_file, do_lower_case=True): 167 | self.vocab = load_vocab(vocab_file) 168 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 169 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 171 | 172 | def tokenize(self, text): 173 | split_tokens = [] 174 | for token in self.basic_tokenizer.tokenize(text): 175 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 176 | split_tokens.append(sub_token) 177 | 178 | return split_tokens 179 | 180 | def convert_tokens_to_ids(self, tokens): 181 | return convert_by_vocab(self.vocab, tokens) 182 | 183 | def convert_ids_to_tokens(self, ids): 184 | return convert_by_vocab(self.inv_vocab, ids) 185 | 186 | 187 | class BasicTokenizer(object): 188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 189 | 190 | def __init__(self, do_lower_case=True): 191 | """Constructs a BasicTokenizer. 192 | 193 | Args: 194 | do_lower_case: Whether to lower case the input. 195 | """ 196 | self.do_lower_case = do_lower_case 197 | 198 | def tokenize(self, text): 199 | """Tokenizes a piece of text.""" 200 | text = convert_to_unicode(text) 201 | text = self._clean_text(text) 202 | 203 | # This was added on November 1st, 2018 for the multilingual and Chinese 204 | # models. This is also applied to the English models now, but it doesn't 205 | # matter since the English models were not trained on any Chinese data 206 | # and generally don't have any Chinese data in them (there are Chinese 207 | # characters in the vocabulary because Wikipedia does have some Chinese 208 | # words in the English Wikipedia.). 209 | text = self._tokenize_chinese_chars(text) 210 | 211 | orig_tokens = whitespace_tokenize(text) 212 | split_tokens = [] 213 | for token in orig_tokens: 214 | if self.do_lower_case: 215 | token = token.lower() 216 | token = self._run_strip_accents(token) 217 | split_tokens.extend(self._run_split_on_punc(token)) 218 | 219 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 220 | return output_tokens 221 | 222 | def _run_strip_accents(self, text): 223 | """Strips accents from a piece of text.""" 224 | text = unicodedata.normalize("NFD", text) 225 | output = [] 226 | for char in text: 227 | cat = unicodedata.category(char) 228 | if cat == "Mn": 229 | continue 230 | output.append(char) 231 | return "".join(output) 232 | 233 | def _run_split_on_punc(self, text): 234 | """Splits punctuation on a piece of text.""" 235 | chars = list(text) 236 | i = 0 237 | start_new_word = True 238 | output = [] 239 | while i < len(chars): 240 | char = chars[i] 241 | if _is_punctuation(char): 242 | output.append([char]) 243 | start_new_word = True 244 | else: 245 | if start_new_word: 246 | output.append([]) 247 | start_new_word = False 248 | output[-1].append(char) 249 | i += 1 250 | 251 | return ["".join(x) for x in output] 252 | 253 | def _tokenize_chinese_chars(self, text): 254 | """Adds whitespace around any CJK character.""" 255 | output = [] 256 | for char in text: 257 | cp = ord(char) 258 | if self._is_chinese_char(cp): 259 | output.append(" ") 260 | output.append(char) 261 | output.append(" ") 262 | else: 263 | output.append(char) 264 | return "".join(output) 265 | 266 | def _is_chinese_char(self, cp): 267 | """Checks whether CP is the codepoint of a CJK character.""" 268 | # This defines a "chinese character" as anything in the CJK Unicode block: 269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 270 | # 271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 272 | # despite its name. The modern Korean Hangul alphabet is a different block, 273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 274 | # space-separated words, so they are not treated specially and handled 275 | # like the all of the other languages. 276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 277 | (cp >= 0x3400 and cp <= 0x4DBF) or # 278 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 279 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 280 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 282 | (cp >= 0xF900 and cp <= 0xFAFF) or # 283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 284 | return True 285 | 286 | return False 287 | 288 | def _clean_text(self, text): 289 | """Performs invalid character removal and whitespace cleanup on text.""" 290 | output = [] 291 | for char in text: 292 | cp = ord(char) 293 | if cp == 0 or cp == 0xfffd or _is_control(char): 294 | continue 295 | if _is_whitespace(char): 296 | output.append(" ") 297 | else: 298 | output.append(char) 299 | return "".join(output) 300 | 301 | 302 | class WordpieceTokenizer(object): 303 | """Runs WordPiece tokenziation.""" 304 | 305 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 306 | self.vocab = vocab 307 | self.unk_token = unk_token 308 | self.max_input_chars_per_word = max_input_chars_per_word 309 | 310 | def tokenize(self, text): 311 | """Tokenizes a piece of text into its word pieces. 312 | 313 | This uses a greedy longest-match-first algorithm to perform tokenization 314 | using the given vocabulary. 315 | 316 | For example: 317 | input = "unaffable" 318 | output = ["un", "##aff", "##able"] 319 | 320 | Args: 321 | text: A single token or whitespace separated tokens. This should have 322 | already been passed through `BasicTokenizer. 323 | 324 | Returns: 325 | A list of wordpiece tokens. 326 | """ 327 | 328 | text = convert_to_unicode(text) 329 | 330 | output_tokens = [] 331 | for token in whitespace_tokenize(text): 332 | chars = list(token) 333 | if len(chars) > self.max_input_chars_per_word: 334 | output_tokens.append(self.unk_token) 335 | continue 336 | 337 | is_bad = False 338 | start = 0 339 | sub_tokens = [] 340 | while start < len(chars): 341 | end = len(chars) 342 | cur_substr = None 343 | while start < end: 344 | substr = "".join(chars[start:end]) 345 | if start > 0: 346 | substr = "##" + substr 347 | if substr in self.vocab: 348 | cur_substr = substr 349 | break 350 | end -= 1 351 | if cur_substr is None: 352 | is_bad = True 353 | break 354 | sub_tokens.append(cur_substr) 355 | start = end 356 | 357 | if is_bad: 358 | output_tokens.append(self.unk_token) 359 | else: 360 | output_tokens.extend(sub_tokens) 361 | return output_tokens 362 | 363 | 364 | def _is_whitespace(char): 365 | """Checks whether `chars` is a whitespace character.""" 366 | # \t, \n, and \r are technically contorl characters but we treat them 367 | # as whitespace since they are generally considered as such. 368 | if char == " " or char == "\t" or char == "\n" or char == "\r": 369 | return True 370 | cat = unicodedata.category(char) 371 | if cat == "Zs": 372 | return True 373 | return False 374 | 375 | 376 | def _is_control(char): 377 | """Checks whether `chars` is a control character.""" 378 | # These are technically control characters but we count them as whitespace 379 | # characters. 380 | if char == "\t" or char == "\n" or char == "\r": 381 | return False 382 | cat = unicodedata.category(char) 383 | if cat in ("Cc", "Cf"): 384 | return True 385 | return False 386 | 387 | 388 | def _is_punctuation(char): 389 | """Checks whether `chars` is a punctuation character.""" 390 | cp = ord(char) 391 | # We treat all non-letter/number ASCII as punctuation. 392 | # Characters such as "^", "$", and "`" are not in the Unicode 393 | # Punctuation class but we treat them as punctuation anyways, for 394 | # consistency. 395 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 396 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 397 | return True 398 | cat = unicodedata.category(char) 399 | if cat.startswith("P"): 400 | return True 401 | return False 402 | --------------------------------------------------------------------------------